configurable directory
build / build (push) Successful in 47s
pytest / pytest (push) Successful in 35s

This commit is contained in:
Matteo Rosati
2026-05-06 21:23:37 +02:00
parent 28ec29f8af
commit 96ccf0396d
7 changed files with 209 additions and 7 deletions
+66 -4
View File
@@ -1,6 +1,9 @@
from __future__ import annotations
import os
from collections.abc import Sequence
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import cast
from uuid import uuid4
@@ -10,12 +13,71 @@ from chromadb.api.types import QueryResult, Where
from chromadb.errors import NotFoundError
from chromy.embedding import EmbeddingRecord
from chromy.errors import ChromaPathError
CHROMA_FOLDER_ENV_VAR = "CHROMA_FOLDER"
CHROMA_SUBDIRECTORY = "chroma"
def _resolve_persistence_path() -> Path | None:
configured_parent = os.getenv(CHROMA_FOLDER_ENV_VAR)
if configured_parent is None:
return None
trimmed_parent = configured_parent.strip()
if not trimmed_parent:
raise ChromaPathError(
f"{CHROMA_FOLDER_ENV_VAR} is set but empty. Please set a valid parent "
"directory path."
)
parent_path = Path(trimmed_parent).expanduser().resolve()
return parent_path / CHROMA_SUBDIRECTORY
def _ensure_persistence_path_is_usable(path: Path, configured_parent: str) -> None:
try:
path.mkdir(parents=True, exist_ok=True)
if not path.is_dir():
raise ChromaPathError(
f"Configured Chroma directory '{path}' is not a directory."
)
with NamedTemporaryFile(dir=path, prefix=".chromy-write-test-", delete=True):
pass
except ChromaPathError:
raise
except OSError as exc:
raise ChromaPathError(
f"Could not create or access Chroma directory '{path}' from "
f"{CHROMA_FOLDER_ENV_VAR}='{configured_parent}': {exc}"
) from exc
def get_client() -> ClientAPI:
persistence_path = _resolve_persistence_path()
if persistence_path is None:
return chromadb.PersistentClient()
configured_parent = os.getenv(CHROMA_FOLDER_ENV_VAR, "")
_ensure_persistence_path_is_usable(persistence_path, configured_parent)
try:
return chromadb.PersistentClient(path=str(persistence_path))
except Exception as exc: # pragma: no cover - defensive wrapper
raise ChromaPathError(
f"Could not initialize Chroma client at '{persistence_path}' from "
f"{CHROMA_FOLDER_ENV_VAR}='{configured_parent}': {exc}"
) from exc
def _get_client_and_collection(
collection_name: str,
) -> tuple[ClientAPI, chromadb.Collection]:
client = chromadb.PersistentClient()
client = get_client()
try:
collection = client.get_collection(name=collection_name)
@@ -26,7 +88,7 @@ def _get_client_and_collection(
def list_collections() -> list[str]:
client = chromadb.PersistentClient()
client = get_client()
collections = client.list_collections()
if not collections:
@@ -36,14 +98,14 @@ def list_collections() -> list[str]:
def create_collection(name: str) -> str:
client = chromadb.PersistentClient()
client = get_client()
collection = client.create_collection(name=name)
return getattr(collection, "name", name)
def delete_collection(name: str) -> None:
client = chromadb.PersistentClient()
client = get_client()
client.delete_collection(name=name)
+6 -1
View File
@@ -6,6 +6,7 @@ import typer
from chromadb.errors import InternalError, NotFoundError
from rich import print
from chromy.errors import ChromaPathError
from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import (
@@ -22,7 +23,11 @@ ExitCodeHandler = Callable[[], int]
def _run(handler: ExitCodeHandler) -> None:
exit_code = handler()
try:
exit_code = handler()
except ChromaPathError as exc:
_fail(str(exc))
if exit_code != 0:
raise typer.Exit(exit_code)
+4
View File
@@ -3,3 +3,7 @@ from __future__ import annotations
class UnsupportedTextFileError(Exception):
"""Raised when a file does not appear to contain supported text content."""
class ChromaPathError(Exception):
"""Raised when the configured Chroma persistence path is invalid or unusable."""
+1 -1
View File
@@ -61,7 +61,7 @@ def _truncate_file_name(file_name: str, max_length: int = 20) -> str:
if len(file_name) <= max_length:
return file_name
return f"{file_name[: max_length - 3]}"
return f"{file_name[: max_length - 3]}..."
def handle_import(collection: str, files: list[str]) -> int: