2026-04-21 14:32:10 +02:00
|
|
|
from typing import List
|
2026-04-21 15:28:20 +02:00
|
|
|
from uuid import uuid4
|
|
|
|
|
|
2026-04-21 17:13:43 +02:00
|
|
|
import chromadb
|
2026-04-21 18:24:49 +02:00
|
|
|
from chromadb.api import ClientAPI
|
2026-04-21 17:20:45 +02:00
|
|
|
from chromadb.api.types import QueryResult
|
2026-04-21 17:13:43 +02:00
|
|
|
from chromadb.errors import NotFoundError
|
|
|
|
|
|
2026-04-21 15:28:20 +02:00
|
|
|
from embed import EmbeddingRecord
|
2026-04-21 14:32:10 +02:00
|
|
|
|
|
|
|
|
|
2026-04-21 18:24:49 +02:00
|
|
|
def _get_client_and_collection(
|
|
|
|
|
collection_name: str,
|
|
|
|
|
) -> tuple[ClientAPI, chromadb.Collection]:
|
|
|
|
|
client = chromadb.PersistentClient()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
collection = client.get_collection(name=collection_name)
|
|
|
|
|
except NotFoundError:
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
return client, collection
|
|
|
|
|
|
|
|
|
|
|
2026-04-21 14:32:10 +02:00
|
|
|
def list_collections() -> List[str]:
|
|
|
|
|
client = chromadb.PersistentClient()
|
|
|
|
|
collections = client.list_collections()
|
|
|
|
|
|
|
|
|
|
if not collections:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
return [getattr(collection, "name", str(collection)) for collection in collections]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_collection(name: str) -> str:
|
|
|
|
|
client = chromadb.PersistentClient()
|
|
|
|
|
collection = client.create_collection(name=name)
|
2026-04-21 14:45:01 +02:00
|
|
|
|
2026-04-21 14:32:10 +02:00
|
|
|
return getattr(collection, "name", name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def delete_collection(name: str) -> None:
|
|
|
|
|
client = chromadb.PersistentClient()
|
|
|
|
|
client.delete_collection(name=name)
|
|
|
|
|
|
|
|
|
|
|
2026-04-21 18:24:49 +02:00
|
|
|
def count_collection(collection_name: str) -> int:
|
|
|
|
|
_, collection = _get_client_and_collection(collection_name)
|
2026-04-21 14:45:01 +02:00
|
|
|
|
|
|
|
|
return collection.count()
|
2026-04-21 15:28:20 +02:00
|
|
|
|
|
|
|
|
|
2026-04-21 18:24:49 +02:00
|
|
|
def add_data(collection_name: str, data: List[EmbeddingRecord], file_name: str) -> None:
|
2026-04-21 15:28:20 +02:00
|
|
|
if not data:
|
|
|
|
|
return
|
|
|
|
|
|
2026-04-21 18:24:49 +02:00
|
|
|
_, collection = _get_client_and_collection(collection_name)
|
2026-04-21 15:28:20 +02:00
|
|
|
|
2026-04-21 18:24:49 +02:00
|
|
|
collection.add(
|
2026-04-21 15:28:20 +02:00
|
|
|
ids=[str(uuid4()) for _ in data],
|
2026-04-21 18:24:49 +02:00
|
|
|
metadatas=[{"file_name": file_name} for _ in data],
|
2026-04-21 15:28:20 +02:00
|
|
|
documents=[record["text"] for record in data],
|
|
|
|
|
embeddings=[record["embedding"] for record in data],
|
|
|
|
|
)
|
2026-04-21 17:13:43 +02:00
|
|
|
|
|
|
|
|
|
2026-04-21 17:20:45 +02:00
|
|
|
def query_data(collection_name: str, texts: list[str]) -> QueryResult:
|
|
|
|
|
if not texts:
|
|
|
|
|
return {
|
|
|
|
|
"ids": [],
|
|
|
|
|
"documents": [],
|
|
|
|
|
"metadatas": [],
|
|
|
|
|
"distances": [],
|
|
|
|
|
"embeddings": [],
|
|
|
|
|
}
|
|
|
|
|
|
2026-04-21 18:24:49 +02:00
|
|
|
_, collection = _get_client_and_collection(collection_name)
|
2026-04-21 17:20:45 +02:00
|
|
|
|
|
|
|
|
return collection.query(query_texts=texts)
|