Compare commits

...

4 Commits

Author SHA1 Message Date
mrosati d71fce7a6a cannot import non-text files!
build / build (push) Successful in 39s
pytest / pytest (push) Successful in 35s
2026-04-24 18:40:51 +02:00
mrosati c6ad060e85 fix types and print middle dot in collections list 2026-04-24 18:28:03 +02:00
mrosati c5b6b196b5 fix syntax and types 2026-04-24 18:23:02 +02:00
mrosati 948f8500be types cleanup 2026-04-24 18:20:22 +02:00
10 changed files with 91 additions and 40 deletions
+10 -5
View File
@@ -8,6 +8,7 @@ import chromadb
from chromadb.api import ClientAPI
from chromadb.api.types import QueryResult, Where
from chromadb.errors import NotFoundError
from rich.text import Text
from chromy.embed import EmbeddingRecord
@@ -25,14 +26,17 @@ def _get_client_and_collection(
return client, collection
def list_collections() -> list[str]:
def list_collections() -> list[Text]:
client = chromadb.PersistentClient()
collections = client.list_collections()
if not collections:
return []
return [getattr(collection, "name", str(collection)) for collection in collections]
return [
Text("· " + getattr(collection, "name", str(collection)))
for collection in collections
]
def create_collection(name: str) -> str:
@@ -58,7 +62,9 @@ def count_collection(collection_name: str) -> str:
_, collection = _get_client_and_collection(collection_name)
count = collection.count()
return f"The '{collection_name}' collection contains [bold green]{count}[/] records."
return (
f"The '{collection_name}' collection contains [bold green]{count}[/] records."
)
def add_data(
@@ -71,8 +77,7 @@ def add_data(
_, collection = _get_client_and_collection(collection_name)
embeddings: list[Sequence[float]] = [record["embedding"]
for record in data]
embeddings: list[Sequence[float]] = [record["embedding"] for record in data]
collection.add(
ids=[str(uuid4()) for _ in data],
+7 -5
View File
@@ -1,18 +1,19 @@
from __future__ import annotations
from plistlib import InvalidFileException
from typing import Annotated, Callable
import typer
from rich import print
from chromadb.errors import InternalError, NotFoundError
from rich import print
from chromy.handlers.import_data import handle_import
from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import (
handle_delete_collection,
handle_delete_records,
)
from chromy.handlers.import_data import handle_import
from chromy.handlers.list_collections import handle_list_collections
from chromy.handlers.query import handle_query
@@ -114,8 +115,7 @@ def import_data(
],
file: Annotated[
str,
typer.Argument(
help="Path to the file to chunk and add to the collection."),
typer.Argument(help="Path to the file to chunk and add to the collection."),
],
) -> None:
try:
@@ -123,7 +123,9 @@ def import_data(
except NotFoundError:
_fail(f"Collection '{collection}' does not exist.")
except FileNotFoundError:
_fail(f"The file {file} was not found.")
_fail(f"The file '{file}' was not found.")
except InvalidFileException:
_fail(f"The file '{file}' is not a text file.")
# ------------------------------------------------------------------------------
+1
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
from rich import print
from chromy.chroma_functions import count_collection
+1
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
from rich import print
from chromy.chroma_functions import create_collection
+3 -4
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
from rich import print
from chromy.chroma_functions import delete_collection, delete_data
@@ -8,15 +9,13 @@ def _parse_where_clause(where_clause: str) -> dict[str, str]:
condition, separator, value = where_clause.partition("=")
if separator == "":
raise ValueError(
"Invalid --where value. Expected <condition>=<value>.")
raise ValueError("Invalid --where value. Expected <condition>=<value>.")
condition = condition.strip()
value = value.strip()
if not condition or not value:
raise ValueError(
"Invalid --where value. Expected <condition>=<value>.")
raise ValueError("Invalid --where value. Expected <condition>=<value>.")
return {condition: value}
+12 -3
View File
@@ -1,10 +1,15 @@
from __future__ import annotations
import os
from pathlib import Path
from plistlib import InvalidFileException
from rich import print
from chromy.utilities import ingest_file
from ..utilities import is_probably_text_file
def _get_absolute_path(file: str) -> str:
"""
@@ -21,11 +26,15 @@ def _get_absolute_path(file: str) -> str:
raise FileNotFoundError()
file_path = Path(file)
return str(file_path.resolve(file_path))
return str(file_path.resolve())
def handle_import(collection: str, file: str) -> int:
absolute_path = _get_absolute_path(file)
if not is_probably_text_file(absolute_path):
raise InvalidFileException()
records_added = ingest_file(collection, _get_absolute_path(file))
print(
f"[bold green]Added[/] {records_added} records to collection '{collection}'.")
print(f"[bold green]Added[/] {records_added} records to collection '{collection}'.")
return 0
+2
View File
@@ -6,9 +6,11 @@ from chromy.utilities import print_lines
def handle_list_collections() -> int:
collections = list_collections()
if not collections:
print("No collections found.")
return 0
print_lines(collections)
return 0
+47 -10
View File
@@ -1,12 +1,12 @@
from __future__ import annotations
from rich.text import Text
from rich.rule import Rule
from rich.console import Console
from collections.abc import Mapping, Sequence
from pathlib import Path
from chromadb import QueryResult
from rich.console import Console
from rich.rule import Rule
from rich.text import Text
from chromy.chroma_functions import add_data, query_data
from chromy.chunk_functions import chunk_file
@@ -15,7 +15,7 @@ from chromy.embed import embed
CONSOLE = Console()
def print_lines(lines: Sequence[str]) -> None:
def print_lines(lines: Sequence[Rule | Text]) -> None:
for line in lines:
CONSOLE.print(line)
@@ -31,7 +31,7 @@ def run_query(collection_name: str, query_text: str) -> QueryResult:
return query_data(collection_name, [query_text])
def format_query_result(result: QueryResult) -> list[str]:
def format_query_result(result: QueryResult) -> list[Rule | Text]:
ids = result.get("ids", [[]])
documents = result.get("documents", [[]])
distances = result.get("distances", [[]])
@@ -43,12 +43,11 @@ def format_query_result(result: QueryResult) -> list[str]:
first_metadatas = metadatas[0] if metadatas else []
if not first_ids:
return ["No results found."]
return [Text.from_markup("[yellow]No results found.[/]")]
lines = [Rule(title="Query results")]
lines: list[Rule | Text] = [Rule(title="Query results")]
for index, document_id in enumerate(first_ids, start=1):
# lines.append(f"{index}.\tid: {document_id}")
lines.append(
Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}")
)
@@ -72,9 +71,47 @@ def format_query_result(result: QueryResult) -> list[str]:
if i < len(first_documents):
lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n"))
lines.append(first_documents[i])
lines.append(Text(first_documents[i]))
# Print a separator between documents
lines.append(Rule())
return lines
def is_probably_text_file(path: str | Path, sample_size: int = 8192) -> bool:
"""
Return whether a file appears to contain text.
Args:
path (str | Path): The path to the file to inspect.
sample_size (int): The maximum number of bytes to read from the file.
Returns:
bool: ``True`` if the sampled bytes decode as UTF-8, UTF-8 with BOM,
UTF-16, or UTF-32, or if the file is empty. Otherwise, ``False``.
"""
path = Path(path)
with path.open("rb") as f:
sample = f.read(sample_size)
if not sample:
return True
encodings = (
"utf-8",
"utf-8-sig",
"utf-16",
"utf-32",
)
for encoding in encodings:
try:
sample.decode(encoding)
return True
except UnicodeDecodeError:
pass
return False
+6 -11
View File
@@ -51,15 +51,13 @@ class CliTests(unittest.TestCase):
def test_create_collection_with_same_name(self) -> None:
with patch(
"chromy.handlers.create_collection.create_collection",
side_effect=InternalError()
side_effect=InternalError(),
) as create_collection:
result = _invoke(["create-collection", "notes"])
create_collection.assert_called_once_with("notes")
self.assertEqual(result.exit_code, 1)
self.assertEqual(
result.stdout, "Error: Collection 'notes' already exists.\n")
self.assertEqual(result.stdout, "Error: Collection 'notes' already exists.\n")
def test_delete_collection(self) -> None:
with patch(
@@ -74,14 +72,13 @@ class CliTests(unittest.TestCase):
def test_delete_non_existent_collection(self) -> None:
with patch(
"chromy.handlers.delete_collection.delete_collection",
side_effect=NotFoundError()
side_effect=NotFoundError(),
) as delete_collection:
result = _invoke(["delete-collection", "notes"])
delete_collection.assert_called_once_with("notes")
self.assertEqual(result.exit_code, 1)
self.assertEqual(
result.stdout, "Error: Collection 'notes' does not exist.\n")
self.assertEqual(result.stdout, "Error: Collection 'notes' does not exist.\n")
def test_count(self) -> None:
with patch(
@@ -106,8 +103,7 @@ class CliTests(unittest.TestCase):
self._fixture_path("romeo_and_juliet.txt"),
)
self.assertEqual(result.exit_code, 0)
self.assertEqual(
result.stdout, "Added 3 records to collection 'notes'.\n")
self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n")
def test_query(self) -> None:
query_result = {"ids": [["1"]], "documents": [["hello"]]}
@@ -139,8 +135,7 @@ class CliTests(unittest.TestCase):
self.assertEqual(result.exit_code, 0)
self.assertEqual(
result.stdout,
"Deleted 2 record(s) from collection 'notes' "
"where file_name=play.txt.\n",
"Deleted 2 record(s) from collection 'notes' where file_name=play.txt.\n",
)
def test_invalid_delete_filter_keeps_user_facing_error(self) -> None:
+2 -2
View File
@@ -8,13 +8,13 @@ from pathlib import Path
from typing import TypeVar
from unittest.mock import patch
from chromy.handlers.import_data import handle_import
from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import (
handle_delete_collection,
handle_delete_records,
)
from chromy.handlers.import_data import handle_import
from chromy.handlers.list_collections import handle_list_collections
from chromy.handlers.query import handle_query
@@ -154,7 +154,7 @@ class HandlerTests(unittest.TestCase):
def _capture_output(
handler: Callable[..., int],
*arguments: CommandT,
*arguments: object,
) -> tuple[int, str]:
output = io.StringIO()