1
0
python-vector-database/src/embedding.py

122 lines
3.1 KiB
Python

import hashlib
import json
from typing import Iterator
from itertools import chain
import cohere
from utils import chunk_list, text_to_chunks, Point, EmbeddingSize
from settings import settings
co = cohere.ClientV2()
def md5(text: str) -> str:
return hashlib.md5(text.encode("utf-8")).hexdigest()
def _text_chunks(
text: str, chunk_size: int, chunk_overlap: int, search: bool = True
) -> list[str]:
if search:
return [text]
return text_to_chunks(text, chunk_size, chunk_overlap)
def _call_cohere(
texts: list[str],
embedding_size: EmbeddingSize,
search: bool = True,
) -> Iterator[tuple[str, Point]]:
response = co.embed(
inputs=[{"content": [{"type": "text", "text": doc}]} for doc in texts],
model="embed-v4.0",
output_dimension=embedding_size,
input_type="search_query" if search else "search_document",
embedding_types=["float"],
)
return zip(texts, response.embeddings.float)
def embed_store(
text: str,
embedding: Point,
embedding_size: EmbeddingSize,
search: bool = False,
):
if not settings.cohere_cache:
return
_md5 = md5(text)
target_file = (
settings.storage_path
/ ".cohere_embedding"
/ ("search" if search else "storage")
/ _md5
)
embedding_table = {}
if target_file.exists():
embedding_table = json.loads(target_file.read_text())
embedding_table[str(embedding_size)] = embedding
target_file.parent.mkdir(exist_ok=True, parents=True)
target_file.write_text(
json.dumps(embedding_table, indent=2)
)
def embed_lookup(
text: str,
embedding_size: EmbeddingSize,
search: bool = False,
) -> Point | None:
if not settings.cohere_cache:
return
_md5 = md5(text)
target_file = (
settings.storage_path
/ ".cohere_embedding"
/ ("search" if search else "storage")
/ _md5
)
if target_file.exists():
embeddings = json.loads(target_file.read_text())
return embeddings.get(str(embedding_size), None)
def embed(
text: str,
chunk_size: int,
chunk_overlap: int,
embedding_size: EmbeddingSize,
search: bool = True,
) -> list[tuple[str, Point]]:
"""
https://docs.cohere.com/reference/embed
Input type:
> "search_document": Used for embeddings stored...
> "search_query": Used for embeddings of search queries...
Cohere API only allows for at most 96 texts at a time.
"""
text_chunks_list = _text_chunks(text, chunk_size, chunk_overlap, search)
text_chunks = {
_text: embed_lookup(_text, embedding_size, search) for _text in text_chunks_list
}
missing = [_text for _text, embedding in text_chunks.items() if embedding is None]
results = {}
for _texts in chunk_list(missing, 50):
results |= dict(_call_cohere(_texts, embedding_size, search))
for _text, embedding in results.items():
embed_store(_text, embedding, embedding_size, search)
returnable = []
for _text in text_chunks_list:
returnable.append((_text, text_chunks[_text] or results[_text]))
return returnable