import hashlib import json from typing import Iterator from itertools import chain import cohere from utils import chunk_list, text_to_chunks, Point, EmbeddingSize from settings import settings co = cohere.ClientV2() def md5(text: str) -> str: return hashlib.md5(text.encode("utf-8")).hexdigest() def _text_chunks( text: str, chunk_size: int, chunk_overlap: int, search: bool = True ) -> list[str]: if search: return [text] return text_to_chunks(text, chunk_size, chunk_overlap) def _call_cohere( texts: list[str], embedding_size: EmbeddingSize, search: bool = True, ) -> Iterator[tuple[str, Point]]: response = co.embed( inputs=[{"content": [{"type": "text", "text": doc}]} for doc in texts], model="embed-v4.0", output_dimension=embedding_size, input_type="search_query" if search else "search_document", embedding_types=["float"], ) return zip(texts, response.embeddings.float) def embed_store( text: str, embedding: Point, embedding_size: EmbeddingSize, search: bool = False, ): if not settings.cohere_cache: return _md5 = md5(text) target_file = ( settings.storage_path / ".cohere_embedding" / ("search" if search else "storage") / _md5 ) embedding_table = {} if target_file.exists(): embedding_table = json.loads(target_file.read_text()) embedding_table[str(embedding_size)] = embedding target_file.parent.mkdir(exist_ok=True, parents=True) target_file.write_text( json.dumps(embedding_table, indent=2) ) def embed_lookup( text: str, embedding_size: EmbeddingSize, search: bool = False, ) -> Point | None: if not settings.cohere_cache: return _md5 = md5(text) target_file = ( settings.storage_path / ".cohere_embedding" / ("search" if search else "storage") / _md5 ) if target_file.exists(): embeddings = json.loads(target_file.read_text()) return embeddings.get(str(embedding_size), None) def embed( text: str, chunk_size: int, chunk_overlap: int, embedding_size: EmbeddingSize, search: bool = True, ) -> list[tuple[str, Point]]: """ https://docs.cohere.com/reference/embed Input type: > "search_document": Used for embeddings stored... > "search_query": Used for embeddings of search queries... Cohere API only allows for at most 96 texts at a time. """ text_chunks_list = _text_chunks(text, chunk_size, chunk_overlap, search) text_chunks = { _text: embed_lookup(_text, embedding_size, search) for _text in text_chunks_list } missing = [_text for _text, embedding in text_chunks.items() if embedding is None] results = {} for _texts in chunk_list(missing, 50): results |= dict(_call_cohere(_texts, embedding_size, search)) for _text, embedding in results.items(): embed_store(_text, embedding, embedding_size, search) returnable = [] for _text in text_chunks_list: returnable.append((_text, text_chunks[_text] or results[_text])) return returnable