Files
Chromy/chroma_functions.py
T

87 lines
2.2 KiB
Python
Raw Normal View History

2026-04-21 14:32:10 +02:00
from typing import List
2026-04-21 15:28:20 +02:00
from uuid import uuid4
2026-04-21 17:13:43 +02:00
import chromadb
2026-04-21 18:24:49 +02:00
from chromadb.api import ClientAPI
2026-04-21 17:20:45 +02:00
from chromadb.api.types import QueryResult
2026-04-21 17:13:43 +02:00
from chromadb.errors import NotFoundError
2026-04-21 15:28:20 +02:00
from embed import EmbeddingRecord
2026-04-21 14:32:10 +02:00
2026-04-21 18:24:49 +02:00
def _get_client_and_collection(
collection_name: str,
) -> tuple[ClientAPI, chromadb.Collection]:
client = chromadb.PersistentClient()
try:
collection = client.get_collection(name=collection_name)
except NotFoundError:
raise
return client, collection
2026-04-21 14:32:10 +02:00
def list_collections() -> List[str]:
client = chromadb.PersistentClient()
collections = client.list_collections()
if not collections:
return []
return [getattr(collection, "name", str(collection)) for collection in collections]
def create_collection(name: str) -> str:
client = chromadb.PersistentClient()
collection = client.create_collection(name=name)
2026-04-21 14:45:01 +02:00
2026-04-21 14:32:10 +02:00
return getattr(collection, "name", name)
def delete_collection(name: str) -> None:
client = chromadb.PersistentClient()
client.delete_collection(name=name)
2026-04-21 21:26:40 +02:00
def delete_data(collection_name: str, where: dict[str, str]) -> int:
_, collection = _get_client_and_collection(collection_name)
result = collection.delete(where=where)
return int(result.get("deleted", 0))
2026-04-21 18:24:49 +02:00
def count_collection(collection_name: str) -> int:
_, collection = _get_client_and_collection(collection_name)
2026-04-21 14:45:01 +02:00
return collection.count()
2026-04-21 15:28:20 +02:00
2026-04-21 18:24:49 +02:00
def add_data(collection_name: str, data: List[EmbeddingRecord], file_name: str) -> None:
2026-04-21 15:28:20 +02:00
if not data:
return
2026-04-21 18:24:49 +02:00
_, collection = _get_client_and_collection(collection_name)
2026-04-21 15:28:20 +02:00
2026-04-21 18:24:49 +02:00
collection.add(
2026-04-21 15:28:20 +02:00
ids=[str(uuid4()) for _ in data],
2026-04-21 18:24:49 +02:00
metadatas=[{"file_name": file_name} for _ in data],
2026-04-21 15:28:20 +02:00
documents=[record["text"] for record in data],
embeddings=[record["embedding"] for record in data],
)
2026-04-21 17:13:43 +02:00
2026-04-21 17:20:45 +02:00
def query_data(collection_name: str, texts: list[str]) -> QueryResult:
if not texts:
return {
"ids": [],
"documents": [],
"metadatas": [],
"distances": [],
"embeddings": [],
}
2026-04-21 18:24:49 +02:00
_, collection = _get_client_and_collection(collection_name)
2026-04-21 17:20:45 +02:00
return collection.query(query_texts=texts)