Big Bang
This commit is contained in:
commit
19f335e300
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
**/__pycache__
|
||||||
|
*.pyc
|
||||||
|
__marimo__
|
||||||
|
.egg-info
|
||||||
|
.git
|
||||||
|
data/
|
||||||
13
.gitignore
vendored
Normal file
13
.gitignore
vendored
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
__marimo__
|
||||||
|
data/
|
||||||
|
.env
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
|||||||
|
3.11
|
||||||
6
Dockerfile
Normal file
6
Dockerfile
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
FROM python:3.11-slim
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:0.4.20 /uv /bin/uv
|
||||||
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
|
WORKDIR /app
|
||||||
|
COPY . .
|
||||||
|
RUN uv sync --frozen --no-cache
|
||||||
7
Justfile
Normal file
7
Justfile
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
api:
|
||||||
|
uv run --env-file .env fastapi dev src/api.py
|
||||||
|
|
||||||
|
docker:
|
||||||
|
docker compose build
|
||||||
|
docker compose up
|
||||||
|
|
||||||
130
README.md
Normal file
130
README.md
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
# Take-at-Home Task - Backend (Vector DB)
|
||||||
|
|
||||||
|
Congrats on making it thus far in the interview process!
|
||||||
|
Here is a task for you to show us where you shine the most 🙂
|
||||||
|
|
||||||
|
The purpose is not to see how fast you go or what magic tricks you know in python, it’s mostly to understand how clearly you think and code.
|
||||||
|
|
||||||
|
If you think clearly and your code is clean, you are better than 90% of applicants already!
|
||||||
|
|
||||||
|
> ⚠️ Feel free to use Cursor, but use it where it makes sense, don’t overuse it, it introduces bugs and is super verbose and not really pythonic.
|
||||||
|
>
|
||||||
|
|
||||||
|
## Objective
|
||||||
|
|
||||||
|
The goal of this project is to develop a REST API that allows users to **index** and **query** their documents within a Vector Database.
|
||||||
|
|
||||||
|
A Vector Database specializes in storing and indexing vector embeddings, enabling fast retrieval and similarity searches. This capability is crucial for applications involving natural language processing, recommendation systems, and many more…
|
||||||
|
|
||||||
|
The REST API should be containerized in a Docker container.
|
||||||
|
|
||||||
|
### Definitions
|
||||||
|
|
||||||
|
To ensure a clear understanding, let's define some key concepts:
|
||||||
|
|
||||||
|
1. Chunk: A chunk is a piece of text with an associated embedding and metadata.
|
||||||
|
2. Document: A document is made out of multiple chunks, it also contains metadata.
|
||||||
|
3. Library: A library is made out of a list of documents and can also contain other metadata.
|
||||||
|
|
||||||
|
The API should:
|
||||||
|
|
||||||
|
1. Allow the users to create, read, update, and delete libraries.
|
||||||
|
2. Allow the users to create, read, update and delete chunks within a library.
|
||||||
|
3. Index the contents of a library.
|
||||||
|
4. Do **k-Nearest Neighbor vector search** over the selected library with a given embedding query.
|
||||||
|
|
||||||
|
### Guidelines:
|
||||||
|
|
||||||
|
The code should be **Python** since that is what we use to develop our backend.
|
||||||
|
|
||||||
|
Here is a suggested path on how to implement a basic solution to the problem.
|
||||||
|
|
||||||
|
1. Define the Chunk, Document and Library classes. To simplify schema definition, we suggest you use a fixed schema for each of the classes. This means not letting the user define which fields should be present within the metadata for each class.
|
||||||
|
|
||||||
|
|
||||||
|
2. Implement two or three indexing algorithms, do not use external libraries, we want to see you code them up.
|
||||||
|
1. What is the space and time complexity for each of the indexes?
|
||||||
|
2. Why did you choose this index?
|
||||||
|
|
||||||
|
3. Implement the necessary data structures/algorithms to ensure that there are no data races between reads and writes to the database.
|
||||||
|
1. Explain your design choices.
|
||||||
|
|
||||||
|
4. Create the logic to do the CRUD operations on libraries and documents/chunks.
|
||||||
|
1. Ideally use Services to decouple API endpoints from actual work
|
||||||
|
|
||||||
|
5. Implement an API layer on top of that logic to let users interact with the vector database.
|
||||||
|
|
||||||
|
6. Create a docker image for the project
|
||||||
|
|
||||||
|
### Extra Points:
|
||||||
|
|
||||||
|
Here are some additional suggestions on how to enhance the project even further. You are not required to implement any of these, but if you do, we will value it. If you have other improvements in mind, please feel free to implement them and document them in the project’s README file
|
||||||
|
|
||||||
|
1. **Metadata filtering:**
|
||||||
|
- Add the possibility of using metadata filters to enhance query results: ie: do kNN search over all chunks created after a given date, whose name contains xyz string etc etc.
|
||||||
|
2. **Persistence to Disk**:
|
||||||
|
- Implement a mechanism to persist the database state to disk, ensuring that the docker container can be restarted and resume its operation from the last checkpoint. Explain your design choices and tradeoffs, considering factors like performance, consistency, and durability.
|
||||||
|
3. **Leader-Follower Architecture**:
|
||||||
|
- Design and implement a leader-follower (master-slave) architecture to support multiple database nodes within the Kubernetes cluster. This architecture should handle read scalability and provide high availability. Explain how leader election, data replication, and failover are managed, along with the benefits and tradeoffs of this approach.
|
||||||
|
4. **Python SDK Client**:
|
||||||
|
- Develop a Python SDK client that interfaces with your API, making it easier for users to interact with the vector database programmatically. Include documentation and examples.
|
||||||
|
|
||||||
|
## Constraints
|
||||||
|
|
||||||
|
Do **not** use libraries like chroma-db, pinecone, FAISS, etc to develop the project, we want to see you write the algorithms yourself. You can use numpy to calculate trigonometry functions `cos` , `sin` , etc
|
||||||
|
|
||||||
|
You **do not need to build a document processing pipeline** (ocr+text extraction+chunking) to test your system. Using a bunch of manually created chunks will suffice.
|
||||||
|
|
||||||
|
## **Tech Stack**
|
||||||
|
|
||||||
|
- **API Backend:** Python + FastAPI + Pydantic
|
||||||
|
|
||||||
|
## Resources:
|
||||||
|
|
||||||
|
[Cohere](https://cohere.com/embeddings) API key to create the embeddings for your test.
|
||||||
|
|
||||||
|
## Evaluation Criteria
|
||||||
|
|
||||||
|
We will evaluate the code functionality and its quality.
|
||||||
|
|
||||||
|
**Code quality:**
|
||||||
|
|
||||||
|
- [SOLID design principles](https://realpython.com/solid-principles-python/).
|
||||||
|
- Use of static typing.
|
||||||
|
- FastAPI good practices.
|
||||||
|
- Pydantic schema validation
|
||||||
|
- Code modularity and reusability.
|
||||||
|
- Use of RESTful API endpoints.
|
||||||
|
- Project containerization with Docker.
|
||||||
|
- Testing
|
||||||
|
- Error handling.
|
||||||
|
- If you know what Domain-Driven design is, do it that way!
|
||||||
|
- Separate API endpoints from business logic using services and from databases using repositories
|
||||||
|
- Keep code as pythonic as possible
|
||||||
|
- Do early returns
|
||||||
|
- Use inheritance where needed
|
||||||
|
- Use composition over inheritance
|
||||||
|
|
||||||
|
**Functionality:**
|
||||||
|
|
||||||
|
- Does everything work as expected?
|
||||||
|
|
||||||
|
## Deliverable
|
||||||
|
|
||||||
|
1. **Source Code**: A link to a GitHub repository containing all your source code.
|
||||||
|
2. **Documentation**: A README file that documents the task, explains your technical choices, how to run the project locally, and any other relevant information.
|
||||||
|
3. **Demo video:**
|
||||||
|
1. A screen recording where you show how to install the project and interact with it in real time.
|
||||||
|
2. A screen recording of your design with an explanation of your design choices and thoughts/problem-solving.
|
||||||
|
|
||||||
|
## Timeline
|
||||||
|
|
||||||
|
As a reference, this task should take at most **4 days** (96h) from the receipt of this test to submit your deliverables 🚀
|
||||||
|
|
||||||
|
But honestly, if you think you can do a much better job with some extra days (perhaps because you couldn’t spend too many hours), be our guest!
|
||||||
|
|
||||||
|
At the end of the day, if it is not going to impress the team, it’s not going to fly, so give it your best shot ✈️
|
||||||
|
|
||||||
|
## Questions
|
||||||
|
|
||||||
|
Feel free to reach out at any given time with questions about the task, particularly if you encounter problems outside your control that may block your progress.
|
||||||
14
docker-compose.yaml
Normal file
14
docker-compose.yaml
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
services:
|
||||||
|
api:
|
||||||
|
build: .
|
||||||
|
command: uv run fastapi run src/api.py
|
||||||
|
ports:
|
||||||
|
- 8000:8000
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
- storage_path=/app/data
|
||||||
|
volumes:
|
||||||
|
- ./data:/app/data
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
38
pyproject.toml
Normal file
38
pyproject.toml
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
[project]
|
||||||
|
name = "python-vector-database"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"cohere>=5.15.0",
|
||||||
|
"fastapi[standard]>=0.115.12",
|
||||||
|
"pydantic-settings>=2.9.1",
|
||||||
|
"pydantic>=2.11.5",
|
||||||
|
"uuid7>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
|
|
||||||
|
[tool.setuptools.package-dir]
|
||||||
|
"" = "src"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"ipython>=9.2.0",
|
||||||
|
"marimo>=0.13.15",
|
||||||
|
"pandas>=2.2.3",
|
||||||
|
"pyarrow>=20.0.0",
|
||||||
|
"pytest>=8.3.5",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
pythonpath = "src"
|
||||||
|
testpaths = "src/tests"
|
||||||
|
addopts = "-v --tb=shor:"
|
||||||
136
src/api.py
Normal file
136
src/api.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
import time
|
||||||
|
|
||||||
|
import marimo
|
||||||
|
from fastapi import FastAPI, HTTPException, Depends, Request, BackgroundTasks
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
from model import (
|
||||||
|
Document,
|
||||||
|
DocumentUpload,
|
||||||
|
Library,
|
||||||
|
LibraryDoesNotExist,
|
||||||
|
QueryResult,
|
||||||
|
LibraryQuery,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseStatus(BaseModel):
|
||||||
|
status: str = "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def get_library(library_slug: str) -> Library:
|
||||||
|
try:
|
||||||
|
library = Library.load(library_slug)
|
||||||
|
except LibraryDoesNotExist:
|
||||||
|
raise HTTPException(404)
|
||||||
|
return library
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Vector Database")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def index():
|
||||||
|
html_content = """
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Vector Database</title>
|
||||||
|
<link rel="stylesheet" href="https://unpkg.com/axist@latest/dist/axist.min.css" />
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<article>
|
||||||
|
<p>Here are the list of UI elements from this API.</p>
|
||||||
|
<ul>
|
||||||
|
<li><a href="/docs">Open API Swagger-like docs</a></li>
|
||||||
|
<li><a href="/app/create">Marimo App to Create Libraries</a></li>
|
||||||
|
<li><a href="/app/seed">Marimo App to Seed the database with movie data.</a></li>
|
||||||
|
<li><a href="/app/insert">Marimo App to Insert Documents</a></li>
|
||||||
|
<li><a href="/app/search">Marimo App to Search the Library</a></li>
|
||||||
|
</ul>
|
||||||
|
</article>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
return HTMLResponse(content=html_content, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
server = marimo.create_asgi_app(include_code=True, quiet=False)
|
||||||
|
src_dir = Path(__file__).parent
|
||||||
|
server = server.with_app(
|
||||||
|
path="/create",
|
||||||
|
root=src_dir / "notebook/app-create.py",
|
||||||
|
)
|
||||||
|
server = server.with_app(
|
||||||
|
path="/seed",
|
||||||
|
root=src_dir / "notebook/app-seed.py",
|
||||||
|
)
|
||||||
|
server = server.with_app(
|
||||||
|
path="/insert",
|
||||||
|
root=src_dir / "notebook/app-insert.py",
|
||||||
|
)
|
||||||
|
server = server.with_app(
|
||||||
|
path="/search",
|
||||||
|
root=src_dir / "notebook/app-search.py",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app.mount("/app", server.build())
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def add_process_time_header(request: Request, call_next):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
response = await call_next(request)
|
||||||
|
process_time = time.perf_counter() - start_time
|
||||||
|
response.headers["X-Process-Time"] = str(process_time)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/library/list")
|
||||||
|
def list_libraries() -> list[str]:
|
||||||
|
return Library.list_libraries()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/library/create")
|
||||||
|
def list_library(payload: Library):
|
||||||
|
if payload.slug in Library.list_libraries():
|
||||||
|
raise HTTPException(400, "Library already exists.")
|
||||||
|
payload.save()
|
||||||
|
return payload.stat
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/library/{library_slug}")
|
||||||
|
def get_library_details(
|
||||||
|
library: Annotated[Library, Depends(get_library)],
|
||||||
|
):
|
||||||
|
return library.stat
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/library/{library_slug}/documents")
|
||||||
|
def insert_new_document(
|
||||||
|
payload: DocumentUpload,
|
||||||
|
library: Annotated[Library, Depends(get_library)],
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
):
|
||||||
|
document = Document.model_validate(payload, from_attributes=True)
|
||||||
|
document.save(library)
|
||||||
|
|
||||||
|
def _create_chunks_for_document():
|
||||||
|
document.create_chunks(library)
|
||||||
|
|
||||||
|
background_tasks.add_task(_create_chunks_for_document)
|
||||||
|
|
||||||
|
return ResponseStatus()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/library/{library_slug}/search")
|
||||||
|
def search(
|
||||||
|
payload: LibraryQuery,
|
||||||
|
library: Annotated[Library, Depends(get_library)],
|
||||||
|
) -> list[QueryResult]:
|
||||||
|
return library.query(payload.query, payload.results)
|
||||||
121
src/embedding.py
Normal file
121
src/embedding.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
from typing import Iterator
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import cohere
|
||||||
|
|
||||||
|
from utils import chunk_list, text_to_chunks, Point, EmbeddingSize
|
||||||
|
from settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
co = cohere.ClientV2()
|
||||||
|
|
||||||
|
|
||||||
|
def md5(text: str) -> str:
|
||||||
|
return hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _text_chunks(
|
||||||
|
text: str, chunk_size: int, chunk_overlap: int, search: bool = True
|
||||||
|
) -> list[str]:
|
||||||
|
if search:
|
||||||
|
return [text]
|
||||||
|
return text_to_chunks(text, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
|
||||||
|
def _call_cohere(
|
||||||
|
texts: list[str],
|
||||||
|
embedding_size: EmbeddingSize,
|
||||||
|
search: bool = True,
|
||||||
|
) -> Iterator[tuple[str, Point]]:
|
||||||
|
response = co.embed(
|
||||||
|
inputs=[{"content": [{"type": "text", "text": doc}]} for doc in texts],
|
||||||
|
model="embed-v4.0",
|
||||||
|
output_dimension=embedding_size,
|
||||||
|
input_type="search_query" if search else "search_document",
|
||||||
|
embedding_types=["float"],
|
||||||
|
)
|
||||||
|
return zip(texts, response.embeddings.float)
|
||||||
|
|
||||||
|
|
||||||
|
def embed_store(
|
||||||
|
text: str,
|
||||||
|
embedding: Point,
|
||||||
|
embedding_size: EmbeddingSize,
|
||||||
|
search: bool = False,
|
||||||
|
):
|
||||||
|
if not settings.cohere_cache:
|
||||||
|
return
|
||||||
|
_md5 = md5(text)
|
||||||
|
target_file = (
|
||||||
|
settings.storage_path
|
||||||
|
/ ".cohere_embedding"
|
||||||
|
/ ("search" if search else "storage")
|
||||||
|
/ _md5
|
||||||
|
)
|
||||||
|
embedding_table = {}
|
||||||
|
if target_file.exists():
|
||||||
|
embedding_table = json.loads(target_file.read_text())
|
||||||
|
embedding_table[str(embedding_size)] = embedding
|
||||||
|
target_file.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
target_file.write_text(
|
||||||
|
json.dumps(embedding_table, indent=2)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def embed_lookup(
|
||||||
|
text: str,
|
||||||
|
embedding_size: EmbeddingSize,
|
||||||
|
search: bool = False,
|
||||||
|
) -> Point | None:
|
||||||
|
if not settings.cohere_cache:
|
||||||
|
return
|
||||||
|
_md5 = md5(text)
|
||||||
|
target_file = (
|
||||||
|
settings.storage_path
|
||||||
|
/ ".cohere_embedding"
|
||||||
|
/ ("search" if search else "storage")
|
||||||
|
/ _md5
|
||||||
|
)
|
||||||
|
if target_file.exists():
|
||||||
|
embeddings = json.loads(target_file.read_text())
|
||||||
|
return embeddings.get(str(embedding_size), None)
|
||||||
|
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
text: str,
|
||||||
|
chunk_size: int,
|
||||||
|
chunk_overlap: int,
|
||||||
|
embedding_size: EmbeddingSize,
|
||||||
|
search: bool = True,
|
||||||
|
) -> list[tuple[str, Point]]:
|
||||||
|
"""
|
||||||
|
https://docs.cohere.com/reference/embed
|
||||||
|
Input type:
|
||||||
|
> "search_document": Used for embeddings stored...
|
||||||
|
> "search_query": Used for embeddings of search queries...
|
||||||
|
|
||||||
|
Cohere API only allows for at most 96 texts at a time.
|
||||||
|
"""
|
||||||
|
text_chunks_list = _text_chunks(text, chunk_size, chunk_overlap, search)
|
||||||
|
|
||||||
|
text_chunks = {
|
||||||
|
_text: embed_lookup(_text, embedding_size, search) for _text in text_chunks_list
|
||||||
|
}
|
||||||
|
|
||||||
|
missing = [_text for _text, embedding in text_chunks.items() if embedding is None]
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for _texts in chunk_list(missing, 50):
|
||||||
|
results |= dict(_call_cohere(_texts, embedding_size, search))
|
||||||
|
|
||||||
|
for _text, embedding in results.items():
|
||||||
|
embed_store(_text, embedding, embedding_size, search)
|
||||||
|
|
||||||
|
returnable = []
|
||||||
|
for _text in text_chunks_list:
|
||||||
|
returnable.append((_text, text_chunks[_text] or results[_text]))
|
||||||
|
|
||||||
|
return returnable
|
||||||
544
src/index.py
Normal file
544
src/index.py
Normal file
@ -0,0 +1,544 @@
|
|||||||
|
import heapq
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import Protocol, Generator, TYPE_CHECKING
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from uuid_extensions import uuid7
|
||||||
|
|
||||||
|
from utils import Point, KMeansBinary
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from model import Library, Chunk
|
||||||
|
|
||||||
|
|
||||||
|
class IndexStrategy(Protocol):
|
||||||
|
def create_index(self, library: "Library") -> None:
|
||||||
|
"""Initialize the index for a library"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def add_chunk(self, library: "Library", chunk: "Chunk") -> None:
|
||||||
|
"""Add a chunk to the index"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, library: "Library", query_vector: Point
|
||||||
|
) -> Generator[tuple[UUID, float], None, None]:
|
||||||
|
"""Search for k nearest chunks, return (chunk_id, distance) pairs"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class IndexDefinition(BaseModel):
|
||||||
|
centers: dict[UUID, Point] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read(cls, library: "Library"):
|
||||||
|
target_file = library.storage_directory / "index" / "definition.json"
|
||||||
|
if not target_file.exists():
|
||||||
|
return None
|
||||||
|
return cls.model_validate_json(target_file.read_text())
|
||||||
|
|
||||||
|
def save(self, library: "Library"):
|
||||||
|
target_file = library.storage_directory / "index" / "definition.json"
|
||||||
|
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
target_file.write_text(self.model_dump_json(indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
class IndexBlob(BaseModel):
|
||||||
|
id: UUID = Field(default_factory=uuid7)
|
||||||
|
center: Point
|
||||||
|
members: list[UUID] = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read(cls, library: "Library", id: UUID):
|
||||||
|
target_file = library.storage_directory / "index" / str(id)
|
||||||
|
if not target_file.exists():
|
||||||
|
return None
|
||||||
|
return cls.model_validate_json(target_file.read_text())
|
||||||
|
|
||||||
|
def save(self, library: "Library"):
|
||||||
|
target_file = library.storage_directory / "index" / str(self.id)
|
||||||
|
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
target_file.write_text(self.model_dump_json(indent=2))
|
||||||
|
|
||||||
|
def delete(self, library: "Library"):
|
||||||
|
target_file = library.storage_directory / "index" / str(self.id)
|
||||||
|
if target_file.exists():
|
||||||
|
target_file.unlink()
|
||||||
|
|
||||||
|
def split(self, library: "Library"):
|
||||||
|
from model import Chunk
|
||||||
|
|
||||||
|
chunks = {chunk_id: Chunk.read(library, chunk_id) for chunk_id in self.members}
|
||||||
|
vectors = [chunk for chunk_id, chunk in chunks.items() if chunk is not None]
|
||||||
|
if len(vectors) < 2:
|
||||||
|
return
|
||||||
|
transformed = [library.vector_transformation(v.vector) for v in vectors]
|
||||||
|
# K-means with k=2 (manual implementation)
|
||||||
|
centers = KMeansBinary(transformed, library.metric)
|
||||||
|
|
||||||
|
# Create new blobs
|
||||||
|
A = IndexBlob(center=library.vector_transformation(centers[0]))
|
||||||
|
B = IndexBlob(center=library.vector_transformation(centers[1]))
|
||||||
|
|
||||||
|
# Assign members based on original metric
|
||||||
|
for i, chunk_id in enumerate(self.members):
|
||||||
|
chunk = chunks[chunk_id]
|
||||||
|
if chunk is None:
|
||||||
|
continue
|
||||||
|
original_vector = chunk.vector
|
||||||
|
distance_to_A = library.metric(original_vector, A.center)
|
||||||
|
distance_to_B = library.metric(original_vector, B.center)
|
||||||
|
if distance_to_A <= distance_to_B:
|
||||||
|
A.members.append(chunk_id)
|
||||||
|
else:
|
||||||
|
B.members.append(chunk_id)
|
||||||
|
|
||||||
|
if not A.members or not B.members:
|
||||||
|
return
|
||||||
|
|
||||||
|
index_definition = IndexDefinition.read(library)
|
||||||
|
if index_definition is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.id in index_definition.centers:
|
||||||
|
del index_definition.centers[self.id]
|
||||||
|
index_definition.centers[A.id] = A.center
|
||||||
|
index_definition.centers[B.id] = B.center
|
||||||
|
index_definition.save(library)
|
||||||
|
self.delete(library)
|
||||||
|
A.save(library)
|
||||||
|
B.save(library)
|
||||||
|
|
||||||
|
def add(self, library: "Library", member: UUID):
|
||||||
|
self.members.append(member)
|
||||||
|
self.save(library)
|
||||||
|
|
||||||
|
if len(self.members) >= library.index_blob_limit:
|
||||||
|
self.split(library)
|
||||||
|
|
||||||
|
|
||||||
|
class BlobIndexStrategy:
|
||||||
|
def create_index(self, library: "Library") -> None:
|
||||||
|
origin = [0.0] * library.embedding_size
|
||||||
|
if IndexDefinition.read(library) is None:
|
||||||
|
index_blob = IndexBlob(center=origin)
|
||||||
|
index_blob.save(library)
|
||||||
|
IndexDefinition(centers={index_blob.id: origin}).save(library)
|
||||||
|
|
||||||
|
def add_chunk(self, library: "Library", chunk: "Chunk") -> None:
|
||||||
|
index_definition = IndexDefinition.read(library)
|
||||||
|
if index_definition is None:
|
||||||
|
self.create_index(library)
|
||||||
|
index_definition = IndexDefinition.read(library)
|
||||||
|
|
||||||
|
distances = {
|
||||||
|
key: library.metric(
|
||||||
|
library.vector_transformation(center),
|
||||||
|
library.vector_transformation(chunk.vector),
|
||||||
|
)
|
||||||
|
for key, center in index_definition.centers.items()
|
||||||
|
}
|
||||||
|
if not distances:
|
||||||
|
return
|
||||||
|
|
||||||
|
nearest_index_id = min(distances, key=distances.get)
|
||||||
|
blob = IndexBlob.read(library, nearest_index_id)
|
||||||
|
if blob:
|
||||||
|
blob.add(library, chunk.id)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, library: "Library", query_vector: Point
|
||||||
|
) -> Generator[tuple[UUID, float], None, None]:
|
||||||
|
from model import Chunk
|
||||||
|
|
||||||
|
index_definition = IndexDefinition.read(library)
|
||||||
|
if index_definition is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
distances = {
|
||||||
|
key: library.metric(
|
||||||
|
library.vector_transformation(center),
|
||||||
|
library.vector_transformation(query_vector),
|
||||||
|
)
|
||||||
|
for key, center in index_definition.centers.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
while len(distances) != 0:
|
||||||
|
nearest = min(distances, key=distances.get)
|
||||||
|
del distances[nearest]
|
||||||
|
|
||||||
|
index_blob = IndexBlob.read(library, nearest)
|
||||||
|
if index_blob is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks = {
|
||||||
|
chunk_id: Chunk.read(library, chunk_id)
|
||||||
|
for chunk_id in index_blob.members
|
||||||
|
}
|
||||||
|
distances_to_chunks = {
|
||||||
|
chunk.id: library.metric(
|
||||||
|
query_vector,
|
||||||
|
library.vector_transformation(chunk.vector),
|
||||||
|
)
|
||||||
|
for chunk in chunks.values()
|
||||||
|
if chunk is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
while len(distances_to_chunks) != 0:
|
||||||
|
nearest_chunk_id = min(
|
||||||
|
distances_to_chunks,
|
||||||
|
key=distances_to_chunks.get,
|
||||||
|
)
|
||||||
|
yield nearest_chunk_id, distances_to_chunks[nearest_chunk_id]
|
||||||
|
del distances_to_chunks[nearest_chunk_id]
|
||||||
|
|
||||||
|
|
||||||
|
class HNSWNode(BaseModel):
|
||||||
|
id: UUID = Field(default_factory=uuid7)
|
||||||
|
chunk_id: UUID
|
||||||
|
vector: Point
|
||||||
|
level: int
|
||||||
|
connections: dict[int, list[UUID]] = {} # level -> list of connected node IDs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read(cls, library: "Library", id: UUID):
|
||||||
|
target_file = library.storage_directory / "index" / str(id)
|
||||||
|
if target_file.exists():
|
||||||
|
return cls.model_validate_json(target_file.read_text())
|
||||||
|
return None
|
||||||
|
|
||||||
|
def save(self, library: "Library"):
|
||||||
|
target_file = library.storage_directory / "index" / str(self.id)
|
||||||
|
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
target_file.write_text(self.model_dump_json(indent=2))
|
||||||
|
except Exception:
|
||||||
|
pass # Fail silently to avoid breaking the indexing process
|
||||||
|
|
||||||
|
def delete(self, library: "Library"):
|
||||||
|
target_file = library.storage_directory / "index" / str(self.id)
|
||||||
|
if target_file.exists():
|
||||||
|
try:
|
||||||
|
target_file.unlink()
|
||||||
|
except Exception:
|
||||||
|
pass # Fail silently
|
||||||
|
|
||||||
|
|
||||||
|
class HNSWGraph(BaseModel):
|
||||||
|
entry_point: UUID | None = None
|
||||||
|
max_level: int = 0
|
||||||
|
ml: float = 1.0 / math.log(2.0)
|
||||||
|
max_connections_0: int = 16
|
||||||
|
max_connections: int = 8
|
||||||
|
ef_construction: int = 200 # Size of candidate list during construction
|
||||||
|
max_level_limit: int = 20
|
||||||
|
nodes: dict[UUID, UUID] = {} # chunk_id -> node_id mapping
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read(cls, library: "Library"):
|
||||||
|
target_file = library.storage_directory / "index" / "definition.json"
|
||||||
|
if not target_file.exists():
|
||||||
|
return cls()
|
||||||
|
try:
|
||||||
|
return cls.model_validate_json(target_file.read_text())
|
||||||
|
except Exception:
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def save(self, library: "Library"):
|
||||||
|
target_file = library.storage_directory / "index" / "definition.json"
|
||||||
|
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
target_file.write_text(self.model_dump_json(indent=2))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_random_level(self) -> int:
|
||||||
|
"""Generate random level for new node using exponential decay"""
|
||||||
|
level = int(-math.log(random.uniform(1e-8, 1.0)) * self.ml)
|
||||||
|
return min(level, self.max_level_limit)
|
||||||
|
|
||||||
|
|
||||||
|
class HNSWIndexStrategy:
|
||||||
|
def create_index(self, library: "Library") -> None:
|
||||||
|
"""Initialize empty HNSW graph"""
|
||||||
|
graph = HNSWGraph()
|
||||||
|
graph.save(library)
|
||||||
|
|
||||||
|
def add_chunk(self, library: "Library", chunk: "Chunk") -> None:
|
||||||
|
"""Add a chunk to the HNSW graph"""
|
||||||
|
graph = HNSWGraph.read(library)
|
||||||
|
|
||||||
|
# Create new node
|
||||||
|
level = graph.get_random_level()
|
||||||
|
node = HNSWNode(
|
||||||
|
chunk_id=chunk.id,
|
||||||
|
vector=chunk.vector,
|
||||||
|
level=level,
|
||||||
|
connections={i: [] for i in range(level + 1)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# If this is the first node, make it the entry point
|
||||||
|
if graph.entry_point is None:
|
||||||
|
graph.entry_point = node.id
|
||||||
|
graph.max_level = level
|
||||||
|
node.save(library)
|
||||||
|
graph.nodes[chunk.id] = node.id
|
||||||
|
graph.save(library)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find entry point and verify it exists
|
||||||
|
entry_node = HNSWNode.read(library, graph.entry_point)
|
||||||
|
if entry_node is None:
|
||||||
|
# Entry point corrupted, make this the new entry
|
||||||
|
graph.entry_point = node.id
|
||||||
|
graph.max_level = level
|
||||||
|
node.save(library)
|
||||||
|
graph.nodes[chunk.id] = node.id
|
||||||
|
graph.save(library)
|
||||||
|
return
|
||||||
|
|
||||||
|
entry_dist = library.metric(node.vector, entry_node.vector)
|
||||||
|
current_closest = [(entry_dist, entry_node.id)]
|
||||||
|
|
||||||
|
# Search from top level down to level + 1 (greedy search with ef=1)
|
||||||
|
for lc in range(graph.max_level, level, -1):
|
||||||
|
current_closest = self._search_layer(
|
||||||
|
library, node.vector, current_closest, 1, lc
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search and connect from level down to 0
|
||||||
|
for lc in range(min(level, graph.max_level), -1, -1):
|
||||||
|
# Use ef_construction for building robust connections
|
||||||
|
candidates = self._search_layer(
|
||||||
|
library, node.vector, current_closest, graph.ef_construction, lc
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select neighbors and create bidirectional connections
|
||||||
|
max_conn = graph.max_connections if lc > 0 else graph.max_connections_0
|
||||||
|
selected = self._select_neighbors(
|
||||||
|
library, node.vector, candidates, max_conn
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add connections
|
||||||
|
for dist, neighbor_id in selected:
|
||||||
|
node.connections[lc].append(neighbor_id)
|
||||||
|
neighbor_node = HNSWNode.read(library, neighbor_id)
|
||||||
|
if neighbor_node and lc in neighbor_node.connections:
|
||||||
|
neighbor_node.connections[lc].append(node.id)
|
||||||
|
|
||||||
|
# Prune neighbor connections if they exceed max
|
||||||
|
if len(neighbor_node.connections[lc]) > max_conn:
|
||||||
|
neighbor_candidates = []
|
||||||
|
for nid in neighbor_node.connections[lc]:
|
||||||
|
n = HNSWNode.read(library, nid)
|
||||||
|
if n is not None:
|
||||||
|
dist = library.metric(neighbor_node.vector, n.vector)
|
||||||
|
neighbor_candidates.append((dist, nid))
|
||||||
|
|
||||||
|
if neighbor_candidates:
|
||||||
|
pruned = self._select_neighbors(
|
||||||
|
library,
|
||||||
|
neighbor_node.vector,
|
||||||
|
neighbor_candidates,
|
||||||
|
max_conn,
|
||||||
|
)
|
||||||
|
neighbor_node.connections[lc] = [nid for _, nid in pruned]
|
||||||
|
|
||||||
|
neighbor_node.save(library)
|
||||||
|
|
||||||
|
current_closest = selected
|
||||||
|
|
||||||
|
# Update graph metadata
|
||||||
|
if level > graph.max_level:
|
||||||
|
graph.max_level = level
|
||||||
|
graph.entry_point = node.id
|
||||||
|
|
||||||
|
graph.nodes[chunk.id] = node.id
|
||||||
|
node.save(library)
|
||||||
|
graph.save(library)
|
||||||
|
|
||||||
|
def _search_layer(
|
||||||
|
self,
|
||||||
|
library: "Library",
|
||||||
|
query: Point,
|
||||||
|
entry_points: list[tuple[float, UUID]],
|
||||||
|
ef: int,
|
||||||
|
level: int,
|
||||||
|
) -> list[tuple[float, UUID]]:
|
||||||
|
"""
|
||||||
|
Search for ef closest nodes in a specific layer.
|
||||||
|
|
||||||
|
ef (exploration factor) determines how many candidates we track during search:
|
||||||
|
- Higher ef = more thorough search, better recall, slower
|
||||||
|
- Lower ef = faster search, potentially missing good candidates
|
||||||
|
- During construction: use ef_construction for robust graph building
|
||||||
|
- During search: can use smaller ef for speed vs quality tradeoff
|
||||||
|
"""
|
||||||
|
visited = set()
|
||||||
|
candidates = [] # max heap for exploration: (-distance, node_id)
|
||||||
|
w = [] # min heap for results: (distance, node_id)
|
||||||
|
|
||||||
|
# Initialize with entry points
|
||||||
|
for dist, node_id in entry_points:
|
||||||
|
if node_id not in visited:
|
||||||
|
heapq.heappush(candidates, (-dist, node_id))
|
||||||
|
heapq.heappush(w, (dist, node_id))
|
||||||
|
visited.add(node_id)
|
||||||
|
|
||||||
|
while candidates:
|
||||||
|
# Get closest unvisited candidate
|
||||||
|
neg_current_dist, current_id = heapq.heappop(candidates)
|
||||||
|
current_dist = -neg_current_dist
|
||||||
|
|
||||||
|
# Stop if current distance is worse than ef-th best found so far
|
||||||
|
if len(w) >= ef and current_dist > max(w)[0]:
|
||||||
|
break
|
||||||
|
|
||||||
|
current_node = HNSWNode.read(library, current_id)
|
||||||
|
if not current_node or level not in current_node.connections:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Explore all neighbors at this level
|
||||||
|
for neighbor_id in current_node.connections[level]:
|
||||||
|
if neighbor_id not in visited:
|
||||||
|
visited.add(neighbor_id)
|
||||||
|
neighbor_node = HNSWNode.read(library, neighbor_id)
|
||||||
|
if neighbor_node:
|
||||||
|
dist = library.metric(query, neighbor_node.vector)
|
||||||
|
|
||||||
|
if len(w) < ef:
|
||||||
|
# Still have room in result set
|
||||||
|
heapq.heappush(candidates, (-dist, neighbor_id))
|
||||||
|
heapq.heappush(w, (dist, neighbor_id))
|
||||||
|
elif dist < max(w)[0]:
|
||||||
|
# Better than worst in result set
|
||||||
|
heapq.heappush(candidates, (-dist, neighbor_id))
|
||||||
|
|
||||||
|
# Remove worst from w and add new candidate
|
||||||
|
w_list = list(w)
|
||||||
|
w_list.remove(max(w_list))
|
||||||
|
w = w_list
|
||||||
|
heapq.heapify(w)
|
||||||
|
heapq.heappush(w, (dist, neighbor_id))
|
||||||
|
|
||||||
|
return sorted(w)
|
||||||
|
|
||||||
|
def _select_neighbors(
|
||||||
|
self,
|
||||||
|
library: "Library",
|
||||||
|
query: Point,
|
||||||
|
candidates: list[tuple[float, UUID]],
|
||||||
|
max_connections: int,
|
||||||
|
) -> list[tuple[float, UUID]]:
|
||||||
|
"""Select diverse neighbors to avoid hubs and improve search quality"""
|
||||||
|
if len(candidates) <= max_connections:
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
candidates.sort(key=lambda x: x[0])
|
||||||
|
selected = []
|
||||||
|
remaining = candidates[:]
|
||||||
|
|
||||||
|
# Always select closest
|
||||||
|
if remaining:
|
||||||
|
selected.append(remaining.pop(0))
|
||||||
|
|
||||||
|
# Select remaining with diversity consideration
|
||||||
|
while len(selected) < max_connections and remaining:
|
||||||
|
best_candidate = None
|
||||||
|
best_score = float("inf")
|
||||||
|
best_idx = -1
|
||||||
|
|
||||||
|
for i, (dist, candidate_id) in enumerate(remaining):
|
||||||
|
# Find minimum distance to already selected nodes (diversity)
|
||||||
|
diversity_score = float("inf")
|
||||||
|
candidate_node = HNSWNode.read(library, candidate_id)
|
||||||
|
|
||||||
|
if candidate_node:
|
||||||
|
for _, selected_id in selected:
|
||||||
|
selected_node = HNSWNode.read(library, selected_id)
|
||||||
|
if selected_node:
|
||||||
|
div_dist = library.metric(
|
||||||
|
candidate_node.vector, selected_node.vector
|
||||||
|
)
|
||||||
|
diversity_score = min(diversity_score, div_dist)
|
||||||
|
|
||||||
|
# Combined score: prefer close nodes that aren't too close to selected ones
|
||||||
|
combined_score = dist - 0.1 * diversity_score
|
||||||
|
|
||||||
|
if combined_score < best_score:
|
||||||
|
best_score = combined_score
|
||||||
|
best_candidate = (dist, candidate_id)
|
||||||
|
best_idx = i
|
||||||
|
|
||||||
|
if best_candidate:
|
||||||
|
selected.append(best_candidate)
|
||||||
|
remaining.pop(best_idx)
|
||||||
|
else:
|
||||||
|
selected.append(remaining.pop(0))
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, library: "Library", query_vector: Point
|
||||||
|
) -> Generator[tuple[UUID, float], None, None]:
|
||||||
|
"""
|
||||||
|
True generator search - yields results indefinitely as long as caller keeps asking.
|
||||||
|
|
||||||
|
The search explores the graph in waves, returning results in order of distance.
|
||||||
|
It maintains a frontier of unexplored candidates and yields the next best result
|
||||||
|
each time it's called.
|
||||||
|
"""
|
||||||
|
graph = HNSWGraph.read(library)
|
||||||
|
if graph.entry_point is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
entry_node = HNSWNode.read(library, graph.entry_point)
|
||||||
|
if not entry_node:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Start with entry point
|
||||||
|
entry_dist = library.metric(query_vector, entry_node.vector)
|
||||||
|
current_closest = [(entry_dist, graph.entry_point)]
|
||||||
|
|
||||||
|
# Navigate down to level 1 (greedy search)
|
||||||
|
for level in range(graph.max_level, 0, -1):
|
||||||
|
current_closest = self._search_layer(
|
||||||
|
library, query_vector, current_closest, 1, level
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now do the actual search at level 0 using a priority queue for true streaming
|
||||||
|
visited = set()
|
||||||
|
candidates = [] # min heap: (distance, node_id)
|
||||||
|
|
||||||
|
# Initialize with level 0 entry points
|
||||||
|
for dist, node_id in current_closest:
|
||||||
|
if node_id not in visited:
|
||||||
|
heapq.heappush(candidates, (dist, node_id))
|
||||||
|
visited.add(node_id)
|
||||||
|
|
||||||
|
yielded = set() # Track what we've already yielded
|
||||||
|
|
||||||
|
while candidates:
|
||||||
|
# Get the closest unvisited candidate
|
||||||
|
current_dist, current_id = heapq.heappop(candidates)
|
||||||
|
|
||||||
|
# Yield this result if we haven't already
|
||||||
|
if current_id not in yielded:
|
||||||
|
current_node = HNSWNode.read(library, current_id)
|
||||||
|
if current_node:
|
||||||
|
yielded.add(current_id)
|
||||||
|
yield current_node.chunk_id, current_dist
|
||||||
|
|
||||||
|
# Add all unvisited neighbors to candidates for future exploration
|
||||||
|
current_node = HNSWNode.read(library, current_id)
|
||||||
|
if current_node and 0 in current_node.connections:
|
||||||
|
for neighbor_id in current_node.connections[0]:
|
||||||
|
if neighbor_id not in visited:
|
||||||
|
visited.add(neighbor_id)
|
||||||
|
neighbor_node = HNSWNode.read(library, neighbor_id)
|
||||||
|
if neighbor_node:
|
||||||
|
dist = library.metric(query_vector, neighbor_node.vector)
|
||||||
|
heapq.heappush(candidates, (dist, neighbor_id))
|
||||||
228
src/model.py
Normal file
228
src/model.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from uuid_extensions import uuid7
|
||||||
|
|
||||||
|
from embedding import embed
|
||||||
|
from index import (
|
||||||
|
IndexStrategy,
|
||||||
|
BlobIndexStrategy,
|
||||||
|
HNSWIndexStrategy,
|
||||||
|
)
|
||||||
|
from utils import (
|
||||||
|
Metric,
|
||||||
|
Point,
|
||||||
|
IndexAlgorithm,
|
||||||
|
EmbeddingSize,
|
||||||
|
VectorTransformation,
|
||||||
|
cosine_distance,
|
||||||
|
identity_transformation,
|
||||||
|
)
|
||||||
|
from settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class LibraryDoesNotExist(Exception): ...
|
||||||
|
|
||||||
|
|
||||||
|
class LibraryQuery(BaseModel):
|
||||||
|
query: str
|
||||||
|
results: int = 5
|
||||||
|
|
||||||
|
|
||||||
|
class QueryResult(BaseModel):
|
||||||
|
document: "Document"
|
||||||
|
chunk: "Chunk"
|
||||||
|
hits: int = 0
|
||||||
|
distance: float
|
||||||
|
|
||||||
|
|
||||||
|
class Library(BaseModel):
|
||||||
|
slug: str
|
||||||
|
|
||||||
|
embedding_size: EmbeddingSize = 256
|
||||||
|
|
||||||
|
chunk_size: int = 128
|
||||||
|
chunk_overlap: int = 32
|
||||||
|
|
||||||
|
index: IndexAlgorithm = "cosine"
|
||||||
|
index_blob_limit: int = 50
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stat(self) -> "LibrarySummary":
|
||||||
|
return LibrarySummary.model_validate(
|
||||||
|
{
|
||||||
|
**self.model_dump(),
|
||||||
|
"documents": len(
|
||||||
|
list((self.storage_directory / "documents").iterdir())
|
||||||
|
),
|
||||||
|
"chunks": len(list((self.storage_directory / "chunks").iterdir())),
|
||||||
|
"indexs": len(list((self.storage_directory / "index").iterdir())) - 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def storage_directory(self) -> Path:
|
||||||
|
return settings.storage_path / self.slug
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_libraries(cls):
|
||||||
|
return [
|
||||||
|
x.name
|
||||||
|
for x in filter(lambda x: x.is_dir(), settings.storage_path.iterdir())
|
||||||
|
if x.name[0] != "."
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, slug: str):
|
||||||
|
target_file = settings.storage_path / slug / "definition.json"
|
||||||
|
if not target_file.exists():
|
||||||
|
raise LibraryDoesNotExist()
|
||||||
|
return cls.model_validate_json(target_file.read_text())
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
target_file = self.storage_directory / "definition.json"
|
||||||
|
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
target_file.write_text(self.model_dump_json(indent=2))
|
||||||
|
(self.storage_directory / "documents").mkdir(exist_ok=True)
|
||||||
|
(self.storage_directory / "chunks").mkdir(exist_ok=True)
|
||||||
|
(self.storage_directory / "index").mkdir(exist_ok=True)
|
||||||
|
self.index_strategy.create_index(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def index_strategy(self) -> IndexStrategy:
|
||||||
|
"""Factory method to get the appropriate index strategy"""
|
||||||
|
if self.index == "hnsw":
|
||||||
|
return HNSWIndexStrategy()
|
||||||
|
return BlobIndexStrategy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metric(self) -> Metric:
|
||||||
|
return cosine_distance
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vector_transformation(self) -> VectorTransformation:
|
||||||
|
return identity_transformation
|
||||||
|
|
||||||
|
def chunk_id(self) -> list[UUID]:
|
||||||
|
target_directory = self.storage_directory / "chunks"
|
||||||
|
|
||||||
|
def _file_to_uuid():
|
||||||
|
for file in target_directory.iterdir():
|
||||||
|
if file.is_file():
|
||||||
|
yield UUID(file.name)
|
||||||
|
|
||||||
|
return list(_file_to_uuid())
|
||||||
|
|
||||||
|
def _text_embed(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
search: bool = True,
|
||||||
|
) -> list[tuple[str, Point]]:
|
||||||
|
return embed(
|
||||||
|
text,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
chunk_overlap=self.chunk_overlap,
|
||||||
|
embedding_size=self.embedding_size,
|
||||||
|
search=search,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_document(self, slug: str, content: str):
|
||||||
|
existing_document = Document.read(self, slug)
|
||||||
|
if existing_document:
|
||||||
|
return
|
||||||
|
new_document = Document(slug=slug, content=content)
|
||||||
|
new_document.save(self)
|
||||||
|
new_document.create_chunks(self)
|
||||||
|
return new_document
|
||||||
|
|
||||||
|
def query(self, content: str, k: int = 5) -> list[QueryResult]:
|
||||||
|
embeddings = self._text_embed(content, search=True)
|
||||||
|
|
||||||
|
vector = embeddings[0][1]
|
||||||
|
transformed_vector = self.vector_transformation(vector)
|
||||||
|
|
||||||
|
search_generator = self.index_strategy.search(self, transformed_vector)
|
||||||
|
|
||||||
|
hit_documents = {}
|
||||||
|
|
||||||
|
for chunk_id, distance in search_generator:
|
||||||
|
chunk = Chunk.read(self, chunk_id)
|
||||||
|
if chunk is None:
|
||||||
|
continue
|
||||||
|
if chunk.document not in hit_documents:
|
||||||
|
hit_documents[chunk.document] = QueryResult(
|
||||||
|
chunk=chunk,
|
||||||
|
document=Document.read(self, chunk.document),
|
||||||
|
hits=1,
|
||||||
|
distance=distance,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hit_documents[chunk.document].hits += 1
|
||||||
|
if distance < hit_documents[chunk.document].distance:
|
||||||
|
# This should not happen.
|
||||||
|
hit_documents[chunk.document].distance = distance
|
||||||
|
hit_documents[chunk.document].chunk = chunk
|
||||||
|
|
||||||
|
if len(hit_documents) >= k:
|
||||||
|
break
|
||||||
|
|
||||||
|
return list(hit_documents.values())
|
||||||
|
|
||||||
|
|
||||||
|
class LibrarySummary(Library):
|
||||||
|
documents: int
|
||||||
|
chunks: int
|
||||||
|
indexs: int
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(BaseModel):
|
||||||
|
id: UUID = Field(default_factory=uuid7)
|
||||||
|
vector: Point
|
||||||
|
text: str
|
||||||
|
document: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read(cls, library: Library, id: UUID):
|
||||||
|
target_file = library.storage_directory / "chunks" / str(id)
|
||||||
|
if not target_file.exists():
|
||||||
|
return None
|
||||||
|
return cls.model_validate_json(target_file.read_text())
|
||||||
|
|
||||||
|
def save(self, library: Library):
|
||||||
|
target_file = library.storage_directory / "chunks" / str(self.id)
|
||||||
|
target_file.write_text(self.model_dump_json(indent=2))
|
||||||
|
|
||||||
|
def index(self, library: Library):
|
||||||
|
library.index_strategy.add_chunk(library, self)
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentUpload(BaseModel):
|
||||||
|
slug: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class Document(DocumentUpload):
|
||||||
|
chunks: list[UUID] = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read(cls, library: Library, slug: str):
|
||||||
|
target_file = library.storage_directory / "documents" / slug
|
||||||
|
if target_file.exists():
|
||||||
|
return cls.model_validate_json(target_file.read_text())
|
||||||
|
return None
|
||||||
|
|
||||||
|
def save(self, library: Library):
|
||||||
|
target_file = library.storage_directory / "documents" / self.slug
|
||||||
|
target_file.write_text(self.model_dump_json(indent=2))
|
||||||
|
|
||||||
|
def create_chunks(self, library: Library):
|
||||||
|
slug, content = self.slug, self.content
|
||||||
|
embedding = library._text_embed(f"{slug}\n{content}", search=False)
|
||||||
|
|
||||||
|
for text, vector in embedding:
|
||||||
|
chunk = Chunk(vector=vector, text=text, document=slug)
|
||||||
|
chunk.save(library)
|
||||||
|
chunk.index(library)
|
||||||
|
self.chunks.append(chunk.id)
|
||||||
|
self.save(library)
|
||||||
148
src/notebook/app-create.py
Normal file
148
src/notebook/app-create.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
import marimo
|
||||||
|
|
||||||
|
__generated_with = "0.13.15"
|
||||||
|
app = marimo.App(width="medium")
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
from itertools import chain
|
||||||
|
import marimo as mo
|
||||||
|
from model import Library, Chunk, Document, LibraryDoesNotExist
|
||||||
|
from utils import EmbeddingSize, IndexAlgorithm
|
||||||
|
from notebook.constants import animals, adjetives, datasets
|
||||||
|
|
||||||
|
return EmbeddingSize, IndexAlgorithm, Library, LibraryDoesNotExist, mo
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(mo):
|
||||||
|
mo.md("""# Create Library Widget""")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(mo):
|
||||||
|
get_create_state, set_create_state = mo.state(False)
|
||||||
|
return get_create_state, set_create_state
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(mo, set_create_state):
|
||||||
|
create_button = mo.ui.button(
|
||||||
|
label="Create Library", value=True, on_click=lambda _: set_create_state(True)
|
||||||
|
)
|
||||||
|
return (create_button,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(EmbeddingSize, IndexAlgorithm, mo):
|
||||||
|
library_name = mo.ui.text(placeholder="Library Name", label="Library Name")
|
||||||
|
embedding_size = mo.ui.dropdown(
|
||||||
|
EmbeddingSize.__args__, label="Embedding Size", value=EmbeddingSize.__args__[0]
|
||||||
|
)
|
||||||
|
chunk_size = mo.ui.number(value=128, label="Chunk Size")
|
||||||
|
chunk_overlap = mo.ui.number(start=10, value=64, label="Chunk Overlap")
|
||||||
|
index_blob_limit = mo.ui.number(start=10, value=100, label="Index Blob Size")
|
||||||
|
index_type = mo.ui.dropdown(IndexAlgorithm.__args__, value=IndexAlgorithm.__args__[0], label="Index Type")
|
||||||
|
return (
|
||||||
|
chunk_overlap,
|
||||||
|
chunk_size,
|
||||||
|
embedding_size,
|
||||||
|
index_blob_limit,
|
||||||
|
index_type,
|
||||||
|
library_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(
|
||||||
|
chunk_overlap,
|
||||||
|
chunk_size,
|
||||||
|
create_button,
|
||||||
|
embedding_size,
|
||||||
|
index_blob_limit,
|
||||||
|
index_type,
|
||||||
|
library_name,
|
||||||
|
mo,
|
||||||
|
):
|
||||||
|
fields = [
|
||||||
|
library_name,
|
||||||
|
embedding_size,
|
||||||
|
chunk_size,
|
||||||
|
chunk_overlap,
|
||||||
|
index_type,
|
||||||
|
index_blob_limit,
|
||||||
|
]
|
||||||
|
mo.vstack(
|
||||||
|
[
|
||||||
|
*fields,
|
||||||
|
create_button,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return (fields,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(
|
||||||
|
Library,
|
||||||
|
LibraryDoesNotExist,
|
||||||
|
chunk_overlap,
|
||||||
|
chunk_size,
|
||||||
|
embedding_size,
|
||||||
|
fields,
|
||||||
|
get_create_state,
|
||||||
|
index_blob_limit,
|
||||||
|
index_type,
|
||||||
|
library_name,
|
||||||
|
mo,
|
||||||
|
set_create_state,
|
||||||
|
):
|
||||||
|
output = [mo.md("### Result")]
|
||||||
|
if get_create_state() and all(map(lambda x: bool(x.value), fields)):
|
||||||
|
try:
|
||||||
|
library = Library.load(library_name.value)
|
||||||
|
output += [mo.md("> Library already exists!")]
|
||||||
|
except LibraryDoesNotExist:
|
||||||
|
library = Library(
|
||||||
|
slug=library_name.value,
|
||||||
|
embedding_size=embedding_size.value,
|
||||||
|
chunk_overlap=chunk_overlap.value,
|
||||||
|
chunk_size=chunk_size.value,
|
||||||
|
index_blob_limit=index_blob_limit.value,
|
||||||
|
index=index_type.value,
|
||||||
|
)
|
||||||
|
library.save()
|
||||||
|
output += [mo.md("> Create!")]
|
||||||
|
|
||||||
|
output += [mo.ui.table(library.stat.model_dump())]
|
||||||
|
set_create_state(False)
|
||||||
|
mo.vstack(output)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(Library, mo):
|
||||||
|
_libraries = Library.list_libraries()
|
||||||
|
mo.vstack(
|
||||||
|
[
|
||||||
|
mo.md("# Existing Libraries"),
|
||||||
|
*[
|
||||||
|
mo.md(
|
||||||
|
f"""/// details | **{library.slug}:**\n\n\n{mo.ui.table(library.stat.model_dump())}\n///"""
|
||||||
|
)
|
||||||
|
for library in map(Library.load, _libraries)
|
||||||
|
],
|
||||||
|
*([mo.md("No libraries created yet.")] if len(_libraries) == 0 else []),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run()
|
||||||
113
src/notebook/app-insert.py
Normal file
113
src/notebook/app-insert.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import marimo
|
||||||
|
|
||||||
|
__generated_with = "0.13.15"
|
||||||
|
app = marimo.App(width="medium")
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
import marimo as mo
|
||||||
|
from model import Library, Chunk, Document
|
||||||
|
|
||||||
|
return Chunk, Library, mo
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(Library, mo):
|
||||||
|
libraries = Library.list_libraries()
|
||||||
|
library_dropdown = mo.ui.dropdown(
|
||||||
|
options=libraries, value=libraries[0], label="Choose Active Library"
|
||||||
|
)
|
||||||
|
slug_input = mo.ui.text(placeholder="Document Title", label="Title")
|
||||||
|
content_input = mo.ui.text_area(
|
||||||
|
placeholder="Lorem ipsum...", label="Document Content", full_width=True
|
||||||
|
)
|
||||||
|
return content_input, library_dropdown, slug_input
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(mo):
|
||||||
|
get_insert_state, set_insert_state = mo.state(False)
|
||||||
|
return get_insert_state, set_insert_state
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(mo, set_insert_state):
|
||||||
|
insert_button = mo.ui.button(
|
||||||
|
label="Insert Document", value=True, on_click=lambda _: set_insert_state(True)
|
||||||
|
)
|
||||||
|
return (insert_button,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(content_input, insert_button, library_dropdown, mo, slug_input):
|
||||||
|
mo.vstack([library_dropdown, slug_input, content_input, insert_button.right()])
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(Library, library_dropdown, mo):
|
||||||
|
library = Library.load(library_dropdown.value)
|
||||||
|
mo.vstack([mo.md("### Library Summary"), mo.ui.table(library.stat.model_dump())])
|
||||||
|
return (library,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(
|
||||||
|
Chunk,
|
||||||
|
content_input,
|
||||||
|
get_insert_state,
|
||||||
|
library,
|
||||||
|
mo,
|
||||||
|
set_insert_state,
|
||||||
|
slug_input,
|
||||||
|
):
|
||||||
|
output = [mo.md("### Result")]
|
||||||
|
if get_insert_state() and slug_input.value and content_input.value:
|
||||||
|
print("Inserting")
|
||||||
|
document = library.add_document(slug_input.value, content_input.value)
|
||||||
|
chunks = []
|
||||||
|
for chunk in document.chunks:
|
||||||
|
chunks.append(Chunk.read(library, chunk))
|
||||||
|
|
||||||
|
output += [
|
||||||
|
mo.md(f"""
|
||||||
|
/// details | **Chunk ID:** `{chunk.id}`
|
||||||
|
|
||||||
|
**Chunk Content**:
|
||||||
|
```
|
||||||
|
{chunk.text}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Vector**:
|
||||||
|
```
|
||||||
|
{chunk.vector}
|
||||||
|
```
|
||||||
|
///
|
||||||
|
""")
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
output += [mo.md("> Please provide the data to insert a document.")]
|
||||||
|
if not slug_input.value:
|
||||||
|
print("Missing Title")
|
||||||
|
if not content_input.value:
|
||||||
|
print("Missing document")
|
||||||
|
|
||||||
|
set_insert_state(False)
|
||||||
|
mo.vstack(output)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run()
|
||||||
102
src/notebook/app-search.py
Normal file
102
src/notebook/app-search.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import marimo
|
||||||
|
|
||||||
|
__generated_with = "0.13.15"
|
||||||
|
app = marimo.App(width="medium")
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
import marimo as mo
|
||||||
|
from model import Library, Chunk, Document
|
||||||
|
|
||||||
|
mo.md("# Search on Library Widget")
|
||||||
|
return Library, mo
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(Library, mo):
|
||||||
|
libraries = Library.list_libraries()
|
||||||
|
library_dropdown = mo.ui.dropdown(
|
||||||
|
options=libraries, value=libraries[0], label="Choose Active Library"
|
||||||
|
)
|
||||||
|
search_query = mo.ui.text(placeholder="Search...", label="Search Query")
|
||||||
|
n_results = mo.ui.number(start=1, stop=20, label="Results")
|
||||||
|
return library_dropdown, n_results, search_query
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(library_dropdown, mo, n_results, search_query):
|
||||||
|
mo.hstack(
|
||||||
|
[
|
||||||
|
library_dropdown,
|
||||||
|
search_query,
|
||||||
|
n_results,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(Library, library_dropdown, mo, n_results, search_query):
|
||||||
|
search_query.value, n_results.value
|
||||||
|
library = Library.load(library_dropdown.value)
|
||||||
|
mo.vstack([mo.md("### Library Summary"), mo.ui.table(library.stat.model_dump())])
|
||||||
|
return (library,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(library, mo, n_results, search_query):
|
||||||
|
output = []
|
||||||
|
if search_query.value and n_results.value:
|
||||||
|
results = library.query(search_query.value, n_results.value)
|
||||||
|
|
||||||
|
output = [mo.md("### Search Results")] + [
|
||||||
|
mo.md(
|
||||||
|
f"""
|
||||||
|
/// details | **Hits:** {result.hits} **Document:** `{result.document.slug}` | **Distance:** {result.distance}
|
||||||
|
|
||||||
|
**Chunk ID:** `{result.chunk.id}`
|
||||||
|
|
||||||
|
**Chunk**:
|
||||||
|
|
||||||
|
```
|
||||||
|
{result.chunk.text}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Document**:
|
||||||
|
|
||||||
|
```
|
||||||
|
{result.document.content}
|
||||||
|
```
|
||||||
|
|
||||||
|
///
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
output = [
|
||||||
|
mo.md("""
|
||||||
|
/// details | No search yet.
|
||||||
|
type: warn
|
||||||
|
|
||||||
|
Please type a search query and hit the return key (or click away from the input).
|
||||||
|
///
|
||||||
|
""")
|
||||||
|
]
|
||||||
|
mo.vstack(output)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run()
|
||||||
65
src/notebook/app-seed.py
Normal file
65
src/notebook/app-seed.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import marimo
|
||||||
|
|
||||||
|
__generated_with = "0.13.15"
|
||||||
|
app = marimo.App(width="medium")
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
import marimo as mo
|
||||||
|
import pandas as pd
|
||||||
|
from model import Library
|
||||||
|
return Library, mo, pd
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(mo):
|
||||||
|
mo.md("""### Seed Data To Database""")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(Library, mo):
|
||||||
|
libraries = Library.list_libraries()
|
||||||
|
library_dropdown = mo.ui.dropdown(options=libraries, value=libraries[0], label="Choose Active Library")
|
||||||
|
return (library_dropdown,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(pd):
|
||||||
|
df = pd.read_json("hf://datasets/rohitsaxena/MovieSum/train.jsonl", lines=True)
|
||||||
|
return (df,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(df):
|
||||||
|
df[["movie_name", "summary"]]
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(Library, df, library_dropdown, mo):
|
||||||
|
def _on_click(*args, **kwargs):
|
||||||
|
library = Library.load(library_dropdown.value)
|
||||||
|
total = len(df)
|
||||||
|
for i, row in mo.status.progress_bar(df[["movie_name", "summary"]].iterrows(), total=total):
|
||||||
|
title, summary = row.to_list()
|
||||||
|
library.add_document(title, summary)
|
||||||
|
|
||||||
|
button = mo.ui.button(label="Add these documents.", on_click=_on_click)
|
||||||
|
return (button,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(button, library_dropdown, mo):
|
||||||
|
mo.vstack([library_dropdown, button])
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run()
|
||||||
756
src/notebook/constants.py
Normal file
756
src/notebook/constants.py
Normal file
@ -0,0 +1,756 @@
|
|||||||
|
datasets = {
|
||||||
|
"math": "hf://datasets/WhiteGiverPlus/extract_theorem_en_400/data/train-00000-of-00001.parquet",
|
||||||
|
"stories": "hf://datasets/deven367/babylm-100M-children-stories/data/train-00000-of-00001-13f0a33230d64dd9.parquet",
|
||||||
|
"news": "hf://datasets/fancyzhx/ag_news/data/train-00000-of-00001.parquet",
|
||||||
|
}
|
||||||
|
|
||||||
|
animals = [
|
||||||
|
"Aardvark",
|
||||||
|
"Albatross",
|
||||||
|
"Alligator",
|
||||||
|
"Alpaca",
|
||||||
|
"Ant",
|
||||||
|
"Anteater",
|
||||||
|
"Antelope",
|
||||||
|
"Ape",
|
||||||
|
"Armadillo",
|
||||||
|
"Donkey",
|
||||||
|
"Baboon",
|
||||||
|
"Badger",
|
||||||
|
"Barracuda",
|
||||||
|
"Bat",
|
||||||
|
"Bear",
|
||||||
|
"Beaver",
|
||||||
|
"Bee",
|
||||||
|
"Bison",
|
||||||
|
"Boar",
|
||||||
|
"Buffalo",
|
||||||
|
"Butterfly",
|
||||||
|
"Camel",
|
||||||
|
"Capybara",
|
||||||
|
"Caribou",
|
||||||
|
"Cassowary",
|
||||||
|
"Cat",
|
||||||
|
"Caterpillar",
|
||||||
|
"Cattle",
|
||||||
|
"Chamois",
|
||||||
|
"Cheetah",
|
||||||
|
"Chicken",
|
||||||
|
"Chimpanzee",
|
||||||
|
"Chinchilla",
|
||||||
|
"Chough",
|
||||||
|
"Clam",
|
||||||
|
"Cobra",
|
||||||
|
"Cockroach",
|
||||||
|
"Cod",
|
||||||
|
"Cormorant",
|
||||||
|
"Coyote",
|
||||||
|
"Crab",
|
||||||
|
"Crane",
|
||||||
|
"Crocodile",
|
||||||
|
"Crow",
|
||||||
|
"Curlew",
|
||||||
|
"Deer",
|
||||||
|
"Dinosaur",
|
||||||
|
"Dog",
|
||||||
|
"Dogfish",
|
||||||
|
"Dolphin",
|
||||||
|
"Dotterel",
|
||||||
|
"Dove",
|
||||||
|
"Dragonfly",
|
||||||
|
"Duck",
|
||||||
|
"Dugong",
|
||||||
|
"Dunlin",
|
||||||
|
"Eagle",
|
||||||
|
"Echidna",
|
||||||
|
"Eel",
|
||||||
|
"Eland",
|
||||||
|
"Elephant",
|
||||||
|
"Elk",
|
||||||
|
"Emu",
|
||||||
|
"Falcon",
|
||||||
|
"Ferret",
|
||||||
|
"Finch",
|
||||||
|
"Fish",
|
||||||
|
"Flamingo",
|
||||||
|
"Fly",
|
||||||
|
"Fox",
|
||||||
|
"Frog",
|
||||||
|
"Gaur",
|
||||||
|
"Gazelle",
|
||||||
|
"Gerbil",
|
||||||
|
"Giraffe",
|
||||||
|
"Gnat",
|
||||||
|
"Gnu",
|
||||||
|
"Goat",
|
||||||
|
"Goldfinch",
|
||||||
|
"Goldfish",
|
||||||
|
"Goose",
|
||||||
|
"Gorilla",
|
||||||
|
"Goshawk",
|
||||||
|
"Grasshopper",
|
||||||
|
"Grouse",
|
||||||
|
"Guanaco",
|
||||||
|
"Gull",
|
||||||
|
"Hamster",
|
||||||
|
"Hare",
|
||||||
|
"Hawk",
|
||||||
|
"Hedgehog",
|
||||||
|
"Heron",
|
||||||
|
"Herring",
|
||||||
|
"Hippopotamus",
|
||||||
|
"Hornet",
|
||||||
|
"Horse",
|
||||||
|
"Human",
|
||||||
|
"Hummingbird",
|
||||||
|
"Hyena",
|
||||||
|
"Ibex",
|
||||||
|
"Ibis",
|
||||||
|
"Jackal",
|
||||||
|
"Jaguar",
|
||||||
|
"Jay",
|
||||||
|
"Jellyfish",
|
||||||
|
"Kangaroo",
|
||||||
|
"Kingfisher",
|
||||||
|
"Koala",
|
||||||
|
"Kookabura",
|
||||||
|
"Kouprey",
|
||||||
|
"Kudu",
|
||||||
|
"Lapwing",
|
||||||
|
"Lark",
|
||||||
|
"Lemur",
|
||||||
|
"Leopard",
|
||||||
|
"Lion",
|
||||||
|
"Llama",
|
||||||
|
"Lobster",
|
||||||
|
"Locust",
|
||||||
|
"Loris",
|
||||||
|
"Louse",
|
||||||
|
"Lyrebird",
|
||||||
|
"Magpie",
|
||||||
|
"Mallard",
|
||||||
|
"Manatee",
|
||||||
|
"Mandrill",
|
||||||
|
"Mantis",
|
||||||
|
"Marten",
|
||||||
|
"Meerkat",
|
||||||
|
"Mink",
|
||||||
|
"Mole",
|
||||||
|
"Mongoose",
|
||||||
|
"Monkey",
|
||||||
|
"Moose",
|
||||||
|
"Mosquito",
|
||||||
|
"Mouse",
|
||||||
|
"Mule",
|
||||||
|
"Narwhal",
|
||||||
|
"Newt",
|
||||||
|
"Nightingale",
|
||||||
|
"Octopus",
|
||||||
|
"Okapi",
|
||||||
|
"Opossum",
|
||||||
|
"Oryx",
|
||||||
|
"Ostrich",
|
||||||
|
"Otter",
|
||||||
|
"Owl",
|
||||||
|
"Oyster",
|
||||||
|
"Panther",
|
||||||
|
"Parrot",
|
||||||
|
"Partridge",
|
||||||
|
"Peafowl",
|
||||||
|
"Pelican",
|
||||||
|
"Penguin",
|
||||||
|
"Pheasant",
|
||||||
|
"Pig",
|
||||||
|
"Pigeon",
|
||||||
|
"Pony",
|
||||||
|
"Porcupine",
|
||||||
|
"Porpoise",
|
||||||
|
"Quail",
|
||||||
|
"Quelea",
|
||||||
|
"Quetzal",
|
||||||
|
"Rabbit",
|
||||||
|
"Raccoon",
|
||||||
|
"Rail",
|
||||||
|
"Ram",
|
||||||
|
"Rat",
|
||||||
|
"Raven",
|
||||||
|
"Red deer",
|
||||||
|
"Red panda",
|
||||||
|
"Reindeer",
|
||||||
|
"Rhinoceros",
|
||||||
|
"Rook",
|
||||||
|
"Salamander",
|
||||||
|
"Salmon",
|
||||||
|
"Sand Dollar",
|
||||||
|
"Sandpiper",
|
||||||
|
"Sardine",
|
||||||
|
"Scorpion",
|
||||||
|
"Seahorse",
|
||||||
|
"Seal",
|
||||||
|
"Shark",
|
||||||
|
"Sheep",
|
||||||
|
"Shrew",
|
||||||
|
"Skunk",
|
||||||
|
"Snail",
|
||||||
|
"Snake",
|
||||||
|
"Sparrow",
|
||||||
|
"Spider",
|
||||||
|
"Spoonbill",
|
||||||
|
"Squid",
|
||||||
|
"Squirrel",
|
||||||
|
"Starling",
|
||||||
|
"Stingray",
|
||||||
|
"Stinkbug",
|
||||||
|
"Stork",
|
||||||
|
"Swallow",
|
||||||
|
"Swan",
|
||||||
|
"Tapir",
|
||||||
|
"Tarsier",
|
||||||
|
"Termite",
|
||||||
|
"Tiger",
|
||||||
|
"Toad",
|
||||||
|
"Trout",
|
||||||
|
"Turkey",
|
||||||
|
"Turtle",
|
||||||
|
"Viper",
|
||||||
|
"Vulture",
|
||||||
|
"Wallaby",
|
||||||
|
"Walrus",
|
||||||
|
"Wasp",
|
||||||
|
"Weasel",
|
||||||
|
"Whale",
|
||||||
|
"Wildcat",
|
||||||
|
"Wolf",
|
||||||
|
"Wolverine",
|
||||||
|
"Wombat",
|
||||||
|
"Woodcock",
|
||||||
|
"Woodpecker",
|
||||||
|
"Worm",
|
||||||
|
"Wren",
|
||||||
|
"Yak",
|
||||||
|
"Zebra",
|
||||||
|
]
|
||||||
|
|
||||||
|
adjetives = [
|
||||||
|
"different",
|
||||||
|
"used",
|
||||||
|
"important",
|
||||||
|
"every",
|
||||||
|
"large",
|
||||||
|
"available",
|
||||||
|
"popular",
|
||||||
|
"able",
|
||||||
|
"basic",
|
||||||
|
"known",
|
||||||
|
"various",
|
||||||
|
"difficult",
|
||||||
|
"several",
|
||||||
|
"united",
|
||||||
|
"historical",
|
||||||
|
"hot",
|
||||||
|
"useful",
|
||||||
|
"mental",
|
||||||
|
"scared",
|
||||||
|
"additional",
|
||||||
|
"emotional",
|
||||||
|
"old",
|
||||||
|
"political",
|
||||||
|
"similar",
|
||||||
|
"healthy",
|
||||||
|
"financial",
|
||||||
|
"medical",
|
||||||
|
"traditional",
|
||||||
|
"federal",
|
||||||
|
"entire",
|
||||||
|
"strong",
|
||||||
|
"actual",
|
||||||
|
"significant",
|
||||||
|
"successful",
|
||||||
|
"electrical",
|
||||||
|
"expensive",
|
||||||
|
"pregnant",
|
||||||
|
"intelligent",
|
||||||
|
"interesting",
|
||||||
|
"poor",
|
||||||
|
"happy",
|
||||||
|
"responsible",
|
||||||
|
"cute",
|
||||||
|
"helpful",
|
||||||
|
"recent",
|
||||||
|
"willing",
|
||||||
|
"nice",
|
||||||
|
"wonderful",
|
||||||
|
"impossible",
|
||||||
|
"serious",
|
||||||
|
"huge",
|
||||||
|
"rare",
|
||||||
|
"technical",
|
||||||
|
"typical",
|
||||||
|
"competitive",
|
||||||
|
"critical",
|
||||||
|
"electronic",
|
||||||
|
"immediate",
|
||||||
|
"aware",
|
||||||
|
"educational",
|
||||||
|
"environmental",
|
||||||
|
"global",
|
||||||
|
"legal",
|
||||||
|
"relevant",
|
||||||
|
"accurate",
|
||||||
|
"capable",
|
||||||
|
"dangerous",
|
||||||
|
"dramatic",
|
||||||
|
"efficient",
|
||||||
|
"powerful",
|
||||||
|
"foreign",
|
||||||
|
"hungry",
|
||||||
|
"practical",
|
||||||
|
"psychological",
|
||||||
|
"severe",
|
||||||
|
"suitable",
|
||||||
|
"numerous",
|
||||||
|
"sufficient",
|
||||||
|
"unusual",
|
||||||
|
"consistent",
|
||||||
|
"cultural",
|
||||||
|
"existing",
|
||||||
|
"famous",
|
||||||
|
"pure",
|
||||||
|
"afraid",
|
||||||
|
"obvious",
|
||||||
|
"careful",
|
||||||
|
"latter",
|
||||||
|
"unhappy",
|
||||||
|
"acceptable",
|
||||||
|
"aggressive",
|
||||||
|
"boring",
|
||||||
|
"distinct",
|
||||||
|
"eastern",
|
||||||
|
"logical",
|
||||||
|
"reasonable",
|
||||||
|
"strict",
|
||||||
|
"administrative",
|
||||||
|
"automatic",
|
||||||
|
"civil",
|
||||||
|
"former",
|
||||||
|
"massive",
|
||||||
|
"southern",
|
||||||
|
"unfair",
|
||||||
|
"visible",
|
||||||
|
"alive",
|
||||||
|
"angry",
|
||||||
|
"desperate",
|
||||||
|
"exciting",
|
||||||
|
"friendly",
|
||||||
|
"lucky",
|
||||||
|
"realistic",
|
||||||
|
"sorry",
|
||||||
|
"ugly",
|
||||||
|
"unlikely",
|
||||||
|
"anxious",
|
||||||
|
"comprehensive",
|
||||||
|
"curious",
|
||||||
|
"impressive",
|
||||||
|
"informal",
|
||||||
|
"inner",
|
||||||
|
"pleasant",
|
||||||
|
"sexual",
|
||||||
|
"sudden",
|
||||||
|
"terrible",
|
||||||
|
"unable",
|
||||||
|
"weak",
|
||||||
|
"wooden",
|
||||||
|
"asleep",
|
||||||
|
"confident",
|
||||||
|
"conscious",
|
||||||
|
"decent",
|
||||||
|
"embarrassed",
|
||||||
|
"guilty",
|
||||||
|
"lonely",
|
||||||
|
"mad",
|
||||||
|
"nervous",
|
||||||
|
"odd",
|
||||||
|
"remarkable",
|
||||||
|
"substantial",
|
||||||
|
"suspicious",
|
||||||
|
"tall",
|
||||||
|
"tiny",
|
||||||
|
"other",
|
||||||
|
"such",
|
||||||
|
"even",
|
||||||
|
"new",
|
||||||
|
"just",
|
||||||
|
"good",
|
||||||
|
"any",
|
||||||
|
"each",
|
||||||
|
"much",
|
||||||
|
"own",
|
||||||
|
"great",
|
||||||
|
"another",
|
||||||
|
"same",
|
||||||
|
"few",
|
||||||
|
"free",
|
||||||
|
"right",
|
||||||
|
"still",
|
||||||
|
"best",
|
||||||
|
"public",
|
||||||
|
"human",
|
||||||
|
"both",
|
||||||
|
"local",
|
||||||
|
"sure",
|
||||||
|
"better",
|
||||||
|
"general",
|
||||||
|
"specific",
|
||||||
|
"enough",
|
||||||
|
"long",
|
||||||
|
"small",
|
||||||
|
"less",
|
||||||
|
"high",
|
||||||
|
"certain",
|
||||||
|
"little",
|
||||||
|
"common",
|
||||||
|
"next",
|
||||||
|
"simple",
|
||||||
|
"hard",
|
||||||
|
"past",
|
||||||
|
"big",
|
||||||
|
"possible",
|
||||||
|
"particular",
|
||||||
|
"real",
|
||||||
|
"major",
|
||||||
|
"personal",
|
||||||
|
"current",
|
||||||
|
"left",
|
||||||
|
"national",
|
||||||
|
"least",
|
||||||
|
"natural",
|
||||||
|
"physical",
|
||||||
|
"short",
|
||||||
|
"last",
|
||||||
|
"single",
|
||||||
|
"individual",
|
||||||
|
"main",
|
||||||
|
"potential",
|
||||||
|
"professional",
|
||||||
|
"international",
|
||||||
|
"lower",
|
||||||
|
"open",
|
||||||
|
"according",
|
||||||
|
"alternative",
|
||||||
|
"special",
|
||||||
|
"working",
|
||||||
|
"true",
|
||||||
|
"whole",
|
||||||
|
"clear",
|
||||||
|
"dry",
|
||||||
|
"easy",
|
||||||
|
"cold",
|
||||||
|
"commercial",
|
||||||
|
"full",
|
||||||
|
"low",
|
||||||
|
"primary",
|
||||||
|
"worth",
|
||||||
|
"necessary",
|
||||||
|
"positive",
|
||||||
|
"present",
|
||||||
|
"close",
|
||||||
|
"creative",
|
||||||
|
"green",
|
||||||
|
"late",
|
||||||
|
"fit",
|
||||||
|
"glad",
|
||||||
|
"proper",
|
||||||
|
"complex",
|
||||||
|
"content",
|
||||||
|
"due",
|
||||||
|
"effective",
|
||||||
|
"middle",
|
||||||
|
"regular",
|
||||||
|
"fast",
|
||||||
|
"independent",
|
||||||
|
"original",
|
||||||
|
"wide",
|
||||||
|
"beautiful",
|
||||||
|
"complete",
|
||||||
|
"active",
|
||||||
|
"negative",
|
||||||
|
"safe",
|
||||||
|
"visual",
|
||||||
|
"wrong",
|
||||||
|
"ago",
|
||||||
|
"quick",
|
||||||
|
"ready",
|
||||||
|
"straight",
|
||||||
|
"white",
|
||||||
|
"direct",
|
||||||
|
"excellent",
|
||||||
|
"extra",
|
||||||
|
"junior",
|
||||||
|
"pretty",
|
||||||
|
"unique",
|
||||||
|
"classic",
|
||||||
|
"final",
|
||||||
|
"overall",
|
||||||
|
"private",
|
||||||
|
"separate",
|
||||||
|
"western",
|
||||||
|
"alone",
|
||||||
|
"familiar",
|
||||||
|
"official",
|
||||||
|
"perfect",
|
||||||
|
"bright",
|
||||||
|
"broad",
|
||||||
|
"comfortable",
|
||||||
|
"flat",
|
||||||
|
"rich",
|
||||||
|
"warm",
|
||||||
|
"young",
|
||||||
|
"heavy",
|
||||||
|
"valuable",
|
||||||
|
"correct",
|
||||||
|
"leading",
|
||||||
|
"slow",
|
||||||
|
"clean",
|
||||||
|
"fresh",
|
||||||
|
"normal",
|
||||||
|
"secret",
|
||||||
|
"tough",
|
||||||
|
"brown",
|
||||||
|
"cheap",
|
||||||
|
"deep",
|
||||||
|
"objective",
|
||||||
|
"secure",
|
||||||
|
"thin",
|
||||||
|
"chemical",
|
||||||
|
"cool",
|
||||||
|
"extreme",
|
||||||
|
"exact",
|
||||||
|
"fair",
|
||||||
|
"fine",
|
||||||
|
"formal",
|
||||||
|
"opposite",
|
||||||
|
"remote",
|
||||||
|
"total",
|
||||||
|
"vast",
|
||||||
|
"lost",
|
||||||
|
"smooth",
|
||||||
|
"dark",
|
||||||
|
"double",
|
||||||
|
"equal",
|
||||||
|
"firm",
|
||||||
|
"frequent",
|
||||||
|
"internal",
|
||||||
|
"sensitive",
|
||||||
|
"constant",
|
||||||
|
"minor",
|
||||||
|
"previous",
|
||||||
|
"raw",
|
||||||
|
"soft",
|
||||||
|
"solid",
|
||||||
|
"weird",
|
||||||
|
"amazing",
|
||||||
|
"annual",
|
||||||
|
"busy",
|
||||||
|
"dead",
|
||||||
|
"false",
|
||||||
|
"round",
|
||||||
|
"sharp",
|
||||||
|
"thick",
|
||||||
|
"wise",
|
||||||
|
"equivalent",
|
||||||
|
"initial",
|
||||||
|
"narrow",
|
||||||
|
"nearby",
|
||||||
|
"proud",
|
||||||
|
"spiritual",
|
||||||
|
"wild",
|
||||||
|
"adult",
|
||||||
|
"apart",
|
||||||
|
"brief",
|
||||||
|
"crazy",
|
||||||
|
"prior",
|
||||||
|
"rough",
|
||||||
|
"sad",
|
||||||
|
"sick",
|
||||||
|
"strange",
|
||||||
|
"external",
|
||||||
|
"illegal",
|
||||||
|
"loud",
|
||||||
|
"mobile",
|
||||||
|
"nasty",
|
||||||
|
"ordinary",
|
||||||
|
"royal",
|
||||||
|
"senior",
|
||||||
|
"super",
|
||||||
|
"tight",
|
||||||
|
"upper",
|
||||||
|
"yellow",
|
||||||
|
"dependent",
|
||||||
|
"funny",
|
||||||
|
"gross",
|
||||||
|
"ill",
|
||||||
|
"spare",
|
||||||
|
"sweet",
|
||||||
|
"upstairs",
|
||||||
|
"usual",
|
||||||
|
"brave",
|
||||||
|
"calm",
|
||||||
|
"dirty",
|
||||||
|
"downtown",
|
||||||
|
"grand",
|
||||||
|
"honest",
|
||||||
|
"loose",
|
||||||
|
"male",
|
||||||
|
"quiet",
|
||||||
|
"brilliant",
|
||||||
|
"dear",
|
||||||
|
"drunk",
|
||||||
|
"empty",
|
||||||
|
"female",
|
||||||
|
"inevitable",
|
||||||
|
"neat",
|
||||||
|
"ok",
|
||||||
|
"representative",
|
||||||
|
"silly",
|
||||||
|
"slight",
|
||||||
|
"smart",
|
||||||
|
"stupid",
|
||||||
|
"temporary",
|
||||||
|
"weekly",
|
||||||
|
"that",
|
||||||
|
"this",
|
||||||
|
"what",
|
||||||
|
"which",
|
||||||
|
"time",
|
||||||
|
"these",
|
||||||
|
"work",
|
||||||
|
"no",
|
||||||
|
"only",
|
||||||
|
"then",
|
||||||
|
"first",
|
||||||
|
"money",
|
||||||
|
"over",
|
||||||
|
"business",
|
||||||
|
"his",
|
||||||
|
"game",
|
||||||
|
"think",
|
||||||
|
"after",
|
||||||
|
"life",
|
||||||
|
"day",
|
||||||
|
"home",
|
||||||
|
"economy",
|
||||||
|
"away",
|
||||||
|
"either",
|
||||||
|
"fat",
|
||||||
|
"key",
|
||||||
|
"training",
|
||||||
|
"top",
|
||||||
|
"level",
|
||||||
|
"far",
|
||||||
|
"fun",
|
||||||
|
"house",
|
||||||
|
"kind",
|
||||||
|
"future",
|
||||||
|
"action",
|
||||||
|
"live",
|
||||||
|
"period",
|
||||||
|
"subject",
|
||||||
|
"mean",
|
||||||
|
"stock",
|
||||||
|
"chance",
|
||||||
|
"beginning",
|
||||||
|
"upset",
|
||||||
|
"chicken",
|
||||||
|
"head",
|
||||||
|
"material",
|
||||||
|
"salt",
|
||||||
|
"car",
|
||||||
|
"appropriate",
|
||||||
|
"inside",
|
||||||
|
"outside",
|
||||||
|
"standard",
|
||||||
|
"medium",
|
||||||
|
"choice",
|
||||||
|
"north",
|
||||||
|
"square",
|
||||||
|
"born",
|
||||||
|
"capital",
|
||||||
|
"shot",
|
||||||
|
"front",
|
||||||
|
"living",
|
||||||
|
"plastic",
|
||||||
|
"express",
|
||||||
|
"feeling",
|
||||||
|
"otherwise",
|
||||||
|
"plus",
|
||||||
|
"savings",
|
||||||
|
"animal",
|
||||||
|
"budget",
|
||||||
|
"minute",
|
||||||
|
"character",
|
||||||
|
"maximum",
|
||||||
|
"novel",
|
||||||
|
"plenty",
|
||||||
|
"select",
|
||||||
|
"background",
|
||||||
|
"forward",
|
||||||
|
"glass",
|
||||||
|
"joint",
|
||||||
|
"master",
|
||||||
|
"red",
|
||||||
|
"vegetable",
|
||||||
|
"ideal",
|
||||||
|
"kitchen",
|
||||||
|
"mother",
|
||||||
|
"party",
|
||||||
|
"relative",
|
||||||
|
"signal",
|
||||||
|
"street",
|
||||||
|
"connect",
|
||||||
|
"minimum",
|
||||||
|
"sea",
|
||||||
|
"south",
|
||||||
|
"status",
|
||||||
|
"daughter",
|
||||||
|
"hour",
|
||||||
|
"trick",
|
||||||
|
"afternoon",
|
||||||
|
"gold",
|
||||||
|
"mission",
|
||||||
|
"agent",
|
||||||
|
"corner",
|
||||||
|
"east",
|
||||||
|
"neither",
|
||||||
|
"parking",
|
||||||
|
"routine",
|
||||||
|
"swimming",
|
||||||
|
"winter",
|
||||||
|
"airline",
|
||||||
|
"designer",
|
||||||
|
"dress",
|
||||||
|
"emergency",
|
||||||
|
"evening",
|
||||||
|
"extension",
|
||||||
|
"holiday",
|
||||||
|
"horror",
|
||||||
|
"mountain",
|
||||||
|
"patient",
|
||||||
|
"proof",
|
||||||
|
"west",
|
||||||
|
"wine",
|
||||||
|
"expert",
|
||||||
|
"native",
|
||||||
|
"opening",
|
||||||
|
"silver",
|
||||||
|
"waste",
|
||||||
|
"plane",
|
||||||
|
"leather",
|
||||||
|
"purple",
|
||||||
|
"specialist",
|
||||||
|
"bitter",
|
||||||
|
"incident",
|
||||||
|
"motor",
|
||||||
|
"pretend",
|
||||||
|
"prize",
|
||||||
|
"resident",
|
||||||
|
]
|
||||||
15
src/settings.py
Normal file
15
src/settings.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from pydantic import SecretStr
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
co_api_key: SecretStr
|
||||||
|
cohere_cache: bool = True
|
||||||
|
storage_path: Path = Path("./data/")
|
||||||
|
fastapi_port: int = 8000
|
||||||
|
fastapi_host: str = "localhost"
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
settings.storage_path.mkdir(exist_ok=True, parents=True)
|
||||||
267
src/tests/test_kmeans.py
Normal file
267
src/tests/test_kmeans.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from knn import (
|
||||||
|
KMeansBinary,
|
||||||
|
Point,
|
||||||
|
calculate_centroid,
|
||||||
|
classify_point_on_clusters,
|
||||||
|
compare_clusters,
|
||||||
|
vector_scalar_multiplication,
|
||||||
|
vector_sum,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test fixtures and helper functions
|
||||||
|
@pytest.fixture
|
||||||
|
def euclidean_distance():
|
||||||
|
def distance(p1: Point, p2: Point) -> float:
|
||||||
|
return (sum((a - b) ** 2 for a, b in zip(p1, p2)))**0.5
|
||||||
|
|
||||||
|
return distance
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manhattan_distance():
|
||||||
|
def distance(p1: Point, p2: Point) -> float:
|
||||||
|
return sum(abs(a - b) for a, b in zip(p1, p2))
|
||||||
|
|
||||||
|
return distance
|
||||||
|
|
||||||
|
|
||||||
|
def points_close(p1: Point, p2: Point, tolerance: float = 1e-6) -> bool:
|
||||||
|
"""Check if two points are approximately equal"""
|
||||||
|
return all(abs(a - b) < tolerance for a, b in zip(p1, p2))
|
||||||
|
|
||||||
|
|
||||||
|
class TestVectorOperations:
|
||||||
|
def test_vector_scalar_multiplication(self):
|
||||||
|
vector = [1.0, 2.0, 3.0]
|
||||||
|
result = vector_scalar_multiplication(2, vector)
|
||||||
|
assert result == [2.0, 4.0, 6.0]
|
||||||
|
|
||||||
|
# Test with zero scalar
|
||||||
|
result = vector_scalar_multiplication(0, vector)
|
||||||
|
assert result == [0.0, 0.0, 0.0]
|
||||||
|
|
||||||
|
# Test with negative scalar
|
||||||
|
result = vector_scalar_multiplication(-1.5, vector)
|
||||||
|
assert result == [-1.5, -3.0, -4.5]
|
||||||
|
|
||||||
|
def test_vector_sum(self):
|
||||||
|
a = [1.0, 2.0, 3.0]
|
||||||
|
b = [4.0, 5.0, 6.0]
|
||||||
|
result = vector_sum(a, b)
|
||||||
|
assert result == [5.0, 7.0, 9.0]
|
||||||
|
|
||||||
|
# Test with zeros
|
||||||
|
zero = [0.0, 0.0, 0.0]
|
||||||
|
result = vector_sum(a, zero)
|
||||||
|
assert result == a
|
||||||
|
|
||||||
|
def test_vector_sum_different_lengths(self):
|
||||||
|
a = [1.0, 2.0]
|
||||||
|
b = [3.0, 4.0, 5.0]
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
vector_sum(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCentroid:
|
||||||
|
def test_calculate_centroid_simple(self):
|
||||||
|
points = [[0.0, 0.0], [2.0, 0.0], [1.0, 2.0]]
|
||||||
|
centroid = calculate_centroid(points)
|
||||||
|
expected = [1.0, 2.0 / 3.0]
|
||||||
|
assert points_close(centroid, expected)
|
||||||
|
|
||||||
|
def test_calculate_centroid_single_point(self):
|
||||||
|
points = [[3.0, 4.0]]
|
||||||
|
centroid = calculate_centroid(points)
|
||||||
|
assert points_close(centroid, [3.0, 4.0])
|
||||||
|
|
||||||
|
def test_calculate_centroid_empty_list(self):
|
||||||
|
points = []
|
||||||
|
centroid = calculate_centroid(points)
|
||||||
|
assert centroid == []
|
||||||
|
|
||||||
|
def test_calculate_centroid_identical_points(self):
|
||||||
|
points = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]
|
||||||
|
centroid = calculate_centroid(points)
|
||||||
|
assert points_close(centroid, [1.0, 1.0])
|
||||||
|
|
||||||
|
|
||||||
|
class TestClassification:
|
||||||
|
def test_classify_point_on_clusters(self, euclidean_distance):
|
||||||
|
point = [1.0, 1.0]
|
||||||
|
centers = ([0.0, 0.0], [3.0, 3.0])
|
||||||
|
|
||||||
|
# Point should be closer to first center
|
||||||
|
result = classify_point_on_clusters(point, centers, euclidean_distance)
|
||||||
|
assert result == True
|
||||||
|
|
||||||
|
# Test point closer to second center
|
||||||
|
point = [2.5, 2.5]
|
||||||
|
result = classify_point_on_clusters(point, centers, euclidean_distance)
|
||||||
|
assert result == False
|
||||||
|
|
||||||
|
def test_classify_point_equidistant(self, euclidean_distance):
|
||||||
|
point = [1.5, 1.5]
|
||||||
|
centers = ([0.0, 0.0], [3.0, 3.0])
|
||||||
|
|
||||||
|
# Point is equidistant, should return False (>=)
|
||||||
|
result = classify_point_on_clusters(point, centers, euclidean_distance)
|
||||||
|
assert result == False
|
||||||
|
|
||||||
|
|
||||||
|
class TestClusterComparison:
|
||||||
|
def test_compare_clusters_identical(self):
|
||||||
|
cluster_a = [[1.0, 2.0], [3.0, 4.0]]
|
||||||
|
cluster_b = [[1.0, 2.0], [3.0, 4.0]]
|
||||||
|
assert compare_clusters(cluster_a, cluster_b) == True
|
||||||
|
|
||||||
|
def test_compare_clusters_different_order(self):
|
||||||
|
cluster_a = [[1.0, 2.0], [3.0, 4.0]]
|
||||||
|
cluster_b = [[3.0, 4.0], [1.0, 2.0]]
|
||||||
|
assert compare_clusters(cluster_a, cluster_b) == True
|
||||||
|
|
||||||
|
def test_compare_clusters_different(self):
|
||||||
|
cluster_a = [[1.0, 2.0], [3.0, 4.0]]
|
||||||
|
cluster_b = [[1.0, 2.0], [5.0, 6.0]]
|
||||||
|
assert compare_clusters(cluster_a, cluster_b) == False
|
||||||
|
|
||||||
|
def test_compare_clusters_empty(self):
|
||||||
|
assert compare_clusters([], []) == True
|
||||||
|
assert compare_clusters([[1.0, 2.0]], []) == False
|
||||||
|
|
||||||
|
|
||||||
|
class TestKMeansBinary:
|
||||||
|
def test_kmeans_binary_simple_case(self, euclidean_distance):
|
||||||
|
# Two well-separated clusters
|
||||||
|
points = [
|
||||||
|
[0.0, 0.0],
|
||||||
|
[0.1, 0.1],
|
||||||
|
[0.2, 0.0], # Cluster 1
|
||||||
|
[5.0, 5.0],
|
||||||
|
[5.1, 5.1],
|
||||||
|
[5.0, 5.2], # Cluster 2
|
||||||
|
]
|
||||||
|
|
||||||
|
centroid1, centroid2 = KMeansBinary(points, euclidean_distance)
|
||||||
|
|
||||||
|
# Check that centroids are reasonable
|
||||||
|
assert len(centroid1) == 2
|
||||||
|
assert len(centroid2) == 2
|
||||||
|
|
||||||
|
# One centroid should be near (0.1, 0.033), other near (5.033, 5.1)
|
||||||
|
c1_near_origin = abs(centroid1[0]) < 1 and abs(centroid1[1]) < 1
|
||||||
|
c2_near_origin = abs(centroid2[0]) < 1 and abs(centroid2[1]) < 1
|
||||||
|
|
||||||
|
# Exactly one should be near origin
|
||||||
|
assert c1_near_origin != c2_near_origin
|
||||||
|
|
||||||
|
def test_kmeans_binary_single_point(self, euclidean_distance):
|
||||||
|
points = [[1.0, 1.0]]
|
||||||
|
centroid1, centroid2 = KMeansBinary(points, euclidean_distance)
|
||||||
|
|
||||||
|
# Both centroids should be the single point
|
||||||
|
assert points_close(centroid1, [1.0, 1.0])
|
||||||
|
assert points_close(centroid2, []) # Empty cluster
|
||||||
|
|
||||||
|
def test_kmeans_binary_two_points(self, euclidean_distance):
|
||||||
|
points = [[0.0, 0.0], [2.0, 2.0]]
|
||||||
|
centroid1, centroid2 = KMeansBinary(points, euclidean_distance)
|
||||||
|
|
||||||
|
# Should converge to the two original points
|
||||||
|
centroids = [centroid1, centroid2]
|
||||||
|
assert any(points_close(c, [0.0, 0.0]) for c in centroids)
|
||||||
|
assert any(points_close(c, [2.0, 2.0]) for c in centroids)
|
||||||
|
|
||||||
|
def test_kmeans_binary_convergence(self, euclidean_distance):
|
||||||
|
# Test that algorithm converges within reasonable iterations
|
||||||
|
points = [
|
||||||
|
[i / 10.0, i / 10.0]
|
||||||
|
for i in range(5) # Points along diagonal
|
||||||
|
] + [
|
||||||
|
[i / 10.0 + 5, i / 10.0 + 5]
|
||||||
|
for i in range(5) # Shifted cluster
|
||||||
|
]
|
||||||
|
|
||||||
|
centroid1, centroid2 = KMeansBinary(
|
||||||
|
points, euclidean_distance, max_iterations=50
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should produce two distinct clusters
|
||||||
|
distance_between_centroids = euclidean_distance(centroid1, centroid2)
|
||||||
|
assert distance_between_centroids > 2.0 # Should be well separated
|
||||||
|
|
||||||
|
def test_kmeans_binary_with_manhattan_distance(self, manhattan_distance):
|
||||||
|
points = [[0, 0], [1, 0], [0, 1], [10, 10], [11, 10], [10, 11]]
|
||||||
|
|
||||||
|
centroid1, centroid2 = KMeansBinary(points, manhattan_distance)
|
||||||
|
|
||||||
|
# Should separate into two clusters
|
||||||
|
assert len(centroid1) == 2
|
||||||
|
assert len(centroid2) == 2
|
||||||
|
|
||||||
|
# One centroid should be near origin, other near (10,10)
|
||||||
|
c1_near_origin = abs(centroid1[0]) < 5 and abs(centroid1[1]) < 5
|
||||||
|
c2_near_origin = abs(centroid2[0]) < 5 and abs(centroid2[1]) < 5
|
||||||
|
|
||||||
|
assert c1_near_origin != c2_near_origin
|
||||||
|
|
||||||
|
def test_kmeans_binary_max_iterations(self, euclidean_distance):
|
||||||
|
# Test that max_iterations parameter works
|
||||||
|
points = [[i, i] for i in range(10)]
|
||||||
|
|
||||||
|
# Should work with very few iterations
|
||||||
|
centroid1, centroid2 = KMeansBinary(
|
||||||
|
points, euclidean_distance, max_iterations=1
|
||||||
|
)
|
||||||
|
assert len(centroid1) == 2
|
||||||
|
assert len(centroid2) == 2
|
||||||
|
|
||||||
|
def test_kmeans_binary_3d_points(self, euclidean_distance):
|
||||||
|
# Test with 3D points
|
||||||
|
points = [
|
||||||
|
[0, 0, 0],
|
||||||
|
[1, 0, 0],
|
||||||
|
[0, 1, 0], # Near origin
|
||||||
|
[5, 5, 5],
|
||||||
|
[6, 5, 5],
|
||||||
|
[5, 6, 5], # Far from origin
|
||||||
|
]
|
||||||
|
|
||||||
|
centroid1, centroid2 = KMeansBinary(points, euclidean_distance)
|
||||||
|
|
||||||
|
assert len(centroid1) == 3
|
||||||
|
assert len(centroid2) == 3
|
||||||
|
|
||||||
|
# Should separate the two groups
|
||||||
|
c1_near_origin = all(abs(x) < 3 for x in centroid1)
|
||||||
|
c2_near_origin = all(abs(x) < 3 for x in centroid2)
|
||||||
|
|
||||||
|
assert c1_near_origin != c2_near_origin
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
def test_identical_points(self, euclidean_distance):
|
||||||
|
# All points are identical
|
||||||
|
points = [[1.0, 1.0]] * 5
|
||||||
|
centroid1, centroid2 = KMeansBinary(points, euclidean_distance)
|
||||||
|
|
||||||
|
# Both centroids should be the same point
|
||||||
|
assert points_close(centroid1, [1.0, 1.0])
|
||||||
|
assert points_close(centroid2, [1.0, 1.0])
|
||||||
|
|
||||||
|
def test_collinear_points(self, euclidean_distance):
|
||||||
|
# Points on a line
|
||||||
|
points = [[i, 0] for i in range(6)]
|
||||||
|
centroid1, centroid2 = KMeansBinary(points, euclidean_distance)
|
||||||
|
|
||||||
|
# Should still produce two clusters
|
||||||
|
assert len(centroid1) == 2
|
||||||
|
assert len(centroid2) == 2
|
||||||
|
assert centroid1[1] == 0 # y-coordinate should be 0
|
||||||
|
assert centroid2[1] == 0 # y-coordinate should be 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
197
src/utils.py
Normal file
197
src/utils.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
from typing import Callable, Literal
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
|
||||||
|
Point = list[float]
|
||||||
|
Metric = Callable[[Point, Point], float]
|
||||||
|
VectorTransformation = Callable[[Point], Point]
|
||||||
|
IndexAlgorithm = Literal["blob", "hnsw"]
|
||||||
|
EmbeddingSize = Literal[256, 512, 1024, 1536]
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_list(_list: list[str], chunk_size: int = 50) -> list[list[str]]:
|
||||||
|
"""
|
||||||
|
This function takes a list of strings, and splits it into
|
||||||
|
a list of lists of strings, in which each entry is a list with at most
|
||||||
|
`chunk_size` elements.
|
||||||
|
"""
|
||||||
|
_range = range(0, len(_list), chunk_size)
|
||||||
|
return [_list[i : i + chunk_size] for i in _range]
|
||||||
|
|
||||||
|
|
||||||
|
def text_to_chunks(
|
||||||
|
text: str,
|
||||||
|
chunk_size: int,
|
||||||
|
chunk_overlap: int,
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
This function takes a string, and creates of list of strings
|
||||||
|
in which each is at most `chunk_size` in length, and each
|
||||||
|
consecutive set of strings, overlaps in `chunk_overlap` amount
|
||||||
|
of characters.
|
||||||
|
"""
|
||||||
|
if len(text) <= chunk_size:
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
if chunk_overlap >= chunk_size:
|
||||||
|
raise ValueError("chunk_overlap must be less than chunk_size")
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
|
||||||
|
while start < len(text):
|
||||||
|
end = start + chunk_size
|
||||||
|
|
||||||
|
if end >= len(text):
|
||||||
|
chunk = text[start:]
|
||||||
|
if len(chunk) < chunk_overlap and chunks:
|
||||||
|
chunks[-1] = chunks[-1] + " " + chunk
|
||||||
|
else:
|
||||||
|
chunks.append(chunk)
|
||||||
|
break
|
||||||
|
|
||||||
|
chunk = text[start:end]
|
||||||
|
|
||||||
|
if end < len(text): # Not the last chunk
|
||||||
|
last_space = chunk.rfind(" ", max(0, len(chunk) - 50))
|
||||||
|
if last_space > len(chunk) // 2: # Only if we find a reasonable break point
|
||||||
|
chunk = chunk[:last_space]
|
||||||
|
end = start + last_space
|
||||||
|
|
||||||
|
chunks.append(chunk.strip())
|
||||||
|
|
||||||
|
# Move start position (with overlap)
|
||||||
|
start = end - chunk_overlap
|
||||||
|
|
||||||
|
chunks = [chunk for chunk in chunks if chunk.strip()]
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Here are some vector math utils
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def identity_transformation(v: Point):
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def vector_dot_product(a: Point, b: Point) -> float:
|
||||||
|
return sum([i * j for i, j in zip(a, b)])
|
||||||
|
|
||||||
|
|
||||||
|
def vector_scalar_multiplication(scalar: float | int, vector: Point) -> Point:
|
||||||
|
return [scalar * v for v in vector]
|
||||||
|
|
||||||
|
|
||||||
|
def vector_sum(a: Point, b: Point) -> Point:
|
||||||
|
assert len(a) == len(b), "Inconsistent arguments provided."
|
||||||
|
return [i + j for i, j in zip(a, b)]
|
||||||
|
|
||||||
|
|
||||||
|
def l_distance_generator(L: int = 2) -> Metric:
|
||||||
|
"""
|
||||||
|
This function returns a metric function which follows
|
||||||
|
the l-norms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _wrapper(a: Point, b: Point) -> float:
|
||||||
|
assert len(a) == len(b), "Inconsistent arguments provided."
|
||||||
|
return sum(abs(i - j) ** L for i, j in zip(a, b)) ** (1 / L)
|
||||||
|
|
||||||
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def euclidean_distance(a: Point, b: Point) -> float:
|
||||||
|
assert len(a) == len(b), "Inconsistent arguments provided."
|
||||||
|
difference = vector_sum(a, vector_scalar_multiplication(-1, b))
|
||||||
|
return vector_dot_product(difference, difference) ** 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_distance(a: Point, b: Point) -> float:
|
||||||
|
"""
|
||||||
|
We take a [0, 2] distance value as a metric.
|
||||||
|
"""
|
||||||
|
assert len(a) == len(b), "Inconsistent arguments provided."
|
||||||
|
return 1 - vector_dot_product(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_centroid(points: list[Point]) -> Point:
|
||||||
|
point_count = len(points)
|
||||||
|
if point_count == 0:
|
||||||
|
return []
|
||||||
|
zero = vector_scalar_multiplication(0, points[0])
|
||||||
|
centroid = reduce(lambda a, b: vector_sum(a, b), points, zero)
|
||||||
|
return vector_scalar_multiplication(1 / point_count, centroid)
|
||||||
|
|
||||||
|
|
||||||
|
def classify_point_on_clusters(
|
||||||
|
point: Point,
|
||||||
|
centers: tuple[Point, Point],
|
||||||
|
metric: Metric,
|
||||||
|
) -> bool:
|
||||||
|
|
||||||
|
return metric(point, centers[0]) < metric(point, centers[1])
|
||||||
|
|
||||||
|
|
||||||
|
def compare_clusters(cluster_a: list[Point], cluster_b: list[Point]) -> bool:
|
||||||
|
return set(map(tuple, cluster_a)) == set(map(tuple, cluster_b))
|
||||||
|
|
||||||
|
|
||||||
|
def KMeansBinary(
|
||||||
|
points: list[Point],
|
||||||
|
metric: Metric,
|
||||||
|
max_iterations: int = 100,
|
||||||
|
) -> tuple[Point, Point]:
|
||||||
|
zero = vector_scalar_multiplication(0, points[0])
|
||||||
|
center = vector_scalar_multiplication(
|
||||||
|
1 / len(points),
|
||||||
|
reduce(lambda a, b: vector_sum(a, b), points, zero),
|
||||||
|
)
|
||||||
|
|
||||||
|
point_distances = [(point, metric(center, point)) for point in points]
|
||||||
|
point_distances.sort(key=lambda x: x[1])
|
||||||
|
|
||||||
|
mid = len(points) // 2
|
||||||
|
cluster_1 = [point for point, _ in point_distances[:mid]]
|
||||||
|
cluster_2 = [point for point, _ in point_distances[mid:]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(max_iterations):
|
||||||
|
centroids = (calculate_centroid(cluster_1), calculate_centroid(cluster_2))
|
||||||
|
new_cluster_1, new_cluster_2 = [], []
|
||||||
|
for point in points:
|
||||||
|
if classify_point_on_clusters(point, centroids, metric):
|
||||||
|
new_cluster_1.append(point)
|
||||||
|
else:
|
||||||
|
new_cluster_2.append(point)
|
||||||
|
|
||||||
|
same_clusters = any(
|
||||||
|
(
|
||||||
|
compare_clusters(cluster_1, new_cluster_1),
|
||||||
|
compare_clusters(cluster_1, new_cluster_2),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if same_clusters:
|
||||||
|
break
|
||||||
|
if not new_cluster_1 or not new_cluster_2:
|
||||||
|
|
||||||
|
point_distances = [(point, metric(centroids[0], point)) for point in points]
|
||||||
|
point_distances.sort(key=lambda x: x[1])
|
||||||
|
|
||||||
|
mid = len(points) // 2
|
||||||
|
new_cluster_1 = [point for point, _ in point_distances[:mid]]
|
||||||
|
new_cluster_2 = [point for point, _ in point_distances[mid:]]
|
||||||
|
|
||||||
|
if not new_cluster_2 and len(points) > 1:
|
||||||
|
new_cluster_2 = [new_cluster_1.pop()]
|
||||||
|
|
||||||
|
cluster_1, cluster_2 = new_cluster_1, new_cluster_2
|
||||||
|
|
||||||
|
centroid_1 = calculate_centroid(cluster_1)
|
||||||
|
centroid_2 = calculate_centroid(cluster_2)
|
||||||
|
|
||||||
|
return centroid_1, centroid_2
|
||||||
Loading…
Reference in New Issue
Block a user