1
0
This commit is contained in:
Miguel Salgado 2025-05-30 14:42:08 -07:00
commit 19f335e300
21 changed files with 4408 additions and 0 deletions

6
.dockerignore Normal file
View File

@ -0,0 +1,6 @@
**/__pycache__
*.pyc
__marimo__
.egg-info
.git
data/

13
.gitignore vendored Normal file
View File

@ -0,0 +1,13 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
__marimo__
data/
.env
# Virtual environments
.venv

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.11

6
Dockerfile Normal file
View File

@ -0,0 +1,6 @@
FROM python:3.11-slim
COPY --from=ghcr.io/astral-sh/uv:0.4.20 /uv /bin/uv
ENV UV_SYSTEM_PYTHON=1
WORKDIR /app
COPY . .
RUN uv sync --frozen --no-cache

7
Justfile Normal file
View File

@ -0,0 +1,7 @@
api:
uv run --env-file .env fastapi dev src/api.py
docker:
docker compose build
docker compose up

130
README.md Normal file
View File

@ -0,0 +1,130 @@
# Take-at-Home Task - Backend (Vector DB)
Congrats on making it thus far in the interview process!
Here is a task for you to show us where you shine the most 🙂
The purpose is not to see how fast you go or what magic tricks you know in python, its mostly to understand how clearly you think and code.
If you think clearly and your code is clean, you are better than 90% of applicants already!
> ⚠ Feel free to use Cursor, but use it where it makes sense, dont overuse it, it introduces bugs and is super verbose and not really pythonic.
>
## Objective
The goal of this project is to develop a REST API that allows users to **index** and **query** their documents within a Vector Database.
A Vector Database specializes in storing and indexing vector embeddings, enabling fast retrieval and similarity searches. This capability is crucial for applications involving natural language processing, recommendation systems, and many more…
The REST API should be containerized in a Docker container.
### Definitions
To ensure a clear understanding, let's define some key concepts:
1. Chunk: A chunk is a piece of text with an associated embedding and metadata.
2. Document: A document is made out of multiple chunks, it also contains metadata.
3. Library: A library is made out of a list of documents and can also contain other metadata.
The API should:
1. Allow the users to create, read, update, and delete libraries.
2. Allow the users to create, read, update and delete chunks within a library.
3. Index the contents of a library.
4. Do **k-Nearest Neighbor vector search** over the selected library with a given embedding query.
### Guidelines:
The code should be **Python** since that is what we use to develop our backend.
Here is a suggested path on how to implement a basic solution to the problem.
1. Define the Chunk, Document and Library classes. To simplify schema definition, we suggest you use a fixed schema for each of the classes. This means not letting the user define which fields should be present within the metadata for each class.
2. Implement two or three indexing algorithms, do not use external libraries, we want to see you code them up.
1. What is the space and time complexity for each of the indexes?
2. Why did you choose this index?
3. Implement the necessary data structures/algorithms to ensure that there are no data races between reads and writes to the database.
1. Explain your design choices.
4. Create the logic to do the CRUD operations on libraries and documents/chunks.
1. Ideally use Services to decouple API endpoints from actual work
5. Implement an API layer on top of that logic to let users interact with the vector database.
6. Create a docker image for the project
### Extra Points:
Here are some additional suggestions on how to enhance the project even further. You are not required to implement any of these, but if you do, we will value it. If you have other improvements in mind, please feel free to implement them and document them in the projects README file
1. **Metadata filtering:**
- Add the possibility of using metadata filters to enhance query results: ie: do kNN search over all chunks created after a given date, whose name contains xyz string etc etc.
2. **Persistence to Disk**:
- Implement a mechanism to persist the database state to disk, ensuring that the docker container can be restarted and resume its operation from the last checkpoint. Explain your design choices and tradeoffs, considering factors like performance, consistency, and durability.
3. **Leader-Follower Architecture**:
- Design and implement a leader-follower (master-slave) architecture to support multiple database nodes within the Kubernetes cluster. This architecture should handle read scalability and provide high availability. Explain how leader election, data replication, and failover are managed, along with the benefits and tradeoffs of this approach.
4. **Python SDK Client**:
- Develop a Python SDK client that interfaces with your API, making it easier for users to interact with the vector database programmatically. Include documentation and examples.
## Constraints
Do **not** use libraries like chroma-db, pinecone, FAISS, etc to develop the project, we want to see you write the algorithms yourself. You can use numpy to calculate trigonometry functions `cos` , `sin` , etc
You **do not need to build a document processing pipeline** (ocr+text extraction+chunking) to test your system. Using a bunch of manually created chunks will suffice.
## **Tech Stack**
- **API Backend:** Python + FastAPI + Pydantic
## Resources:
[Cohere](https://cohere.com/embeddings) API key to create the embeddings for your test.
## Evaluation Criteria
We will evaluate the code functionality and its quality.
**Code quality:**
- [SOLID design principles](https://realpython.com/solid-principles-python/).
- Use of static typing.
- FastAPI good practices.
- Pydantic schema validation
- Code modularity and reusability.
- Use of RESTful API endpoints.
- Project containerization with Docker.
- Testing
- Error handling.
- If you know what Domain-Driven design is, do it that way!
- Separate API endpoints from business logic using services and from databases using repositories
- Keep code as pythonic as possible
- Do early returns
- Use inheritance where needed
- Use composition over inheritance
**Functionality:**
- Does everything work as expected?
## Deliverable
1. **Source Code**: A link to a GitHub repository containing all your source code.
2. **Documentation**: A README file that documents the task, explains your technical choices, how to run the project locally, and any other relevant information.
3. **Demo video:**
1. A screen recording where you show how to install the project and interact with it in real time.
2. A screen recording of your design with an explanation of your design choices and thoughts/problem-solving.
## Timeline
As a reference, this task should take at most **4 days** (96h) from the receipt of this test to submit your deliverables 🚀 
But honestly, if you think you can do a much better job with some extra days (perhaps because you couldnt spend too many hours), be our guest!
At the end of the day, if it is not going to impress the team, its not going to fly, so give it your best shot ✈️
## Questions
Feel free to reach out at any given time with questions about the task, particularly if you encounter problems outside your control that may block your progress.

14
docker-compose.yaml Normal file
View File

@ -0,0 +1,14 @@
services:
api:
build: .
command: uv run fastapi run src/api.py
ports:
- 8000:8000
env_file: .env
environment:
- storage_path=/app/data
volumes:
- ./data:/app/data

38
pyproject.toml Normal file
View File

@ -0,0 +1,38 @@
[project]
name = "python-vector-database"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"cohere>=5.15.0",
"fastapi[standard]>=0.115.12",
"pydantic-settings>=2.9.1",
"pydantic>=2.11.5",
"uuid7>=0.1.0",
]
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
[tool.setuptools.package-dir]
"" = "src"
[dependency-groups]
dev = [
"ipython>=9.2.0",
"marimo>=0.13.15",
"pandas>=2.2.3",
"pyarrow>=20.0.0",
"pytest>=8.3.5",
]
[tool.pytest.ini_options]
pythonpath = "src"
testpaths = "src/tests"
addopts = "-v --tb=shor:"

136
src/api.py Normal file
View File

@ -0,0 +1,136 @@
from pathlib import Path
from typing import Annotated
import time
import marimo
from fastapi import FastAPI, HTTPException, Depends, Request, BackgroundTasks
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from model import (
Document,
DocumentUpload,
Library,
LibraryDoesNotExist,
QueryResult,
LibraryQuery,
)
class ResponseStatus(BaseModel):
status: str = "ok"
def get_library(library_slug: str) -> Library:
try:
library = Library.load(library_slug)
except LibraryDoesNotExist:
raise HTTPException(404)
return library
app = FastAPI(title="Vector Database")
@app.get("/")
def index():
html_content = """
<html>
<head>
<title>Vector Database</title>
<link rel="stylesheet" href="https://unpkg.com/axist@latest/dist/axist.min.css" />
</head>
<body>
<article>
<p>Here are the list of UI elements from this API.</p>
<ul>
<li><a href="/docs">Open API Swagger-like docs</a></li>
<li><a href="/app/create">Marimo App to Create Libraries</a></li>
<li><a href="/app/seed">Marimo App to Seed the database with movie data.</a></li>
<li><a href="/app/insert">Marimo App to Insert Documents</a></li>
<li><a href="/app/search">Marimo App to Search the Library</a></li>
</ul>
</article>
</body>
</html>
"""
return HTMLResponse(content=html_content, status_code=200)
server = marimo.create_asgi_app(include_code=True, quiet=False)
src_dir = Path(__file__).parent
server = server.with_app(
path="/create",
root=src_dir / "notebook/app-create.py",
)
server = server.with_app(
path="/seed",
root=src_dir / "notebook/app-seed.py",
)
server = server.with_app(
path="/insert",
root=src_dir / "notebook/app-insert.py",
)
server = server.with_app(
path="/search",
root=src_dir / "notebook/app-search.py",
)
app.mount("/app", server.build())
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.perf_counter()
response = await call_next(request)
process_time = time.perf_counter() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
@app.get("/library/list")
def list_libraries() -> list[str]:
return Library.list_libraries()
@app.post("/library/create")
def list_library(payload: Library):
if payload.slug in Library.list_libraries():
raise HTTPException(400, "Library already exists.")
payload.save()
return payload.stat
@app.get("/library/{library_slug}")
def get_library_details(
library: Annotated[Library, Depends(get_library)],
):
return library.stat
@app.post("/library/{library_slug}/documents")
def insert_new_document(
payload: DocumentUpload,
library: Annotated[Library, Depends(get_library)],
background_tasks: BackgroundTasks,
):
document = Document.model_validate(payload, from_attributes=True)
document.save(library)
def _create_chunks_for_document():
document.create_chunks(library)
background_tasks.add_task(_create_chunks_for_document)
return ResponseStatus()
@app.post("/library/{library_slug}/search")
def search(
payload: LibraryQuery,
library: Annotated[Library, Depends(get_library)],
) -> list[QueryResult]:
return library.query(payload.query, payload.results)

121
src/embedding.py Normal file
View File

@ -0,0 +1,121 @@
import hashlib
import json
from typing import Iterator
from itertools import chain
import cohere
from utils import chunk_list, text_to_chunks, Point, EmbeddingSize
from settings import settings
co = cohere.ClientV2()
def md5(text: str) -> str:
return hashlib.md5(text.encode("utf-8")).hexdigest()
def _text_chunks(
text: str, chunk_size: int, chunk_overlap: int, search: bool = True
) -> list[str]:
if search:
return [text]
return text_to_chunks(text, chunk_size, chunk_overlap)
def _call_cohere(
texts: list[str],
embedding_size: EmbeddingSize,
search: bool = True,
) -> Iterator[tuple[str, Point]]:
response = co.embed(
inputs=[{"content": [{"type": "text", "text": doc}]} for doc in texts],
model="embed-v4.0",
output_dimension=embedding_size,
input_type="search_query" if search else "search_document",
embedding_types=["float"],
)
return zip(texts, response.embeddings.float)
def embed_store(
text: str,
embedding: Point,
embedding_size: EmbeddingSize,
search: bool = False,
):
if not settings.cohere_cache:
return
_md5 = md5(text)
target_file = (
settings.storage_path
/ ".cohere_embedding"
/ ("search" if search else "storage")
/ _md5
)
embedding_table = {}
if target_file.exists():
embedding_table = json.loads(target_file.read_text())
embedding_table[str(embedding_size)] = embedding
target_file.parent.mkdir(exist_ok=True, parents=True)
target_file.write_text(
json.dumps(embedding_table, indent=2)
)
def embed_lookup(
text: str,
embedding_size: EmbeddingSize,
search: bool = False,
) -> Point | None:
if not settings.cohere_cache:
return
_md5 = md5(text)
target_file = (
settings.storage_path
/ ".cohere_embedding"
/ ("search" if search else "storage")
/ _md5
)
if target_file.exists():
embeddings = json.loads(target_file.read_text())
return embeddings.get(str(embedding_size), None)
def embed(
text: str,
chunk_size: int,
chunk_overlap: int,
embedding_size: EmbeddingSize,
search: bool = True,
) -> list[tuple[str, Point]]:
"""
https://docs.cohere.com/reference/embed
Input type:
> "search_document": Used for embeddings stored...
> "search_query": Used for embeddings of search queries...
Cohere API only allows for at most 96 texts at a time.
"""
text_chunks_list = _text_chunks(text, chunk_size, chunk_overlap, search)
text_chunks = {
_text: embed_lookup(_text, embedding_size, search) for _text in text_chunks_list
}
missing = [_text for _text, embedding in text_chunks.items() if embedding is None]
results = {}
for _texts in chunk_list(missing, 50):
results |= dict(_call_cohere(_texts, embedding_size, search))
for _text, embedding in results.items():
embed_store(_text, embedding, embedding_size, search)
returnable = []
for _text in text_chunks_list:
returnable.append((_text, text_chunks[_text] or results[_text]))
return returnable

544
src/index.py Normal file
View File

@ -0,0 +1,544 @@
import heapq
import math
import random
from uuid import UUID
from typing import Protocol, Generator, TYPE_CHECKING
from pydantic import BaseModel, Field
from uuid_extensions import uuid7
from utils import Point, KMeansBinary
if TYPE_CHECKING:
from model import Library, Chunk
class IndexStrategy(Protocol):
def create_index(self, library: "Library") -> None:
"""Initialize the index for a library"""
...
def add_chunk(self, library: "Library", chunk: "Chunk") -> None:
"""Add a chunk to the index"""
...
def search(
self, library: "Library", query_vector: Point
) -> Generator[tuple[UUID, float], None, None]:
"""Search for k nearest chunks, return (chunk_id, distance) pairs"""
...
class IndexDefinition(BaseModel):
centers: dict[UUID, Point] = {}
@classmethod
def read(cls, library: "Library"):
target_file = library.storage_directory / "index" / "definition.json"
if not target_file.exists():
return None
return cls.model_validate_json(target_file.read_text())
def save(self, library: "Library"):
target_file = library.storage_directory / "index" / "definition.json"
target_file.parent.mkdir(parents=True, exist_ok=True)
target_file.write_text(self.model_dump_json(indent=2))
class IndexBlob(BaseModel):
id: UUID = Field(default_factory=uuid7)
center: Point
members: list[UUID] = []
@classmethod
def read(cls, library: "Library", id: UUID):
target_file = library.storage_directory / "index" / str(id)
if not target_file.exists():
return None
return cls.model_validate_json(target_file.read_text())
def save(self, library: "Library"):
target_file = library.storage_directory / "index" / str(self.id)
target_file.parent.mkdir(parents=True, exist_ok=True)
target_file.write_text(self.model_dump_json(indent=2))
def delete(self, library: "Library"):
target_file = library.storage_directory / "index" / str(self.id)
if target_file.exists():
target_file.unlink()
def split(self, library: "Library"):
from model import Chunk
chunks = {chunk_id: Chunk.read(library, chunk_id) for chunk_id in self.members}
vectors = [chunk for chunk_id, chunk in chunks.items() if chunk is not None]
if len(vectors) < 2:
return
transformed = [library.vector_transformation(v.vector) for v in vectors]
# K-means with k=2 (manual implementation)
centers = KMeansBinary(transformed, library.metric)
# Create new blobs
A = IndexBlob(center=library.vector_transformation(centers[0]))
B = IndexBlob(center=library.vector_transformation(centers[1]))
# Assign members based on original metric
for i, chunk_id in enumerate(self.members):
chunk = chunks[chunk_id]
if chunk is None:
continue
original_vector = chunk.vector
distance_to_A = library.metric(original_vector, A.center)
distance_to_B = library.metric(original_vector, B.center)
if distance_to_A <= distance_to_B:
A.members.append(chunk_id)
else:
B.members.append(chunk_id)
if not A.members or not B.members:
return
index_definition = IndexDefinition.read(library)
if index_definition is None:
return
if self.id in index_definition.centers:
del index_definition.centers[self.id]
index_definition.centers[A.id] = A.center
index_definition.centers[B.id] = B.center
index_definition.save(library)
self.delete(library)
A.save(library)
B.save(library)
def add(self, library: "Library", member: UUID):
self.members.append(member)
self.save(library)
if len(self.members) >= library.index_blob_limit:
self.split(library)
class BlobIndexStrategy:
def create_index(self, library: "Library") -> None:
origin = [0.0] * library.embedding_size
if IndexDefinition.read(library) is None:
index_blob = IndexBlob(center=origin)
index_blob.save(library)
IndexDefinition(centers={index_blob.id: origin}).save(library)
def add_chunk(self, library: "Library", chunk: "Chunk") -> None:
index_definition = IndexDefinition.read(library)
if index_definition is None:
self.create_index(library)
index_definition = IndexDefinition.read(library)
distances = {
key: library.metric(
library.vector_transformation(center),
library.vector_transformation(chunk.vector),
)
for key, center in index_definition.centers.items()
}
if not distances:
return
nearest_index_id = min(distances, key=distances.get)
blob = IndexBlob.read(library, nearest_index_id)
if blob:
blob.add(library, chunk.id)
def search(
self, library: "Library", query_vector: Point
) -> Generator[tuple[UUID, float], None, None]:
from model import Chunk
index_definition = IndexDefinition.read(library)
if index_definition is None:
return
distances = {
key: library.metric(
library.vector_transformation(center),
library.vector_transformation(query_vector),
)
for key, center in index_definition.centers.items()
}
while len(distances) != 0:
nearest = min(distances, key=distances.get)
del distances[nearest]
index_blob = IndexBlob.read(library, nearest)
if index_blob is None:
continue
chunks = {
chunk_id: Chunk.read(library, chunk_id)
for chunk_id in index_blob.members
}
distances_to_chunks = {
chunk.id: library.metric(
query_vector,
library.vector_transformation(chunk.vector),
)
for chunk in chunks.values()
if chunk is not None
}
while len(distances_to_chunks) != 0:
nearest_chunk_id = min(
distances_to_chunks,
key=distances_to_chunks.get,
)
yield nearest_chunk_id, distances_to_chunks[nearest_chunk_id]
del distances_to_chunks[nearest_chunk_id]
class HNSWNode(BaseModel):
id: UUID = Field(default_factory=uuid7)
chunk_id: UUID
vector: Point
level: int
connections: dict[int, list[UUID]] = {} # level -> list of connected node IDs
@classmethod
def read(cls, library: "Library", id: UUID):
target_file = library.storage_directory / "index" / str(id)
if target_file.exists():
return cls.model_validate_json(target_file.read_text())
return None
def save(self, library: "Library"):
target_file = library.storage_directory / "index" / str(self.id)
target_file.parent.mkdir(parents=True, exist_ok=True)
try:
target_file.write_text(self.model_dump_json(indent=2))
except Exception:
pass # Fail silently to avoid breaking the indexing process
def delete(self, library: "Library"):
target_file = library.storage_directory / "index" / str(self.id)
if target_file.exists():
try:
target_file.unlink()
except Exception:
pass # Fail silently
class HNSWGraph(BaseModel):
entry_point: UUID | None = None
max_level: int = 0
ml: float = 1.0 / math.log(2.0)
max_connections_0: int = 16
max_connections: int = 8
ef_construction: int = 200 # Size of candidate list during construction
max_level_limit: int = 20
nodes: dict[UUID, UUID] = {} # chunk_id -> node_id mapping
@classmethod
def read(cls, library: "Library"):
target_file = library.storage_directory / "index" / "definition.json"
if not target_file.exists():
return cls()
try:
return cls.model_validate_json(target_file.read_text())
except Exception:
return cls()
def save(self, library: "Library"):
target_file = library.storage_directory / "index" / "definition.json"
target_file.parent.mkdir(parents=True, exist_ok=True)
try:
target_file.write_text(self.model_dump_json(indent=2))
except Exception:
pass
def get_random_level(self) -> int:
"""Generate random level for new node using exponential decay"""
level = int(-math.log(random.uniform(1e-8, 1.0)) * self.ml)
return min(level, self.max_level_limit)
class HNSWIndexStrategy:
def create_index(self, library: "Library") -> None:
"""Initialize empty HNSW graph"""
graph = HNSWGraph()
graph.save(library)
def add_chunk(self, library: "Library", chunk: "Chunk") -> None:
"""Add a chunk to the HNSW graph"""
graph = HNSWGraph.read(library)
# Create new node
level = graph.get_random_level()
node = HNSWNode(
chunk_id=chunk.id,
vector=chunk.vector,
level=level,
connections={i: [] for i in range(level + 1)},
)
# If this is the first node, make it the entry point
if graph.entry_point is None:
graph.entry_point = node.id
graph.max_level = level
node.save(library)
graph.nodes[chunk.id] = node.id
graph.save(library)
return
# Find entry point and verify it exists
entry_node = HNSWNode.read(library, graph.entry_point)
if entry_node is None:
# Entry point corrupted, make this the new entry
graph.entry_point = node.id
graph.max_level = level
node.save(library)
graph.nodes[chunk.id] = node.id
graph.save(library)
return
entry_dist = library.metric(node.vector, entry_node.vector)
current_closest = [(entry_dist, entry_node.id)]
# Search from top level down to level + 1 (greedy search with ef=1)
for lc in range(graph.max_level, level, -1):
current_closest = self._search_layer(
library, node.vector, current_closest, 1, lc
)
# Search and connect from level down to 0
for lc in range(min(level, graph.max_level), -1, -1):
# Use ef_construction for building robust connections
candidates = self._search_layer(
library, node.vector, current_closest, graph.ef_construction, lc
)
# Select neighbors and create bidirectional connections
max_conn = graph.max_connections if lc > 0 else graph.max_connections_0
selected = self._select_neighbors(
library, node.vector, candidates, max_conn
)
# Add connections
for dist, neighbor_id in selected:
node.connections[lc].append(neighbor_id)
neighbor_node = HNSWNode.read(library, neighbor_id)
if neighbor_node and lc in neighbor_node.connections:
neighbor_node.connections[lc].append(node.id)
# Prune neighbor connections if they exceed max
if len(neighbor_node.connections[lc]) > max_conn:
neighbor_candidates = []
for nid in neighbor_node.connections[lc]:
n = HNSWNode.read(library, nid)
if n is not None:
dist = library.metric(neighbor_node.vector, n.vector)
neighbor_candidates.append((dist, nid))
if neighbor_candidates:
pruned = self._select_neighbors(
library,
neighbor_node.vector,
neighbor_candidates,
max_conn,
)
neighbor_node.connections[lc] = [nid for _, nid in pruned]
neighbor_node.save(library)
current_closest = selected
# Update graph metadata
if level > graph.max_level:
graph.max_level = level
graph.entry_point = node.id
graph.nodes[chunk.id] = node.id
node.save(library)
graph.save(library)
def _search_layer(
self,
library: "Library",
query: Point,
entry_points: list[tuple[float, UUID]],
ef: int,
level: int,
) -> list[tuple[float, UUID]]:
"""
Search for ef closest nodes in a specific layer.
ef (exploration factor) determines how many candidates we track during search:
- Higher ef = more thorough search, better recall, slower
- Lower ef = faster search, potentially missing good candidates
- During construction: use ef_construction for robust graph building
- During search: can use smaller ef for speed vs quality tradeoff
"""
visited = set()
candidates = [] # max heap for exploration: (-distance, node_id)
w = [] # min heap for results: (distance, node_id)
# Initialize with entry points
for dist, node_id in entry_points:
if node_id not in visited:
heapq.heappush(candidates, (-dist, node_id))
heapq.heappush(w, (dist, node_id))
visited.add(node_id)
while candidates:
# Get closest unvisited candidate
neg_current_dist, current_id = heapq.heappop(candidates)
current_dist = -neg_current_dist
# Stop if current distance is worse than ef-th best found so far
if len(w) >= ef and current_dist > max(w)[0]:
break
current_node = HNSWNode.read(library, current_id)
if not current_node or level not in current_node.connections:
continue
# Explore all neighbors at this level
for neighbor_id in current_node.connections[level]:
if neighbor_id not in visited:
visited.add(neighbor_id)
neighbor_node = HNSWNode.read(library, neighbor_id)
if neighbor_node:
dist = library.metric(query, neighbor_node.vector)
if len(w) < ef:
# Still have room in result set
heapq.heappush(candidates, (-dist, neighbor_id))
heapq.heappush(w, (dist, neighbor_id))
elif dist < max(w)[0]:
# Better than worst in result set
heapq.heappush(candidates, (-dist, neighbor_id))
# Remove worst from w and add new candidate
w_list = list(w)
w_list.remove(max(w_list))
w = w_list
heapq.heapify(w)
heapq.heappush(w, (dist, neighbor_id))
return sorted(w)
def _select_neighbors(
self,
library: "Library",
query: Point,
candidates: list[tuple[float, UUID]],
max_connections: int,
) -> list[tuple[float, UUID]]:
"""Select diverse neighbors to avoid hubs and improve search quality"""
if len(candidates) <= max_connections:
return candidates
candidates.sort(key=lambda x: x[0])
selected = []
remaining = candidates[:]
# Always select closest
if remaining:
selected.append(remaining.pop(0))
# Select remaining with diversity consideration
while len(selected) < max_connections and remaining:
best_candidate = None
best_score = float("inf")
best_idx = -1
for i, (dist, candidate_id) in enumerate(remaining):
# Find minimum distance to already selected nodes (diversity)
diversity_score = float("inf")
candidate_node = HNSWNode.read(library, candidate_id)
if candidate_node:
for _, selected_id in selected:
selected_node = HNSWNode.read(library, selected_id)
if selected_node:
div_dist = library.metric(
candidate_node.vector, selected_node.vector
)
diversity_score = min(diversity_score, div_dist)
# Combined score: prefer close nodes that aren't too close to selected ones
combined_score = dist - 0.1 * diversity_score
if combined_score < best_score:
best_score = combined_score
best_candidate = (dist, candidate_id)
best_idx = i
if best_candidate:
selected.append(best_candidate)
remaining.pop(best_idx)
else:
selected.append(remaining.pop(0))
return selected
def search(
self, library: "Library", query_vector: Point
) -> Generator[tuple[UUID, float], None, None]:
"""
True generator search - yields results indefinitely as long as caller keeps asking.
The search explores the graph in waves, returning results in order of distance.
It maintains a frontier of unexplored candidates and yields the next best result
each time it's called.
"""
graph = HNSWGraph.read(library)
if graph.entry_point is None:
return
entry_node = HNSWNode.read(library, graph.entry_point)
if not entry_node:
return
# Start with entry point
entry_dist = library.metric(query_vector, entry_node.vector)
current_closest = [(entry_dist, graph.entry_point)]
# Navigate down to level 1 (greedy search)
for level in range(graph.max_level, 0, -1):
current_closest = self._search_layer(
library, query_vector, current_closest, 1, level
)
# Now do the actual search at level 0 using a priority queue for true streaming
visited = set()
candidates = [] # min heap: (distance, node_id)
# Initialize with level 0 entry points
for dist, node_id in current_closest:
if node_id not in visited:
heapq.heappush(candidates, (dist, node_id))
visited.add(node_id)
yielded = set() # Track what we've already yielded
while candidates:
# Get the closest unvisited candidate
current_dist, current_id = heapq.heappop(candidates)
# Yield this result if we haven't already
if current_id not in yielded:
current_node = HNSWNode.read(library, current_id)
if current_node:
yielded.add(current_id)
yield current_node.chunk_id, current_dist
# Add all unvisited neighbors to candidates for future exploration
current_node = HNSWNode.read(library, current_id)
if current_node and 0 in current_node.connections:
for neighbor_id in current_node.connections[0]:
if neighbor_id not in visited:
visited.add(neighbor_id)
neighbor_node = HNSWNode.read(library, neighbor_id)
if neighbor_node:
dist = library.metric(query_vector, neighbor_node.vector)
heapq.heappush(candidates, (dist, neighbor_id))

228
src/model.py Normal file
View File

@ -0,0 +1,228 @@
from pathlib import Path
from uuid import UUID
from pydantic import BaseModel, Field
from uuid_extensions import uuid7
from embedding import embed
from index import (
IndexStrategy,
BlobIndexStrategy,
HNSWIndexStrategy,
)
from utils import (
Metric,
Point,
IndexAlgorithm,
EmbeddingSize,
VectorTransformation,
cosine_distance,
identity_transformation,
)
from settings import settings
class LibraryDoesNotExist(Exception): ...
class LibraryQuery(BaseModel):
query: str
results: int = 5
class QueryResult(BaseModel):
document: "Document"
chunk: "Chunk"
hits: int = 0
distance: float
class Library(BaseModel):
slug: str
embedding_size: EmbeddingSize = 256
chunk_size: int = 128
chunk_overlap: int = 32
index: IndexAlgorithm = "cosine"
index_blob_limit: int = 50
@property
def stat(self) -> "LibrarySummary":
return LibrarySummary.model_validate(
{
**self.model_dump(),
"documents": len(
list((self.storage_directory / "documents").iterdir())
),
"chunks": len(list((self.storage_directory / "chunks").iterdir())),
"indexs": len(list((self.storage_directory / "index").iterdir())) - 1,
}
)
@property
def storage_directory(self) -> Path:
return settings.storage_path / self.slug
@classmethod
def list_libraries(cls):
return [
x.name
for x in filter(lambda x: x.is_dir(), settings.storage_path.iterdir())
if x.name[0] != "."
]
@classmethod
def load(cls, slug: str):
target_file = settings.storage_path / slug / "definition.json"
if not target_file.exists():
raise LibraryDoesNotExist()
return cls.model_validate_json(target_file.read_text())
def save(self):
target_file = self.storage_directory / "definition.json"
target_file.parent.mkdir(parents=True, exist_ok=True)
target_file.write_text(self.model_dump_json(indent=2))
(self.storage_directory / "documents").mkdir(exist_ok=True)
(self.storage_directory / "chunks").mkdir(exist_ok=True)
(self.storage_directory / "index").mkdir(exist_ok=True)
self.index_strategy.create_index(self)
@property
def index_strategy(self) -> IndexStrategy:
"""Factory method to get the appropriate index strategy"""
if self.index == "hnsw":
return HNSWIndexStrategy()
return BlobIndexStrategy()
@property
def metric(self) -> Metric:
return cosine_distance
@property
def vector_transformation(self) -> VectorTransformation:
return identity_transformation
def chunk_id(self) -> list[UUID]:
target_directory = self.storage_directory / "chunks"
def _file_to_uuid():
for file in target_directory.iterdir():
if file.is_file():
yield UUID(file.name)
return list(_file_to_uuid())
def _text_embed(
self,
text: str,
search: bool = True,
) -> list[tuple[str, Point]]:
return embed(
text,
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
embedding_size=self.embedding_size,
search=search,
)
def add_document(self, slug: str, content: str):
existing_document = Document.read(self, slug)
if existing_document:
return
new_document = Document(slug=slug, content=content)
new_document.save(self)
new_document.create_chunks(self)
return new_document
def query(self, content: str, k: int = 5) -> list[QueryResult]:
embeddings = self._text_embed(content, search=True)
vector = embeddings[0][1]
transformed_vector = self.vector_transformation(vector)
search_generator = self.index_strategy.search(self, transformed_vector)
hit_documents = {}
for chunk_id, distance in search_generator:
chunk = Chunk.read(self, chunk_id)
if chunk is None:
continue
if chunk.document not in hit_documents:
hit_documents[chunk.document] = QueryResult(
chunk=chunk,
document=Document.read(self, chunk.document),
hits=1,
distance=distance,
)
else:
hit_documents[chunk.document].hits += 1
if distance < hit_documents[chunk.document].distance:
# This should not happen.
hit_documents[chunk.document].distance = distance
hit_documents[chunk.document].chunk = chunk
if len(hit_documents) >= k:
break
return list(hit_documents.values())
class LibrarySummary(Library):
documents: int
chunks: int
indexs: int
class Chunk(BaseModel):
id: UUID = Field(default_factory=uuid7)
vector: Point
text: str
document: str
@classmethod
def read(cls, library: Library, id: UUID):
target_file = library.storage_directory / "chunks" / str(id)
if not target_file.exists():
return None
return cls.model_validate_json(target_file.read_text())
def save(self, library: Library):
target_file = library.storage_directory / "chunks" / str(self.id)
target_file.write_text(self.model_dump_json(indent=2))
def index(self, library: Library):
library.index_strategy.add_chunk(library, self)
class DocumentUpload(BaseModel):
slug: str
content: str
class Document(DocumentUpload):
chunks: list[UUID] = []
@classmethod
def read(cls, library: Library, slug: str):
target_file = library.storage_directory / "documents" / slug
if target_file.exists():
return cls.model_validate_json(target_file.read_text())
return None
def save(self, library: Library):
target_file = library.storage_directory / "documents" / self.slug
target_file.write_text(self.model_dump_json(indent=2))
def create_chunks(self, library: Library):
slug, content = self.slug, self.content
embedding = library._text_embed(f"{slug}\n{content}", search=False)
for text, vector in embedding:
chunk = Chunk(vector=vector, text=text, document=slug)
chunk.save(library)
chunk.index(library)
self.chunks.append(chunk.id)
self.save(library)

148
src/notebook/app-create.py Normal file
View File

@ -0,0 +1,148 @@
import marimo
__generated_with = "0.13.15"
app = marimo.App(width="medium")
@app.cell
def _():
from itertools import chain
import marimo as mo
from model import Library, Chunk, Document, LibraryDoesNotExist
from utils import EmbeddingSize, IndexAlgorithm
from notebook.constants import animals, adjetives, datasets
return EmbeddingSize, IndexAlgorithm, Library, LibraryDoesNotExist, mo
@app.cell
def _(mo):
mo.md("""# Create Library Widget""")
return
@app.cell
def _(mo):
get_create_state, set_create_state = mo.state(False)
return get_create_state, set_create_state
@app.cell
def _(mo, set_create_state):
create_button = mo.ui.button(
label="Create Library", value=True, on_click=lambda _: set_create_state(True)
)
return (create_button,)
@app.cell
def _(EmbeddingSize, IndexAlgorithm, mo):
library_name = mo.ui.text(placeholder="Library Name", label="Library Name")
embedding_size = mo.ui.dropdown(
EmbeddingSize.__args__, label="Embedding Size", value=EmbeddingSize.__args__[0]
)
chunk_size = mo.ui.number(value=128, label="Chunk Size")
chunk_overlap = mo.ui.number(start=10, value=64, label="Chunk Overlap")
index_blob_limit = mo.ui.number(start=10, value=100, label="Index Blob Size")
index_type = mo.ui.dropdown(IndexAlgorithm.__args__, value=IndexAlgorithm.__args__[0], label="Index Type")
return (
chunk_overlap,
chunk_size,
embedding_size,
index_blob_limit,
index_type,
library_name,
)
@app.cell
def _(
chunk_overlap,
chunk_size,
create_button,
embedding_size,
index_blob_limit,
index_type,
library_name,
mo,
):
fields = [
library_name,
embedding_size,
chunk_size,
chunk_overlap,
index_type,
index_blob_limit,
]
mo.vstack(
[
*fields,
create_button,
]
)
return (fields,)
@app.cell
def _(
Library,
LibraryDoesNotExist,
chunk_overlap,
chunk_size,
embedding_size,
fields,
get_create_state,
index_blob_limit,
index_type,
library_name,
mo,
set_create_state,
):
output = [mo.md("### Result")]
if get_create_state() and all(map(lambda x: bool(x.value), fields)):
try:
library = Library.load(library_name.value)
output += [mo.md("> Library already exists!")]
except LibraryDoesNotExist:
library = Library(
slug=library_name.value,
embedding_size=embedding_size.value,
chunk_overlap=chunk_overlap.value,
chunk_size=chunk_size.value,
index_blob_limit=index_blob_limit.value,
index=index_type.value,
)
library.save()
output += [mo.md("> Create!")]
output += [mo.ui.table(library.stat.model_dump())]
set_create_state(False)
mo.vstack(output)
return
@app.cell
def _(Library, mo):
_libraries = Library.list_libraries()
mo.vstack(
[
mo.md("# Existing Libraries"),
*[
mo.md(
f"""/// details | **{library.slug}:**\n\n\n{mo.ui.table(library.stat.model_dump())}\n///"""
)
for library in map(Library.load, _libraries)
],
*([mo.md("No libraries created yet.")] if len(_libraries) == 0 else []),
],
)
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

113
src/notebook/app-insert.py Normal file
View File

@ -0,0 +1,113 @@
import marimo
__generated_with = "0.13.15"
app = marimo.App(width="medium")
@app.cell
def _():
import marimo as mo
from model import Library, Chunk, Document
return Chunk, Library, mo
@app.cell
def _(Library, mo):
libraries = Library.list_libraries()
library_dropdown = mo.ui.dropdown(
options=libraries, value=libraries[0], label="Choose Active Library"
)
slug_input = mo.ui.text(placeholder="Document Title", label="Title")
content_input = mo.ui.text_area(
placeholder="Lorem ipsum...", label="Document Content", full_width=True
)
return content_input, library_dropdown, slug_input
@app.cell
def _(mo):
get_insert_state, set_insert_state = mo.state(False)
return get_insert_state, set_insert_state
@app.cell
def _(mo, set_insert_state):
insert_button = mo.ui.button(
label="Insert Document", value=True, on_click=lambda _: set_insert_state(True)
)
return (insert_button,)
@app.cell
def _(content_input, insert_button, library_dropdown, mo, slug_input):
mo.vstack([library_dropdown, slug_input, content_input, insert_button.right()])
return
@app.cell
def _(Library, library_dropdown, mo):
library = Library.load(library_dropdown.value)
mo.vstack([mo.md("### Library Summary"), mo.ui.table(library.stat.model_dump())])
return (library,)
@app.cell
def _(
Chunk,
content_input,
get_insert_state,
library,
mo,
set_insert_state,
slug_input,
):
output = [mo.md("### Result")]
if get_insert_state() and slug_input.value and content_input.value:
print("Inserting")
document = library.add_document(slug_input.value, content_input.value)
chunks = []
for chunk in document.chunks:
chunks.append(Chunk.read(library, chunk))
output += [
mo.md(f"""
/// details | **Chunk ID:** `{chunk.id}`
**Chunk Content**:
```
{chunk.text}
```
**Vector**:
```
{chunk.vector}
```
///
""")
for chunk in chunks
]
else:
output += [mo.md("> Please provide the data to insert a document.")]
if not slug_input.value:
print("Missing Title")
if not content_input.value:
print("Missing document")
set_insert_state(False)
mo.vstack(output)
return
@app.cell
def _():
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

102
src/notebook/app-search.py Normal file
View File

@ -0,0 +1,102 @@
import marimo
__generated_with = "0.13.15"
app = marimo.App(width="medium")
@app.cell
def _():
import marimo as mo
from model import Library, Chunk, Document
mo.md("# Search on Library Widget")
return Library, mo
@app.cell
def _(Library, mo):
libraries = Library.list_libraries()
library_dropdown = mo.ui.dropdown(
options=libraries, value=libraries[0], label="Choose Active Library"
)
search_query = mo.ui.text(placeholder="Search...", label="Search Query")
n_results = mo.ui.number(start=1, stop=20, label="Results")
return library_dropdown, n_results, search_query
@app.cell
def _(library_dropdown, mo, n_results, search_query):
mo.hstack(
[
library_dropdown,
search_query,
n_results,
]
)
return
@app.cell
def _(Library, library_dropdown, mo, n_results, search_query):
search_query.value, n_results.value
library = Library.load(library_dropdown.value)
mo.vstack([mo.md("### Library Summary"), mo.ui.table(library.stat.model_dump())])
return (library,)
@app.cell
def _(library, mo, n_results, search_query):
output = []
if search_query.value and n_results.value:
results = library.query(search_query.value, n_results.value)
output = [mo.md("### Search Results")] + [
mo.md(
f"""
/// details | **Hits:** {result.hits} **Document:** `{result.document.slug}` | **Distance:** {result.distance}
**Chunk ID:** `{result.chunk.id}`
**Chunk**:
```
{result.chunk.text}
```
**Document**:
```
{result.document.content}
```
///
"""
)
for result in results
]
else:
output = [
mo.md("""
/// details | No search yet.
type: warn
Please type a search query and hit the return key (or click away from the input).
///
""")
]
mo.vstack(output)
return
@app.cell
def _():
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

65
src/notebook/app-seed.py Normal file
View File

@ -0,0 +1,65 @@
import marimo
__generated_with = "0.13.15"
app = marimo.App(width="medium")
@app.cell
def _():
import marimo as mo
import pandas as pd
from model import Library
return Library, mo, pd
@app.cell
def _(mo):
mo.md("""### Seed Data To Database""")
return
@app.cell
def _(Library, mo):
libraries = Library.list_libraries()
library_dropdown = mo.ui.dropdown(options=libraries, value=libraries[0], label="Choose Active Library")
return (library_dropdown,)
@app.cell
def _(pd):
df = pd.read_json("hf://datasets/rohitsaxena/MovieSum/train.jsonl", lines=True)
return (df,)
@app.cell
def _(df):
df[["movie_name", "summary"]]
return
@app.cell
def _(Library, df, library_dropdown, mo):
def _on_click(*args, **kwargs):
library = Library.load(library_dropdown.value)
total = len(df)
for i, row in mo.status.progress_bar(df[["movie_name", "summary"]].iterrows(), total=total):
title, summary = row.to_list()
library.add_document(title, summary)
button = mo.ui.button(label="Add these documents.", on_click=_on_click)
return (button,)
@app.cell
def _(button, library_dropdown, mo):
mo.vstack([library_dropdown, button])
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

756
src/notebook/constants.py Normal file
View File

@ -0,0 +1,756 @@
datasets = {
"math": "hf://datasets/WhiteGiverPlus/extract_theorem_en_400/data/train-00000-of-00001.parquet",
"stories": "hf://datasets/deven367/babylm-100M-children-stories/data/train-00000-of-00001-13f0a33230d64dd9.parquet",
"news": "hf://datasets/fancyzhx/ag_news/data/train-00000-of-00001.parquet",
}
animals = [
"Aardvark",
"Albatross",
"Alligator",
"Alpaca",
"Ant",
"Anteater",
"Antelope",
"Ape",
"Armadillo",
"Donkey",
"Baboon",
"Badger",
"Barracuda",
"Bat",
"Bear",
"Beaver",
"Bee",
"Bison",
"Boar",
"Buffalo",
"Butterfly",
"Camel",
"Capybara",
"Caribou",
"Cassowary",
"Cat",
"Caterpillar",
"Cattle",
"Chamois",
"Cheetah",
"Chicken",
"Chimpanzee",
"Chinchilla",
"Chough",
"Clam",
"Cobra",
"Cockroach",
"Cod",
"Cormorant",
"Coyote",
"Crab",
"Crane",
"Crocodile",
"Crow",
"Curlew",
"Deer",
"Dinosaur",
"Dog",
"Dogfish",
"Dolphin",
"Dotterel",
"Dove",
"Dragonfly",
"Duck",
"Dugong",
"Dunlin",
"Eagle",
"Echidna",
"Eel",
"Eland",
"Elephant",
"Elk",
"Emu",
"Falcon",
"Ferret",
"Finch",
"Fish",
"Flamingo",
"Fly",
"Fox",
"Frog",
"Gaur",
"Gazelle",
"Gerbil",
"Giraffe",
"Gnat",
"Gnu",
"Goat",
"Goldfinch",
"Goldfish",
"Goose",
"Gorilla",
"Goshawk",
"Grasshopper",
"Grouse",
"Guanaco",
"Gull",
"Hamster",
"Hare",
"Hawk",
"Hedgehog",
"Heron",
"Herring",
"Hippopotamus",
"Hornet",
"Horse",
"Human",
"Hummingbird",
"Hyena",
"Ibex",
"Ibis",
"Jackal",
"Jaguar",
"Jay",
"Jellyfish",
"Kangaroo",
"Kingfisher",
"Koala",
"Kookabura",
"Kouprey",
"Kudu",
"Lapwing",
"Lark",
"Lemur",
"Leopard",
"Lion",
"Llama",
"Lobster",
"Locust",
"Loris",
"Louse",
"Lyrebird",
"Magpie",
"Mallard",
"Manatee",
"Mandrill",
"Mantis",
"Marten",
"Meerkat",
"Mink",
"Mole",
"Mongoose",
"Monkey",
"Moose",
"Mosquito",
"Mouse",
"Mule",
"Narwhal",
"Newt",
"Nightingale",
"Octopus",
"Okapi",
"Opossum",
"Oryx",
"Ostrich",
"Otter",
"Owl",
"Oyster",
"Panther",
"Parrot",
"Partridge",
"Peafowl",
"Pelican",
"Penguin",
"Pheasant",
"Pig",
"Pigeon",
"Pony",
"Porcupine",
"Porpoise",
"Quail",
"Quelea",
"Quetzal",
"Rabbit",
"Raccoon",
"Rail",
"Ram",
"Rat",
"Raven",
"Red deer",
"Red panda",
"Reindeer",
"Rhinoceros",
"Rook",
"Salamander",
"Salmon",
"Sand Dollar",
"Sandpiper",
"Sardine",
"Scorpion",
"Seahorse",
"Seal",
"Shark",
"Sheep",
"Shrew",
"Skunk",
"Snail",
"Snake",
"Sparrow",
"Spider",
"Spoonbill",
"Squid",
"Squirrel",
"Starling",
"Stingray",
"Stinkbug",
"Stork",
"Swallow",
"Swan",
"Tapir",
"Tarsier",
"Termite",
"Tiger",
"Toad",
"Trout",
"Turkey",
"Turtle",
"Viper",
"Vulture",
"Wallaby",
"Walrus",
"Wasp",
"Weasel",
"Whale",
"Wildcat",
"Wolf",
"Wolverine",
"Wombat",
"Woodcock",
"Woodpecker",
"Worm",
"Wren",
"Yak",
"Zebra",
]
adjetives = [
"different",
"used",
"important",
"every",
"large",
"available",
"popular",
"able",
"basic",
"known",
"various",
"difficult",
"several",
"united",
"historical",
"hot",
"useful",
"mental",
"scared",
"additional",
"emotional",
"old",
"political",
"similar",
"healthy",
"financial",
"medical",
"traditional",
"federal",
"entire",
"strong",
"actual",
"significant",
"successful",
"electrical",
"expensive",
"pregnant",
"intelligent",
"interesting",
"poor",
"happy",
"responsible",
"cute",
"helpful",
"recent",
"willing",
"nice",
"wonderful",
"impossible",
"serious",
"huge",
"rare",
"technical",
"typical",
"competitive",
"critical",
"electronic",
"immediate",
"aware",
"educational",
"environmental",
"global",
"legal",
"relevant",
"accurate",
"capable",
"dangerous",
"dramatic",
"efficient",
"powerful",
"foreign",
"hungry",
"practical",
"psychological",
"severe",
"suitable",
"numerous",
"sufficient",
"unusual",
"consistent",
"cultural",
"existing",
"famous",
"pure",
"afraid",
"obvious",
"careful",
"latter",
"unhappy",
"acceptable",
"aggressive",
"boring",
"distinct",
"eastern",
"logical",
"reasonable",
"strict",
"administrative",
"automatic",
"civil",
"former",
"massive",
"southern",
"unfair",
"visible",
"alive",
"angry",
"desperate",
"exciting",
"friendly",
"lucky",
"realistic",
"sorry",
"ugly",
"unlikely",
"anxious",
"comprehensive",
"curious",
"impressive",
"informal",
"inner",
"pleasant",
"sexual",
"sudden",
"terrible",
"unable",
"weak",
"wooden",
"asleep",
"confident",
"conscious",
"decent",
"embarrassed",
"guilty",
"lonely",
"mad",
"nervous",
"odd",
"remarkable",
"substantial",
"suspicious",
"tall",
"tiny",
"other",
"such",
"even",
"new",
"just",
"good",
"any",
"each",
"much",
"own",
"great",
"another",
"same",
"few",
"free",
"right",
"still",
"best",
"public",
"human",
"both",
"local",
"sure",
"better",
"general",
"specific",
"enough",
"long",
"small",
"less",
"high",
"certain",
"little",
"common",
"next",
"simple",
"hard",
"past",
"big",
"possible",
"particular",
"real",
"major",
"personal",
"current",
"left",
"national",
"least",
"natural",
"physical",
"short",
"last",
"single",
"individual",
"main",
"potential",
"professional",
"international",
"lower",
"open",
"according",
"alternative",
"special",
"working",
"true",
"whole",
"clear",
"dry",
"easy",
"cold",
"commercial",
"full",
"low",
"primary",
"worth",
"necessary",
"positive",
"present",
"close",
"creative",
"green",
"late",
"fit",
"glad",
"proper",
"complex",
"content",
"due",
"effective",
"middle",
"regular",
"fast",
"independent",
"original",
"wide",
"beautiful",
"complete",
"active",
"negative",
"safe",
"visual",
"wrong",
"ago",
"quick",
"ready",
"straight",
"white",
"direct",
"excellent",
"extra",
"junior",
"pretty",
"unique",
"classic",
"final",
"overall",
"private",
"separate",
"western",
"alone",
"familiar",
"official",
"perfect",
"bright",
"broad",
"comfortable",
"flat",
"rich",
"warm",
"young",
"heavy",
"valuable",
"correct",
"leading",
"slow",
"clean",
"fresh",
"normal",
"secret",
"tough",
"brown",
"cheap",
"deep",
"objective",
"secure",
"thin",
"chemical",
"cool",
"extreme",
"exact",
"fair",
"fine",
"formal",
"opposite",
"remote",
"total",
"vast",
"lost",
"smooth",
"dark",
"double",
"equal",
"firm",
"frequent",
"internal",
"sensitive",
"constant",
"minor",
"previous",
"raw",
"soft",
"solid",
"weird",
"amazing",
"annual",
"busy",
"dead",
"false",
"round",
"sharp",
"thick",
"wise",
"equivalent",
"initial",
"narrow",
"nearby",
"proud",
"spiritual",
"wild",
"adult",
"apart",
"brief",
"crazy",
"prior",
"rough",
"sad",
"sick",
"strange",
"external",
"illegal",
"loud",
"mobile",
"nasty",
"ordinary",
"royal",
"senior",
"super",
"tight",
"upper",
"yellow",
"dependent",
"funny",
"gross",
"ill",
"spare",
"sweet",
"upstairs",
"usual",
"brave",
"calm",
"dirty",
"downtown",
"grand",
"honest",
"loose",
"male",
"quiet",
"brilliant",
"dear",
"drunk",
"empty",
"female",
"inevitable",
"neat",
"ok",
"representative",
"silly",
"slight",
"smart",
"stupid",
"temporary",
"weekly",
"that",
"this",
"what",
"which",
"time",
"these",
"work",
"no",
"only",
"then",
"first",
"money",
"over",
"business",
"his",
"game",
"think",
"after",
"life",
"day",
"home",
"economy",
"away",
"either",
"fat",
"key",
"training",
"top",
"level",
"far",
"fun",
"house",
"kind",
"future",
"action",
"live",
"period",
"subject",
"mean",
"stock",
"chance",
"beginning",
"upset",
"chicken",
"head",
"material",
"salt",
"car",
"appropriate",
"inside",
"outside",
"standard",
"medium",
"choice",
"north",
"square",
"born",
"capital",
"shot",
"front",
"living",
"plastic",
"express",
"feeling",
"otherwise",
"plus",
"savings",
"animal",
"budget",
"minute",
"character",
"maximum",
"novel",
"plenty",
"select",
"background",
"forward",
"glass",
"joint",
"master",
"red",
"vegetable",
"ideal",
"kitchen",
"mother",
"party",
"relative",
"signal",
"street",
"connect",
"minimum",
"sea",
"south",
"status",
"daughter",
"hour",
"trick",
"afternoon",
"gold",
"mission",
"agent",
"corner",
"east",
"neither",
"parking",
"routine",
"swimming",
"winter",
"airline",
"designer",
"dress",
"emergency",
"evening",
"extension",
"holiday",
"horror",
"mountain",
"patient",
"proof",
"west",
"wine",
"expert",
"native",
"opening",
"silver",
"waste",
"plane",
"leather",
"purple",
"specialist",
"bitter",
"incident",
"motor",
"pretend",
"prize",
"resident",
]

15
src/settings.py Normal file
View File

@ -0,0 +1,15 @@
from pathlib import Path
from pydantic import SecretStr
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
co_api_key: SecretStr
cohere_cache: bool = True
storage_path: Path = Path("./data/")
fastapi_port: int = 8000
fastapi_host: str = "localhost"
settings = Settings()
settings.storage_path.mkdir(exist_ok=True, parents=True)

267
src/tests/test_kmeans.py Normal file
View File

@ -0,0 +1,267 @@
import pytest
from knn import (
KMeansBinary,
Point,
calculate_centroid,
classify_point_on_clusters,
compare_clusters,
vector_scalar_multiplication,
vector_sum,
)
# Test fixtures and helper functions
@pytest.fixture
def euclidean_distance():
def distance(p1: Point, p2: Point) -> float:
return (sum((a - b) ** 2 for a, b in zip(p1, p2)))**0.5
return distance
@pytest.fixture
def manhattan_distance():
def distance(p1: Point, p2: Point) -> float:
return sum(abs(a - b) for a, b in zip(p1, p2))
return distance
def points_close(p1: Point, p2: Point, tolerance: float = 1e-6) -> bool:
"""Check if two points are approximately equal"""
return all(abs(a - b) < tolerance for a, b in zip(p1, p2))
class TestVectorOperations:
def test_vector_scalar_multiplication(self):
vector = [1.0, 2.0, 3.0]
result = vector_scalar_multiplication(2, vector)
assert result == [2.0, 4.0, 6.0]
# Test with zero scalar
result = vector_scalar_multiplication(0, vector)
assert result == [0.0, 0.0, 0.0]
# Test with negative scalar
result = vector_scalar_multiplication(-1.5, vector)
assert result == [-1.5, -3.0, -4.5]
def test_vector_sum(self):
a = [1.0, 2.0, 3.0]
b = [4.0, 5.0, 6.0]
result = vector_sum(a, b)
assert result == [5.0, 7.0, 9.0]
# Test with zeros
zero = [0.0, 0.0, 0.0]
result = vector_sum(a, zero)
assert result == a
def test_vector_sum_different_lengths(self):
a = [1.0, 2.0]
b = [3.0, 4.0, 5.0]
with pytest.raises(AssertionError):
vector_sum(a, b)
class TestCentroid:
def test_calculate_centroid_simple(self):
points = [[0.0, 0.0], [2.0, 0.0], [1.0, 2.0]]
centroid = calculate_centroid(points)
expected = [1.0, 2.0 / 3.0]
assert points_close(centroid, expected)
def test_calculate_centroid_single_point(self):
points = [[3.0, 4.0]]
centroid = calculate_centroid(points)
assert points_close(centroid, [3.0, 4.0])
def test_calculate_centroid_empty_list(self):
points = []
centroid = calculate_centroid(points)
assert centroid == []
def test_calculate_centroid_identical_points(self):
points = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]
centroid = calculate_centroid(points)
assert points_close(centroid, [1.0, 1.0])
class TestClassification:
def test_classify_point_on_clusters(self, euclidean_distance):
point = [1.0, 1.0]
centers = ([0.0, 0.0], [3.0, 3.0])
# Point should be closer to first center
result = classify_point_on_clusters(point, centers, euclidean_distance)
assert result == True
# Test point closer to second center
point = [2.5, 2.5]
result = classify_point_on_clusters(point, centers, euclidean_distance)
assert result == False
def test_classify_point_equidistant(self, euclidean_distance):
point = [1.5, 1.5]
centers = ([0.0, 0.0], [3.0, 3.0])
# Point is equidistant, should return False (>=)
result = classify_point_on_clusters(point, centers, euclidean_distance)
assert result == False
class TestClusterComparison:
def test_compare_clusters_identical(self):
cluster_a = [[1.0, 2.0], [3.0, 4.0]]
cluster_b = [[1.0, 2.0], [3.0, 4.0]]
assert compare_clusters(cluster_a, cluster_b) == True
def test_compare_clusters_different_order(self):
cluster_a = [[1.0, 2.0], [3.0, 4.0]]
cluster_b = [[3.0, 4.0], [1.0, 2.0]]
assert compare_clusters(cluster_a, cluster_b) == True
def test_compare_clusters_different(self):
cluster_a = [[1.0, 2.0], [3.0, 4.0]]
cluster_b = [[1.0, 2.0], [5.0, 6.0]]
assert compare_clusters(cluster_a, cluster_b) == False
def test_compare_clusters_empty(self):
assert compare_clusters([], []) == True
assert compare_clusters([[1.0, 2.0]], []) == False
class TestKMeansBinary:
def test_kmeans_binary_simple_case(self, euclidean_distance):
# Two well-separated clusters
points = [
[0.0, 0.0],
[0.1, 0.1],
[0.2, 0.0], # Cluster 1
[5.0, 5.0],
[5.1, 5.1],
[5.0, 5.2], # Cluster 2
]
centroid1, centroid2 = KMeansBinary(points, euclidean_distance)
# Check that centroids are reasonable
assert len(centroid1) == 2
assert len(centroid2) == 2
# One centroid should be near (0.1, 0.033), other near (5.033, 5.1)
c1_near_origin = abs(centroid1[0]) < 1 and abs(centroid1[1]) < 1
c2_near_origin = abs(centroid2[0]) < 1 and abs(centroid2[1]) < 1
# Exactly one should be near origin
assert c1_near_origin != c2_near_origin
def test_kmeans_binary_single_point(self, euclidean_distance):
points = [[1.0, 1.0]]
centroid1, centroid2 = KMeansBinary(points, euclidean_distance)
# Both centroids should be the single point
assert points_close(centroid1, [1.0, 1.0])
assert points_close(centroid2, []) # Empty cluster
def test_kmeans_binary_two_points(self, euclidean_distance):
points = [[0.0, 0.0], [2.0, 2.0]]
centroid1, centroid2 = KMeansBinary(points, euclidean_distance)
# Should converge to the two original points
centroids = [centroid1, centroid2]
assert any(points_close(c, [0.0, 0.0]) for c in centroids)
assert any(points_close(c, [2.0, 2.0]) for c in centroids)
def test_kmeans_binary_convergence(self, euclidean_distance):
# Test that algorithm converges within reasonable iterations
points = [
[i / 10.0, i / 10.0]
for i in range(5) # Points along diagonal
] + [
[i / 10.0 + 5, i / 10.0 + 5]
for i in range(5) # Shifted cluster
]
centroid1, centroid2 = KMeansBinary(
points, euclidean_distance, max_iterations=50
)
# Should produce two distinct clusters
distance_between_centroids = euclidean_distance(centroid1, centroid2)
assert distance_between_centroids > 2.0 # Should be well separated
def test_kmeans_binary_with_manhattan_distance(self, manhattan_distance):
points = [[0, 0], [1, 0], [0, 1], [10, 10], [11, 10], [10, 11]]
centroid1, centroid2 = KMeansBinary(points, manhattan_distance)
# Should separate into two clusters
assert len(centroid1) == 2
assert len(centroid2) == 2
# One centroid should be near origin, other near (10,10)
c1_near_origin = abs(centroid1[0]) < 5 and abs(centroid1[1]) < 5
c2_near_origin = abs(centroid2[0]) < 5 and abs(centroid2[1]) < 5
assert c1_near_origin != c2_near_origin
def test_kmeans_binary_max_iterations(self, euclidean_distance):
# Test that max_iterations parameter works
points = [[i, i] for i in range(10)]
# Should work with very few iterations
centroid1, centroid2 = KMeansBinary(
points, euclidean_distance, max_iterations=1
)
assert len(centroid1) == 2
assert len(centroid2) == 2
def test_kmeans_binary_3d_points(self, euclidean_distance):
# Test with 3D points
points = [
[0, 0, 0],
[1, 0, 0],
[0, 1, 0], # Near origin
[5, 5, 5],
[6, 5, 5],
[5, 6, 5], # Far from origin
]
centroid1, centroid2 = KMeansBinary(points, euclidean_distance)
assert len(centroid1) == 3
assert len(centroid2) == 3
# Should separate the two groups
c1_near_origin = all(abs(x) < 3 for x in centroid1)
c2_near_origin = all(abs(x) < 3 for x in centroid2)
assert c1_near_origin != c2_near_origin
class TestEdgeCases:
def test_identical_points(self, euclidean_distance):
# All points are identical
points = [[1.0, 1.0]] * 5
centroid1, centroid2 = KMeansBinary(points, euclidean_distance)
# Both centroids should be the same point
assert points_close(centroid1, [1.0, 1.0])
assert points_close(centroid2, [1.0, 1.0])
def test_collinear_points(self, euclidean_distance):
# Points on a line
points = [[i, 0] for i in range(6)]
centroid1, centroid2 = KMeansBinary(points, euclidean_distance)
# Should still produce two clusters
assert len(centroid1) == 2
assert len(centroid2) == 2
assert centroid1[1] == 0 # y-coordinate should be 0
assert centroid2[1] == 0 # y-coordinate should be 0
if __name__ == "__main__":
pytest.main([__file__])

