decouple core data from CLI formatting
build / build (push) Successful in 49s
pytest / pytest (push) Successful in 30s

This commit is contained in:
Matteo Rosati
2026-04-29 12:44:28 +02:00
parent 615ab14a1a
commit d1b1238897
12 changed files with 142 additions and 87 deletions
+4 -12
View File
@@ -8,7 +8,6 @@ import chromadb
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from chromadb.api.types import QueryResult, Where from chromadb.api.types import QueryResult, Where
from chromadb.errors import NotFoundError from chromadb.errors import NotFoundError
from rich.text import Text
from chromy.embed import EmbeddingRecord from chromy.embed import EmbeddingRecord
@@ -26,17 +25,14 @@ def _get_client_and_collection(
return client, collection return client, collection
def list_collections() -> list[Text]: def list_collections() -> list[str]:
client = chromadb.PersistentClient() client = chromadb.PersistentClient()
collections = client.list_collections() collections = client.list_collections()
if not collections: if not collections:
return [] return []
return [ return [getattr(collection, "name", str(collection)) for collection in collections]
Text("· " + getattr(collection, "name", str(collection)))
for collection in collections
]
def create_collection(name: str) -> str: def create_collection(name: str) -> str:
@@ -58,13 +54,9 @@ def delete_data(collection_name: str, where: dict[str, str]) -> int:
return int(result.get("deleted", 0)) return int(result.get("deleted", 0))
def count_collection(collection_name: str) -> str: def count_collection(collection_name: str) -> int:
_, collection = _get_client_and_collection(collection_name) _, collection = _get_client_and_collection(collection_name)
count = collection.count() return collection.count()
return (
f"The '{collection_name}' collection contains [bold green]{count}[/] records."
)
def add_data( def add_data(
+2 -2
View File
@@ -1,12 +1,12 @@
from __future__ import annotations from __future__ import annotations
from plistlib import InvalidFileException
from typing import Annotated, Callable from typing import Annotated, Callable
import typer import typer
from chromadb.errors import InternalError, NotFoundError from chromadb.errors import InternalError, NotFoundError
from rich import print from rich import print
from chromy.errors import UnsupportedTextFileError
from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import ( from chromy.handlers.delete_collection import (
@@ -124,7 +124,7 @@ def import_data(
_fail(f"Collection '{collection}' does not exist.") _fail(f"Collection '{collection}' does not exist.")
except FileNotFoundError: except FileNotFoundError:
_fail(f"The file '{file}' was not found.") _fail(f"The file '{file}' was not found.")
except InvalidFileException: except UnsupportedTextFileError:
_fail(f"The file '{file}' is not a text file.") _fail(f"The file '{file}' is not a text file.")
+5
View File
@@ -0,0 +1,5 @@
from __future__ import annotations
class UnsupportedTextFileError(Exception):
"""Raised when a file does not appear to contain supported text content."""
+2 -1
View File
@@ -3,8 +3,9 @@ from __future__ import annotations
from rich import print from rich import print
from chromy.chroma_functions import count_collection from chromy.chroma_functions import count_collection
from chromy.output import format_count_message
def handle_count_collection(collection: str) -> int: def handle_count_collection(collection: str) -> int:
print(count_collection(collection)) print(format_count_message(collection, count_collection(collection)))
return 0 return 0
+3 -3
View File
@@ -2,10 +2,10 @@ from __future__ import annotations
import os import os
from pathlib import Path from pathlib import Path
from plistlib import InvalidFileException
from rich import print from rich import print
from chromy.errors import UnsupportedTextFileError
from chromy.utilities import ingest_file from chromy.utilities import ingest_file
from ..utilities import is_probably_text_file from ..utilities import is_probably_text_file
@@ -33,8 +33,8 @@ def handle_import(collection: str, file: str) -> int:
absolute_path = _get_absolute_path(file) absolute_path = _get_absolute_path(file)
if not is_probably_text_file(absolute_path): if not is_probably_text_file(absolute_path):
raise InvalidFileException() raise UnsupportedTextFileError()
records_added = ingest_file(collection, _get_absolute_path(file)) records_added = ingest_file(collection, absolute_path)
print(f"[bold green]Added[/] {records_added} records to collection '{collection}'.") print(f"[bold green]Added[/] {records_added} records to collection '{collection}'.")
return 0 return 0
+2 -2
View File
@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from chromy.chroma_functions import list_collections from chromy.chroma_functions import list_collections
from chromy.utilities import print_lines from chromy.output import format_collection_names, print_lines
def handle_list_collections() -> int: def handle_list_collections() -> int:
@@ -11,6 +11,6 @@ def handle_list_collections() -> int:
print("No collections found.") print("No collections found.")
return 0 return 0
print_lines(collections) print_lines(format_collection_names(collections))
return 0 return 0
+2 -1
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from chromy.utilities import format_query_result, print_lines, run_query from chromy.output import format_query_result, print_lines
from chromy.utilities import run_query
def handle_query(collection: str, query_text: str) -> int: def handle_query(collection: str, query_text: str) -> int:
+70
View File
@@ -0,0 +1,70 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from chromadb import QueryResult
from rich.console import Console
from rich.rule import Rule
from rich.text import Text
CONSOLE = Console()
def print_lines(lines: Sequence[Rule | Text | str]) -> None:
for line in lines:
CONSOLE.print(line)
def format_collection_names(collections: Sequence[str]) -> list[Text]:
return [Text(f"· {collection}") for collection in collections]
def format_count_message(collection_name: str, count: int) -> str:
return (
f"The '{collection_name}' collection contains [bold green]{count}[/] records."
)
def format_query_result(result: QueryResult) -> list[Rule | Text]:
ids = result.get("ids", [[]])
documents = result.get("documents", [[]])
distances = result.get("distances", [[]])
metadatas = result.get("metadatas", [[]])
first_ids = ids[0] if ids else []
first_documents = documents[0] if documents else []
first_distances = distances[0] if distances else []
first_metadatas = metadatas[0] if metadatas else []
if not first_ids:
return [Text.from_markup("[yellow]No results found.[/]")]
lines: list[Rule | Text] = [Rule(title="Query results")]
for index, document_id in enumerate(first_ids, start=1):
lines.append(
Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}")
)
i = index - 1
if i < len(first_distances):
lines.append(
Text.from_markup(f"\t[green]distance[/]\t{first_distances[i]}")
)
if i < len(first_metadatas):
metadata = first_metadatas[i]
if isinstance(metadata, Mapping):
file_name = metadata.get("file_name")
if file_name:
lines.append(Text.from_markup(f"\t[green]file_name[/]\t{file_name}"))
if i < len(first_documents):
lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n"))
lines.append(Text(first_documents[i]))
lines.append(Rule())
return lines
-59
View File
@@ -1,24 +1,13 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping, Sequence
from pathlib import Path from pathlib import Path
from chromadb import QueryResult from chromadb import QueryResult
from rich.console import Console
from rich.rule import Rule
from rich.text import Text
from chromy.chroma_functions import add_data, query_data from chromy.chroma_functions import add_data, query_data
from chromy.chunk_functions import chunk_file from chromy.chunk_functions import chunk_file
from chromy.embed import embed from chromy.embed import embed
CONSOLE = Console()
def print_lines(lines: Sequence[Rule | Text]) -> None:
for line in lines:
CONSOLE.print(line)
def ingest_file(collection_name: str, file_path: str) -> int: def ingest_file(collection_name: str, file_path: str) -> int:
chunks = chunk_file(file_path) chunks = chunk_file(file_path)
@@ -31,54 +20,6 @@ def run_query(collection_name: str, query_text: str) -> QueryResult:
return query_data(collection_name, [query_text]) return query_data(collection_name, [query_text])
def format_query_result(result: QueryResult) -> list[Rule | Text]:
ids = result.get("ids", [[]])
documents = result.get("documents", [[]])
distances = result.get("distances", [[]])
metadatas = result.get("metadatas", [[]])
first_ids = ids[0] if ids else []
first_documents = documents[0] if documents else []
first_distances = distances[0] if distances else []
first_metadatas = metadatas[0] if metadatas else []
if not first_ids:
return [Text.from_markup("[yellow]No results found.[/]")]
lines: list[Rule | Text] = [Rule(title="Query results")]
for index, document_id in enumerate(first_ids, start=1):
lines.append(
Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}")
)
i = index - 1
if i < len(first_distances):
lines.append(
Text.from_markup(f"\t[green]distance[/]\t{first_distances[i]}")
)
if i < len(first_metadatas):
metadata = first_metadatas[i]
if isinstance(metadata, Mapping):
file_name = metadata.get("file_name")
if file_name:
lines.append(
Text.from_markup(f"\t[green]file_name[/]\t{file_name}")
)
if i < len(first_documents):
lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n"))
lines.append(Text(first_documents[i]))
# Print a separator between documents
lines.append(Rule())
return lines
def is_probably_text_file(path: str | Path, sample_size: int = 8192) -> bool: def is_probably_text_file(path: str | Path, sample_size: int = 8192) -> bool:
""" """
Return whether a file appears to contain text. Return whether a file appears to contain text.
+18 -2
View File
@@ -35,7 +35,7 @@ class CliTests(unittest.TestCase):
result = _invoke(["list-collections"]) result = _invoke(["list-collections"])
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
self.assertEqual(result.stdout, "books\ncode\n") self.assertEqual(result.stdout, "· books\n· code\n")
def test_create_collection(self) -> None: def test_create_collection(self) -> None:
with patch( with patch(
@@ -89,7 +89,10 @@ class CliTests(unittest.TestCase):
count_collection.assert_called_once_with("notes") count_collection.assert_called_once_with("notes")
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
self.assertEqual(result.stdout, "7\n") self.assertEqual(
result.stdout,
"The 'notes' collection contains 7 records.\n",
)
def test_import_data(self) -> None: def test_import_data(self) -> None:
with patch( with patch(
@@ -105,6 +108,19 @@ class CliTests(unittest.TestCase):
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n") self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n")
def test_import_data_rejects_non_text_files(self) -> None:
with patch(
"chromy.handlers.import_data.is_probably_text_file",
return_value=False,
):
result = _invoke(["import", "notes", "romeo_and_juliet.txt"])
self.assertEqual(result.exit_code, 1)
self.assertEqual(
result.stdout,
"Error: The file 'romeo_and_juliet.txt' is not a text file.\n",
)
def test_query(self) -> None: def test_query(self) -> None:
query_result = {"ids": [["1"]], "documents": [["hello"]]} query_result = {"ids": [["1"]], "documents": [["hello"]]}
+20 -2
View File
@@ -1,11 +1,29 @@
from __future__ import annotations from __future__ import annotations
import unittest import unittest
from unittest.mock import patch
from chromy.embed import embed
class EmbedTest(unittest.TestCase): class EmbedTest(unittest.TestCase):
def test_embed_function(self) -> None: def test_embed_returns_empty_list_for_empty_chunks(self) -> None:
self.assertEqual(0, 0) self.assertEqual(embed([]), [])
def test_embed_pairs_text_with_list_embeddings(self) -> None:
with patch(
"chromy.embed.DefaultEmbeddingFunction",
return_value=lambda chunks: ((1.0, 2.0), (3.0, 4.0)),
):
result = embed(["first", "second"])
self.assertEqual(
result,
[
{"text": "first", "embedding": [1.0, 2.0]},
{"text": "second", "embedding": [3.0, 4.0]},
],
)
if __name__ == "__main__": if __name__ == "__main__":
+13 -2
View File
@@ -8,6 +8,7 @@ from pathlib import Path
from typing import TypeVar from typing import TypeVar
from unittest.mock import patch from unittest.mock import patch
from chromy.errors import UnsupportedTextFileError
from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import ( from chromy.handlers.delete_collection import (
@@ -47,7 +48,7 @@ class HandlerTests(unittest.TestCase):
) )
self.assertEqual(exit_code, 0) self.assertEqual(exit_code, 0)
self.assertEqual(output, "notes\nplays\n") self.assertEqual(output, "· notes\n· plays\n")
def test_create_collection_uses_typed_input(self) -> None: def test_create_collection_uses_typed_input(self) -> None:
with patch( with patch(
@@ -86,7 +87,7 @@ class HandlerTests(unittest.TestCase):
count.assert_called_once_with("notes") count.assert_called_once_with("notes")
self.assertEqual(exit_code, 0) self.assertEqual(exit_code, 0)
self.assertEqual(output, "7\n") self.assertEqual(output, "The 'notes' collection contains 7 records.\n")
def test_import_data_uses_typed_input(self) -> None: def test_import_data_uses_typed_input(self) -> None:
with patch( with patch(
@@ -106,6 +107,16 @@ class HandlerTests(unittest.TestCase):
self.assertEqual(exit_code, 0) self.assertEqual(exit_code, 0)
self.assertEqual(output, "Added 3 records to collection 'notes'.\n") self.assertEqual(output, "Added 3 records to collection 'notes'.\n")
def test_import_data_rejects_non_text_files(self) -> None:
with (
patch(
"chromy.handlers.import_data.is_probably_text_file",
return_value=False,
),
self.assertRaises(UnsupportedTextFileError),
):
handle_import("notes", "romeo_and_juliet.txt")
def test_query_uses_typed_input(self) -> None: def test_query_uses_typed_input(self) -> None:
query_result = {"ids": [["1"]], "documents": [["hello"]]} query_result = {"ids": [["1"]], "documents": [["hello"]]}
with ( with (