Embedding some texts
This is a general embedding utility.
# /// script
# dependencies = [
# "srsly", "typer", "sentence-transformers", "tqdm"
# ]
# ///
import typer
from typing import List, Dict
from sentence_transformers import SentenceTransformer
import srsly
from tqdm import tqdm
app = typer.Typer()
def process_batch(model: SentenceTransformer, batch: List[Dict], text_key: str) -> List[Dict]:
texts = [item[text_key] for item in batch]
embeddings = model.encode(texts)
for item, embedding in zip(batch, embeddings):
item['embedding'] = embedding.tolist()
return batch
@app.command()
def add_embeddings(
input_file: str = typer.Argument(..., help="Input JSONL file path"),
output_file: str = typer.Argument(..., help="Output JSONL file path"),
model_name: str = typer.Option("all-MiniLM-L6-v2", help="Sentence transformer model name"),
text_key: str = typer.Option("text", help="Key for the text field in JSON objects"),
batch_size: int = typer.Option(128, help="Batch size for embedding generation")
):
"""
Add sentence transformer embeddings to texts in a JSONL file.
"""
typer.echo(f"Loading model {model_name}")
model = SentenceTransformer(model_name)
typer.echo(f"Processing {input_file}")
reader = srsly.read_jsonl(input_file)
processed_data = []
for item in tqdm(reader, desc="Processing items"):
processed_data.append(item)
if len(processed_data) >= batch_size:
processed_batch = process_batch(model, processed_data, text_key)
srsly.write_jsonl(output_file, processed_batch, append=True)
processed_data = []
if processed_data: # Process any remaining items
processed_batch = process_batch(model, processed_data, text_key)
srsly.write_jsonl(output_file, processed_batch, append=True)
typer.echo("Done!")
if __name__ == "__main__":
app()