197
src/utils.py Normal file
View File

@ -0,0 +1,197 @@
from typing import Callable, Literal
from functools import reduce
Point = list[float]
Metric = Callable[[Point, Point], float]
VectorTransformation = Callable[[Point], Point]
IndexAlgorithm = Literal["blob", "hnsw"]
EmbeddingSize = Literal[256, 512, 1024, 1536]
def chunk_list(_list: list[str], chunk_size: int = 50) -> list[list[str]]:
"""
This function takes a list of strings, and splits it into
a list of lists of strings, in which each entry is a list with at most
`chunk_size` elements.
"""
_range = range(0, len(_list), chunk_size)
return [_list[i : i + chunk_size] for i in _range]
def text_to_chunks(
text: str,
chunk_size: int,
chunk_overlap: int,
) -> list[str]:
"""
This function takes a string, and creates of list of strings
in which each is at most `chunk_size` in length, and each
consecutive set of strings, overlaps in `chunk_overlap` amount
of characters.
"""
if len(text) <= chunk_size:
return [text]
if chunk_overlap >= chunk_size:
raise ValueError("chunk_overlap must be less than chunk_size")
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
if end >= len(text):
chunk = text[start:]
if len(chunk) < chunk_overlap and chunks:
chunks[-1] = chunks[-1] + " " + chunk
else:
chunks.append(chunk)
break
chunk = text[start:end]
if end < len(text): # Not the last chunk
last_space = chunk.rfind(" ", max(0, len(chunk) - 50))
if last_space > len(chunk) // 2: # Only if we find a reasonable break point
chunk = chunk[:last_space]
end = start + last_space
chunks.append(chunk.strip())
# Move start position (with overlap)
start = end - chunk_overlap
chunks = [chunk for chunk in chunks if chunk.strip()]
return chunks
"""
Here are some vector math utils
"""
def identity_transformation(v: Point):
return v
def vector_dot_product(a: Point, b: Point) -> float:
return sum([i * j for i, j in zip(a, b)])
def vector_scalar_multiplication(scalar: float | int, vector: Point) -> Point:
return [scalar * v for v in vector]
def vector_sum(a: Point, b: Point) -> Point:
assert len(a) == len(b), "Inconsistent arguments provided."
return [i + j for i, j in zip(a, b)]
def l_distance_generator(L: int = 2) -> Metric:
"""
This function returns a metric function which follows
the l-norms.
"""
def _wrapper(a: Point, b: Point) -> float:
assert len(a) == len(b), "Inconsistent arguments provided."
return sum(abs(i - j) ** L for i, j in zip(a, b)) ** (1 / L)
return _wrapper
def euclidean_distance(a: Point, b: Point) -> float:
assert len(a) == len(b), "Inconsistent arguments provided."
difference = vector_sum(a, vector_scalar_multiplication(-1, b))
return vector_dot_product(difference, difference) ** 0.5
def cosine_distance(a: Point, b: Point) -> float:
"""
We take a [0, 2] distance value as a metric.
"""
assert len(a) == len(b), "Inconsistent arguments provided."
return 1 - vector_dot_product(a, b)
def calculate_centroid(points: list[Point]) -> Point:
point_count = len(points)
if point_count == 0:
return []
zero = vector_scalar_multiplication(0, points[0])
centroid = reduce(lambda a, b: vector_sum(a, b), points, zero)
return vector_scalar_multiplication(1 / point_count, centroid)
def classify_point_on_clusters(
point: Point,
centers: tuple[Point, Point],
metric: Metric,
) -> bool:
return metric(point, centers[0]) < metric(point, centers[1])
def compare_clusters(cluster_a: list[Point], cluster_b: list[Point]) -> bool:
return set(map(tuple, cluster_a)) == set(map(tuple, cluster_b))
def KMeansBinary(
points: list[Point],
metric: Metric,
max_iterations: int = 100,
) -> tuple[Point, Point]:
zero = vector_scalar_multiplication(0, points[0])
center = vector_scalar_multiplication(
1 / len(points),
reduce(lambda a, b: vector_sum(a, b), points, zero),
)
point_distances = [(point, metric(center, point)) for point in points]
point_distances.sort(key=lambda x: x[1])
mid = len(points) // 2
cluster_1 = [point for point, _ in point_distances[:mid]]
cluster_2 = [point for point, _ in point_distances[mid:]]
for i in range(max_iterations):
centroids = (calculate_centroid(cluster_1), calculate_centroid(cluster_2))
new_cluster_1, new_cluster_2 = [], []
for point in points:
if classify_point_on_clusters(point, centroids, metric):
new_cluster_1.append(point)
else:
new_cluster_2.append(point)
same_clusters = any(
(
compare_clusters(cluster_1, new_cluster_1),
compare_clusters(cluster_1, new_cluster_2),
)
)
if same_clusters:
break
if not new_cluster_1 or not new_cluster_2:
point_distances = [(point, metric(centroids[0], point)) for point in points]
point_distances.sort(key=lambda x: x[1])
mid = len(points) // 2
new_cluster_1 = [point for point, _ in point_distances[:mid]]
new_cluster_2 = [point for point, _ in point_distances[mid:]]
if not new_cluster_2 and len(points) > 1:
new_cluster_2 = [new_cluster_1.pop()]
cluster_1, cluster_2 = new_cluster_1, new_cluster_2
centroid_1 = calculate_centroid(cluster_1)
centroid_2 = calculate_centroid(cluster_2)
return centroid_1, centroid_2

1501
uv.lock Normal file

File diff suppressed because it is too large Load Diff