Files
Chromy/chromy/chroma_functions.py
T

168 lines
4.6 KiB
Python
Raw Normal View History

2026-04-22 17:03:01 +02:00
from __future__ import annotations
2026-05-06 21:23:37 +02:00
import os
2026-04-22 17:03:01 +02:00
from collections.abc import Sequence
2026-05-06 21:23:37 +02:00
from pathlib import Path
from tempfile import NamedTemporaryFile
2026-04-22 17:03:01 +02:00
from typing import cast
2026-04-21 15:28:20 +02:00
from uuid import uuid4
2026-04-21 17:13:43 +02:00
import chromadb
2026-04-21 18:24:49 +02:00
from chromadb.api import ClientAPI
2026-04-22 17:03:01 +02:00
from chromadb.api.types import QueryResult, Where
2026-04-21 17:13:43 +02:00
from chromadb.errors import NotFoundError
from chromy.embedding import EmbeddingRecord
2026-05-06 21:23:37 +02:00
from chromy.errors import ChromaPathError
CHROMA_FOLDER_ENV_VAR = "CHROMA_FOLDER"
CHROMA_SUBDIRECTORY = "chroma"
def _resolve_persistence_path() -> Path | None:
configured_parent = os.getenv(CHROMA_FOLDER_ENV_VAR)
if configured_parent is None:
return None
trimmed_parent = configured_parent.strip()
if not trimmed_parent:
raise ChromaPathError(
f"{CHROMA_FOLDER_ENV_VAR} is set but empty. Please set a valid parent "
"directory path."
)
parent_path = Path(trimmed_parent).expanduser().resolve()
return parent_path / CHROMA_SUBDIRECTORY
def _ensure_persistence_path_is_usable(path: Path, configured_parent: str) -> None:
try:
path.mkdir(parents=True, exist_ok=True)
if not path.is_dir():
raise ChromaPathError(
f"Configured Chroma directory '{path}' is not a directory."
)
with NamedTemporaryFile(dir=path, prefix=".chromy-write-test-", delete=True):
pass
except ChromaPathError:
raise
except OSError as exc:
raise ChromaPathError(
f"Could not create or access Chroma directory '{path}' from "
f"{CHROMA_FOLDER_ENV_VAR}='{configured_parent}': {exc}"
) from exc
def get_client() -> ClientAPI:
persistence_path = _resolve_persistence_path()
if persistence_path is None:
return chromadb.PersistentClient()
configured_parent = os.getenv(CHROMA_FOLDER_ENV_VAR, "")
_ensure_persistence_path_is_usable(persistence_path, configured_parent)
try:
return chromadb.PersistentClient(path=str(persistence_path))
except Exception as exc: # pragma: no cover - defensive wrapper
raise ChromaPathError(
f"Could not initialize Chroma client at '{persistence_path}' from "
f"{CHROMA_FOLDER_ENV_VAR}='{configured_parent}': {exc}"
) from exc
2026-04-21 14:32:10 +02:00
2026-04-21 18:24:49 +02:00
def _get_client_and_collection(
collection_name: str,
) -> tuple[ClientAPI, chromadb.Collection]:
2026-05-06 21:23:37 +02:00
client = get_client()
2026-04-21 18:24:49 +02:00
try:
collection = client.get_collection(name=collection_name)
except NotFoundError:
raise
return client, collection
2026-04-29 12:44:28 +02:00
def list_collections() -> list[str]:
2026-05-06 21:23:37 +02:00
client = get_client()
2026-04-21 14:32:10 +02:00
collections = client.list_collections()
if not collections:
return []
2026-04-29 12:44:28 +02:00
return [getattr(collection, "name", str(collection)) for collection in collections]
2026-04-21 14:32:10 +02:00
def create_collection(name: str) -> str:
2026-05-06 21:23:37 +02:00
client = get_client()
2026-04-21 14:32:10 +02:00
collection = client.create_collection(name=name)
2026-04-21 14:45:01 +02:00
2026-04-21 14:32:10 +02:00
return getattr(collection, "name", name)
def delete_collection(name: str) -> None:
2026-05-06 21:23:37 +02:00
client = get_client()
2026-04-21 14:32:10 +02:00
client.delete_collection(name=name)
2026-04-21 21:26:40 +02:00
def delete_data(collection_name: str, where: dict[str, str]) -> int:
_, collection = _get_client_and_collection(collection_name)
2026-04-22 17:03:01 +02:00
result = collection.delete(where=cast(Where, where))
2026-04-21 21:26:40 +02:00
return int(result.get("deleted", 0))
2026-04-29 14:46:41 +02:00
def has_data_for_file(collection_name: str, file_name: str) -> bool:
_, collection = _get_client_and_collection(collection_name)
result = collection.get(where=cast(Where, {"file_name": file_name}))
ids = result.get("ids", [])
return len(ids) > 0
2026-04-29 12:44:28 +02:00
def count_collection(collection_name: str) -> int:
2026-04-21 18:24:49 +02:00
_, collection = _get_client_and_collection(collection_name)
2026-04-29 12:44:28 +02:00
return collection.count()
2026-04-21 15:28:20 +02:00
2026-04-22 17:19:14 +02:00
def add_data(
collection_name: str,
data: Sequence[EmbeddingRecord],
file_name: str,
) -> None:
2026-04-21 15:28:20 +02:00
if not data:
return
2026-04-21 18:24:49 +02:00
_, collection = _get_client_and_collection(collection_name)
2026-04-21 15:28:20 +02:00
2026-04-24 18:23:02 +02:00
embeddings: list[Sequence[float]] = [record["embedding"] for record in data]
2026-04-22 17:03:01 +02:00
2026-04-21 18:24:49 +02:00
collection.add(
2026-04-21 15:28:20 +02:00
ids=[str(uuid4()) for _ in data],
2026-04-21 18:24:49 +02:00
metadatas=[{"file_name": file_name} for _ in data],
2026-04-21 15:28:20 +02:00
documents=[record["text"] for record in data],
2026-04-22 17:03:01 +02:00
embeddings=embeddings,
2026-04-21 15:28:20 +02:00
)
2026-04-21 17:13:43 +02:00
2026-04-22 17:19:14 +02:00
def query_data(collection_name: str, texts: Sequence[str]) -> QueryResult:
2026-04-21 17:20:45 +02:00
if not texts:
return {
"ids": [],
"documents": [],
"metadatas": [],
"distances": [],
2026-04-22 17:03:01 +02:00
"embeddings": None,
"uris": None,
"data": None,
"included": ["documents", "metadatas", "distances"],
2026-04-21 17:20:45 +02:00
}
2026-04-21 18:24:49 +02:00
_, collection = _get_client_and_collection(collection_name)
2026-04-21 17:20:45 +02:00
2026-04-22 17:19:14 +02:00
return collection.query(query_texts=list(texts))