Compare commits

..

4 Commits

Author SHA1 Message Date
mrosati d71fce7a6a cannot import non-text files!
build / build (push) Successful in 39s
pytest / pytest (push) Successful in 35s
2026-04-24 18:40:51 +02:00
mrosati c6ad060e85 fix types and print middle dot in collections list 2026-04-24 18:28:03 +02:00
mrosati c5b6b196b5 fix syntax and types 2026-04-24 18:23:02 +02:00
mrosati 948f8500be types cleanup 2026-04-24 18:20:22 +02:00
10 changed files with 91 additions and 40 deletions
+10 -5
View File
@@ -8,6 +8,7 @@ import chromadb
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from chromadb.api.types import QueryResult, Where from chromadb.api.types import QueryResult, Where
from chromadb.errors import NotFoundError from chromadb.errors import NotFoundError
from rich.text import Text
from chromy.embed import EmbeddingRecord from chromy.embed import EmbeddingRecord
@@ -25,14 +26,17 @@ def _get_client_and_collection(
return client, collection return client, collection
def list_collections() -> list[str]: def list_collections() -> list[Text]:
client = chromadb.PersistentClient() client = chromadb.PersistentClient()
collections = client.list_collections() collections = client.list_collections()
if not collections: if not collections:
return [] return []
return [getattr(collection, "name", str(collection)) for collection in collections] return [
Text("· " + getattr(collection, "name", str(collection)))
for collection in collections
]
def create_collection(name: str) -> str: def create_collection(name: str) -> str:
@@ -58,7 +62,9 @@ def count_collection(collection_name: str) -> str:
_, collection = _get_client_and_collection(collection_name) _, collection = _get_client_and_collection(collection_name)
count = collection.count() count = collection.count()
return f"The '{collection_name}' collection contains [bold green]{count}[/] records." return (
f"The '{collection_name}' collection contains [bold green]{count}[/] records."
)
def add_data( def add_data(
@@ -71,8 +77,7 @@ def add_data(
_, collection = _get_client_and_collection(collection_name) _, collection = _get_client_and_collection(collection_name)
embeddings: list[Sequence[float]] = [record["embedding"] embeddings: list[Sequence[float]] = [record["embedding"] for record in data]
for record in data]
collection.add( collection.add(
ids=[str(uuid4()) for _ in data], ids=[str(uuid4()) for _ in data],
+7 -5
View File
@@ -1,18 +1,19 @@
from __future__ import annotations from __future__ import annotations
from plistlib import InvalidFileException
from typing import Annotated, Callable from typing import Annotated, Callable
import typer import typer
from rich import print
from chromadb.errors import InternalError, NotFoundError from chromadb.errors import InternalError, NotFoundError
from rich import print
from chromy.handlers.import_data import handle_import
from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import ( from chromy.handlers.delete_collection import (
handle_delete_collection, handle_delete_collection,
handle_delete_records, handle_delete_records,
) )
from chromy.handlers.import_data import handle_import
from chromy.handlers.list_collections import handle_list_collections from chromy.handlers.list_collections import handle_list_collections
from chromy.handlers.query import handle_query from chromy.handlers.query import handle_query
@@ -114,8 +115,7 @@ def import_data(
], ],
file: Annotated[ file: Annotated[
str, str,
typer.Argument( typer.Argument(help="Path to the file to chunk and add to the collection."),
help="Path to the file to chunk and add to the collection."),
], ],
) -> None: ) -> None:
try: try:
@@ -123,7 +123,9 @@ def import_data(
except NotFoundError: except NotFoundError:
_fail(f"Collection '{collection}' does not exist.") _fail(f"Collection '{collection}' does not exist.")
except FileNotFoundError: except FileNotFoundError:
_fail(f"The file {file} was not found.") _fail(f"The file '{file}' was not found.")
except InvalidFileException:
_fail(f"The file '{file}' is not a text file.")
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
+1
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from rich import print from rich import print
from chromy.chroma_functions import count_collection from chromy.chroma_functions import count_collection
+1
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from rich import print from rich import print
from chromy.chroma_functions import create_collection from chromy.chroma_functions import create_collection
+3 -4
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from rich import print from rich import print
from chromy.chroma_functions import delete_collection, delete_data from chromy.chroma_functions import delete_collection, delete_data
@@ -8,15 +9,13 @@ def _parse_where_clause(where_clause: str) -> dict[str, str]:
condition, separator, value = where_clause.partition("=") condition, separator, value = where_clause.partition("=")
if separator == "": if separator == "":
raise ValueError( raise ValueError("Invalid --where value. Expected <condition>=<value>.")
"Invalid --where value. Expected <condition>=<value>.")
condition = condition.strip() condition = condition.strip()
value = value.strip() value = value.strip()
if not condition or not value: if not condition or not value:
raise ValueError( raise ValueError("Invalid --where value. Expected <condition>=<value>.")
"Invalid --where value. Expected <condition>=<value>.")
return {condition: value} return {condition: value}
+12 -3
View File
@@ -1,10 +1,15 @@
from __future__ import annotations from __future__ import annotations
import os import os
from pathlib import Path from pathlib import Path
from plistlib import InvalidFileException
from rich import print from rich import print
from chromy.utilities import ingest_file from chromy.utilities import ingest_file
from ..utilities import is_probably_text_file
def _get_absolute_path(file: str) -> str: def _get_absolute_path(file: str) -> str:
""" """
@@ -21,11 +26,15 @@ def _get_absolute_path(file: str) -> str:
raise FileNotFoundError() raise FileNotFoundError()
file_path = Path(file) file_path = Path(file)
return str(file_path.resolve(file_path)) return str(file_path.resolve())
def handle_import(collection: str, file: str) -> int: def handle_import(collection: str, file: str) -> int:
absolute_path = _get_absolute_path(file)
if not is_probably_text_file(absolute_path):
raise InvalidFileException()
records_added = ingest_file(collection, _get_absolute_path(file)) records_added = ingest_file(collection, _get_absolute_path(file))
print( print(f"[bold green]Added[/] {records_added} records to collection '{collection}'.")
f"[bold green]Added[/] {records_added} records to collection '{collection}'.")
return 0 return 0
+2
View File
@@ -6,9 +6,11 @@ from chromy.utilities import print_lines
def handle_list_collections() -> int: def handle_list_collections() -> int:
collections = list_collections() collections = list_collections()
if not collections: if not collections:
print("No collections found.") print("No collections found.")
return 0 return 0
print_lines(collections) print_lines(collections)
return 0 return 0
+47 -10
View File
@@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
from rich.text import Text
from rich.rule import Rule
from rich.console import Console
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from pathlib import Path
from chromadb import QueryResult from chromadb import QueryResult
from rich.console import Console
from rich.rule import Rule
from rich.text import Text
from chromy.chroma_functions import add_data, query_data from chromy.chroma_functions import add_data, query_data
from chromy.chunk_functions import chunk_file from chromy.chunk_functions import chunk_file
@@ -15,7 +15,7 @@ from chromy.embed import embed
CONSOLE = Console() CONSOLE = Console()
def print_lines(lines: Sequence[str]) -> None: def print_lines(lines: Sequence[Rule | Text]) -> None:
for line in lines: for line in lines:
CONSOLE.print(line) CONSOLE.print(line)
@@ -31,7 +31,7 @@ def run_query(collection_name: str, query_text: str) -> QueryResult:
return query_data(collection_name, [query_text]) return query_data(collection_name, [query_text])
def format_query_result(result: QueryResult) -> list[str]: def format_query_result(result: QueryResult) -> list[Rule | Text]:
ids = result.get("ids", [[]]) ids = result.get("ids", [[]])
documents = result.get("documents", [[]]) documents = result.get("documents", [[]])
distances = result.get("distances", [[]]) distances = result.get("distances", [[]])
@@ -43,12 +43,11 @@ def format_query_result(result: QueryResult) -> list[str]:
first_metadatas = metadatas[0] if metadatas else [] first_metadatas = metadatas[0] if metadatas else []
if not first_ids: if not first_ids:
return ["No results found."] return [Text.from_markup("[yellow]No results found.[/]")]
lines = [Rule(title="Query results")] lines: list[Rule | Text] = [Rule(title="Query results")]
for index, document_id in enumerate(first_ids, start=1): for index, document_id in enumerate(first_ids, start=1):
# lines.append(f"{index}.\tid: {document_id}")
lines.append( lines.append(
Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}") Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}")
) )
@@ -72,9 +71,47 @@ def format_query_result(result: QueryResult) -> list[str]:
if i < len(first_documents): if i < len(first_documents):
lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n")) lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n"))
lines.append(first_documents[i]) lines.append(Text(first_documents[i]))
# Print a separator between documents # Print a separator between documents
lines.append(Rule()) lines.append(Rule())
return lines return lines
def is_probably_text_file(path: str | Path, sample_size: int = 8192) -> bool:
"""
Return whether a file appears to contain text.
Args:
path (str | Path): The path to the file to inspect.
sample_size (int): The maximum number of bytes to read from the file.
Returns:
bool: ``True`` if the sampled bytes decode as UTF-8, UTF-8 with BOM,
UTF-16, or UTF-32, or if the file is empty. Otherwise, ``False``.
"""
path = Path(path)
with path.open("rb") as f:
sample = f.read(sample_size)
if not sample:
return True
encodings = (
"utf-8",
"utf-8-sig",
"utf-16",
"utf-32",
)
for encoding in encodings:
try:
sample.decode(encoding)
return True
except UnicodeDecodeError:
pass
return False
+6 -11
View File
@@ -51,15 +51,13 @@ class CliTests(unittest.TestCase):
def test_create_collection_with_same_name(self) -> None: def test_create_collection_with_same_name(self) -> None:
with patch( with patch(
"chromy.handlers.create_collection.create_collection", "chromy.handlers.create_collection.create_collection",
side_effect=InternalError() side_effect=InternalError(),
) as create_collection: ) as create_collection:
result = _invoke(["create-collection", "notes"]) result = _invoke(["create-collection", "notes"])
create_collection.assert_called_once_with("notes") create_collection.assert_called_once_with("notes")
self.assertEqual(result.exit_code, 1) self.assertEqual(result.exit_code, 1)
self.assertEqual( self.assertEqual(result.stdout, "Error: Collection 'notes' already exists.\n")
result.stdout, "Error: Collection 'notes' already exists.\n")
def test_delete_collection(self) -> None: def test_delete_collection(self) -> None:
with patch( with patch(
@@ -74,14 +72,13 @@ class CliTests(unittest.TestCase):
def test_delete_non_existent_collection(self) -> None: def test_delete_non_existent_collection(self) -> None:
with patch( with patch(
"chromy.handlers.delete_collection.delete_collection", "chromy.handlers.delete_collection.delete_collection",
side_effect=NotFoundError() side_effect=NotFoundError(),
) as delete_collection: ) as delete_collection:
result = _invoke(["delete-collection", "notes"]) result = _invoke(["delete-collection", "notes"])
delete_collection.assert_called_once_with("notes") delete_collection.assert_called_once_with("notes")
self.assertEqual(result.exit_code, 1) self.assertEqual(result.exit_code, 1)
self.assertEqual( self.assertEqual(result.stdout, "Error: Collection 'notes' does not exist.\n")
result.stdout, "Error: Collection 'notes' does not exist.\n")
def test_count(self) -> None: def test_count(self) -> None:
with patch( with patch(
@@ -106,8 +103,7 @@ class CliTests(unittest.TestCase):
self._fixture_path("romeo_and_juliet.txt"), self._fixture_path("romeo_and_juliet.txt"),
) )
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
self.assertEqual( self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n")
result.stdout, "Added 3 records to collection 'notes'.\n")
def test_query(self) -> None: def test_query(self) -> None:
query_result = {"ids": [["1"]], "documents": [["hello"]]} query_result = {"ids": [["1"]], "documents": [["hello"]]}
@@ -139,8 +135,7 @@ class CliTests(unittest.TestCase):
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
self.assertEqual( self.assertEqual(
result.stdout, result.stdout,
"Deleted 2 record(s) from collection 'notes' " "Deleted 2 record(s) from collection 'notes' where file_name=play.txt.\n",
"where file_name=play.txt.\n",
) )
def test_invalid_delete_filter_keeps_user_facing_error(self) -> None: def test_invalid_delete_filter_keeps_user_facing_error(self) -> None:
+2 -2
View File
@@ -8,13 +8,13 @@ from pathlib import Path
from typing import TypeVar from typing import TypeVar
from unittest.mock import patch from unittest.mock import patch
from chromy.handlers.import_data import handle_import
from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import ( from chromy.handlers.delete_collection import (
handle_delete_collection, handle_delete_collection,
handle_delete_records, handle_delete_records,
) )
from chromy.handlers.import_data import handle_import
from chromy.handlers.list_collections import handle_list_collections from chromy.handlers.list_collections import handle_list_collections
from chromy.handlers.query import handle_query from chromy.handlers.query import handle_query
@@ -154,7 +154,7 @@ class HandlerTests(unittest.TestCase):
def _capture_output( def _capture_output(
handler: Callable[..., int], handler: Callable[..., int],
*arguments: CommandT, *arguments: object,
) -> tuple[int, str]: ) -> tuple[int, str]:
output = io.StringIO() output = io.StringIO()