122 lines
3.1 KiB
Python
122 lines
3.1 KiB
Python
import hashlib
|
|
import json
|
|
from typing import Iterator
|
|
from itertools import chain
|
|
|
|
import cohere
|
|
|
|
from utils import chunk_list, text_to_chunks, Point, EmbeddingSize
|
|
from settings import settings
|
|
|
|
|
|
co = cohere.ClientV2()
|
|
|
|
|
|
def md5(text: str) -> str:
|
|
return hashlib.md5(text.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def _text_chunks(
|
|
text: str, chunk_size: int, chunk_overlap: int, search: bool = True
|
|
) -> list[str]:
|
|
if search:
|
|
return [text]
|
|
return text_to_chunks(text, chunk_size, chunk_overlap)
|
|
|
|
|
|
def _call_cohere(
|
|
texts: list[str],
|
|
embedding_size: EmbeddingSize,
|
|
search: bool = True,
|
|
) -> Iterator[tuple[str, Point]]:
|
|
response = co.embed(
|
|
inputs=[{"content": [{"type": "text", "text": doc}]} for doc in texts],
|
|
model="embed-v4.0",
|
|
output_dimension=embedding_size,
|
|
input_type="search_query" if search else "search_document",
|
|
embedding_types=["float"],
|
|
)
|
|
return zip(texts, response.embeddings.float)
|
|
|
|
|
|
def embed_store(
|
|
text: str,
|
|
embedding: Point,
|
|
embedding_size: EmbeddingSize,
|
|
search: bool = False,
|
|
):
|
|
if not settings.cohere_cache:
|
|
return
|
|
_md5 = md5(text)
|
|
target_file = (
|
|
settings.storage_path
|
|
/ ".cohere_embedding"
|
|
/ ("search" if search else "storage")
|
|
/ _md5
|
|
)
|
|
embedding_table = {}
|
|
if target_file.exists():
|
|
embedding_table = json.loads(target_file.read_text())
|
|
embedding_table[str(embedding_size)] = embedding
|
|
target_file.parent.mkdir(exist_ok=True, parents=True)
|
|
target_file.write_text(
|
|
json.dumps(embedding_table, indent=2)
|
|
)
|
|
|
|
|
|
def embed_lookup(
|
|
text: str,
|
|
embedding_size: EmbeddingSize,
|
|
search: bool = False,
|
|
) -> Point | None:
|
|
if not settings.cohere_cache:
|
|
return
|
|
_md5 = md5(text)
|
|
target_file = (
|
|
settings.storage_path
|
|
/ ".cohere_embedding"
|
|
/ ("search" if search else "storage")
|
|
/ _md5
|
|
)
|
|
if target_file.exists():
|
|
embeddings = json.loads(target_file.read_text())
|
|
return embeddings.get(str(embedding_size), None)
|
|
|
|
|
|
def embed(
|
|
text: str,
|
|
chunk_size: int,
|
|
chunk_overlap: int,
|
|
embedding_size: EmbeddingSize,
|
|
search: bool = True,
|
|
) -> list[tuple[str, Point]]:
|
|
"""
|
|
https://docs.cohere.com/reference/embed
|
|
Input type:
|
|
> "search_document": Used for embeddings stored...
|
|
> "search_query": Used for embeddings of search queries...
|
|
|
|
Cohere API only allows for at most 96 texts at a time.
|
|
"""
|
|
text_chunks_list = _text_chunks(text, chunk_size, chunk_overlap, search)
|
|
|
|
text_chunks = {
|
|
_text: embed_lookup(_text, embedding_size, search) for _text in text_chunks_list
|
|
}
|
|
|
|
missing = [_text for _text, embedding in text_chunks.items() if embedding is None]
|
|
|
|
results = {}
|
|
|
|
for _texts in chunk_list(missing, 50):
|
|
results |= dict(_call_cohere(_texts, embedding_size, search))
|
|
|
|
for _text, embedding in results.items():
|
|
embed_store(_text, embedding, embedding_size, search)
|
|
|
|
returnable = []
|
|
for _text in text_chunks_list:
|
|
returnable.append((_text, text_chunks[_text] or results[_text]))
|
|
|
|
return returnable
|