Embedding some texts

This is a general embedding utility.

# /// script
# dependencies = [
#   "srsly", "typer", "sentence-transformers", "tqdm"
# ]
# ///

import typer
from typing import List, Dict
from sentence_transformers import SentenceTransformer
import srsly
from tqdm import tqdm

app = typer.Typer()

def process_batch(model: SentenceTransformer, batch: List[Dict], text_key: str) -> List[Dict]:
    texts = [item[text_key] for item in batch]
    embeddings = model.encode(texts)
    for item, embedding in zip(batch, embeddings):
        item['embedding'] = embedding.tolist()
    return batch

@app.command()
def add_embeddings(
    input_file: str = typer.Argument(..., help="Input JSONL file path"),
    output_file: str = typer.Argument(..., help="Output JSONL file path"),
    model_name: str = typer.Option("all-MiniLM-L6-v2", help="Sentence transformer model name"),
    text_key: str = typer.Option("text", help="Key for the text field in JSON objects"),
    batch_size: int = typer.Option(128, help="Batch size for embedding generation")
):
    """
    Add sentence transformer embeddings to texts in a JSONL file.
    """
    typer.echo(f"Loading model {model_name}")
    model = SentenceTransformer(model_name)

    typer.echo(f"Processing {input_file}")
    reader = srsly.read_jsonl(input_file)

    processed_data = []
    for item in tqdm(reader, desc="Processing items"):
        processed_data.append(item)
        if len(processed_data) >= batch_size:
            processed_batch = process_batch(model, processed_data, text_key)
            srsly.write_jsonl(output_file, processed_batch, append=True)
            processed_data = []

    if processed_data:  # Process any remaining items
        processed_batch = process_batch(model, processed_data, text_key)
        srsly.write_jsonl(output_file, processed_batch, append=True)

    typer.echo("Done!")

if __name__ == "__main__":
    app()