decouple core data from CLI formatting
build / build (push) Successful in 49s
pytest / pytest (push) Successful in 30s

This commit is contained in:
Matteo Rosati
2026-04-29 12:44:28 +02:00
parent 615ab14a1a
commit d1b1238897
12 changed files with 142 additions and 87 deletions
+4 -12
View File
@@ -8,7 +8,6 @@ import chromadb
from chromadb.api import ClientAPI
from chromadb.api.types import QueryResult, Where
from chromadb.errors import NotFoundError
from rich.text import Text
from chromy.embed import EmbeddingRecord
@@ -26,17 +25,14 @@ def _get_client_and_collection(
return client, collection
def list_collections() -> list[Text]:
def list_collections() -> list[str]:
client = chromadb.PersistentClient()
collections = client.list_collections()
if not collections:
return []
return [
Text("· " + getattr(collection, "name", str(collection)))
for collection in collections
]
return [getattr(collection, "name", str(collection)) for collection in collections]
def create_collection(name: str) -> str:
@@ -58,13 +54,9 @@ def delete_data(collection_name: str, where: dict[str, str]) -> int:
return int(result.get("deleted", 0))
def count_collection(collection_name: str) -> str:
def count_collection(collection_name: str) -> int:
_, collection = _get_client_and_collection(collection_name)
count = collection.count()
return (
f"The '{collection_name}' collection contains [bold green]{count}[/] records."
)
return collection.count()
def add_data(
+2 -2
View File
@@ -1,12 +1,12 @@
from __future__ import annotations
from plistlib import InvalidFileException
from typing import Annotated, Callable
import typer
from chromadb.errors import InternalError, NotFoundError
from rich import print
from chromy.errors import UnsupportedTextFileError
from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import (
@@ -124,7 +124,7 @@ def import_data(
_fail(f"Collection '{collection}' does not exist.")
except FileNotFoundError:
_fail(f"The file '{file}' was not found.")
except InvalidFileException:
except UnsupportedTextFileError:
_fail(f"The file '{file}' is not a text file.")
+5
View File
@@ -0,0 +1,5 @@
from __future__ import annotations
class UnsupportedTextFileError(Exception):
"""Raised when a file does not appear to contain supported text content."""
+2 -1
View File
@@ -3,8 +3,9 @@ from __future__ import annotations
from rich import print
from chromy.chroma_functions import count_collection
from chromy.output import format_count_message
def handle_count_collection(collection: str) -> int:
print(count_collection(collection))
print(format_count_message(collection, count_collection(collection)))
return 0
+3 -3
View File
@@ -2,10 +2,10 @@ from __future__ import annotations
import os
from pathlib import Path
from plistlib import InvalidFileException
from rich import print
from chromy.errors import UnsupportedTextFileError
from chromy.utilities import ingest_file
from ..utilities import is_probably_text_file
@@ -33,8 +33,8 @@ def handle_import(collection: str, file: str) -> int:
absolute_path = _get_absolute_path(file)
if not is_probably_text_file(absolute_path):
raise InvalidFileException()
raise UnsupportedTextFileError()
records_added = ingest_file(collection, _get_absolute_path(file))
records_added = ingest_file(collection, absolute_path)
print(f"[bold green]Added[/] {records_added} records to collection '{collection}'.")
return 0
+2 -2
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
from chromy.chroma_functions import list_collections
from chromy.utilities import print_lines
from chromy.output import format_collection_names, print_lines
def handle_list_collections() -> int:
@@ -11,6 +11,6 @@ def handle_list_collections() -> int:
print("No collections found.")
return 0
print_lines(collections)
print_lines(format_collection_names(collections))
return 0
+2 -1
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
from chromy.utilities import format_query_result, print_lines, run_query
from chromy.output import format_query_result, print_lines
from chromy.utilities import run_query
def handle_query(collection: str, query_text: str) -> int:
+70
View File
@@ -0,0 +1,70 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from chromadb import QueryResult
from rich.console import Console
from rich.rule import Rule
from rich.text import Text
CONSOLE = Console()
def print_lines(lines: Sequence[Rule | Text | str]) -> None:
for line in lines:
CONSOLE.print(line)
def format_collection_names(collections: Sequence[str]) -> list[Text]:
return [Text(f"· {collection}") for collection in collections]
def format_count_message(collection_name: str, count: int) -> str:
return (
f"The '{collection_name}' collection contains [bold green]{count}[/] records."
)
def format_query_result(result: QueryResult) -> list[Rule | Text]:
ids = result.get("ids", [[]])
documents = result.get("documents", [[]])
distances = result.get("distances", [[]])
metadatas = result.get("metadatas", [[]])
first_ids = ids[0] if ids else []
first_documents = documents[0] if documents else []
first_distances = distances[0] if distances else []
first_metadatas = metadatas[0] if metadatas else []
if not first_ids:
return [Text.from_markup("[yellow]No results found.[/]")]
lines: list[Rule | Text] = [Rule(title="Query results")]
for index, document_id in enumerate(first_ids, start=1):
lines.append(
Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}")
)
i = index - 1
if i < len(first_distances):
lines.append(
Text.from_markup(f"\t[green]distance[/]\t{first_distances[i]}")
)
if i < len(first_metadatas):
metadata = first_metadatas[i]
if isinstance(metadata, Mapping):
file_name = metadata.get("file_name")
if file_name:
lines.append(Text.from_markup(f"\t[green]file_name[/]\t{file_name}"))
if i < len(first_documents):
lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n"))
lines.append(Text(first_documents[i]))
lines.append(Rule())
return lines
-59
View File
@@ -1,24 +1,13 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from pathlib import Path
from chromadb import QueryResult
from rich.console import Console
from rich.rule import Rule
from rich.text import Text
from chromy.chroma_functions import add_data, query_data
from chromy.chunk_functions import chunk_file
from chromy.embed import embed
CONSOLE = Console()
def print_lines(lines: Sequence[Rule | Text]) -> None:
for line in lines:
CONSOLE.print(line)
def ingest_file(collection_name: str, file_path: str) -> int:
chunks = chunk_file(file_path)
@@ -31,54 +20,6 @@ def run_query(collection_name: str, query_text: str) -> QueryResult:
return query_data(collection_name, [query_text])
def format_query_result(result: QueryResult) -> list[Rule | Text]:
ids = result.get("ids", [[]])
documents = result.get("documents", [[]])
distances = result.get("distances", [[]])
metadatas = result.get("metadatas", [[]])
first_ids = ids[0] if ids else []
first_documents = documents[0] if documents else []
first_distances = distances[0] if distances else []
first_metadatas = metadatas[0] if metadatas else []
if not first_ids:
return [Text.from_markup("[yellow]No results found.[/]")]
lines: list[Rule | Text] = [Rule(title="Query results")]
for index, document_id in enumerate(first_ids, start=1):
lines.append(
Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}")
)
i = index - 1
if i < len(first_distances):
lines.append(
Text.from_markup(f"\t[green]distance[/]\t{first_distances[i]}")
)
if i < len(first_metadatas):
metadata = first_metadatas[i]
if isinstance(metadata, Mapping):
file_name = metadata.get("file_name")
if file_name:
lines.append(
Text.from_markup(f"\t[green]file_name[/]\t{file_name}")
)
if i < len(first_documents):
lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n"))
lines.append(Text(first_documents[i]))
# Print a separator between documents
lines.append(Rule())
return lines
def is_probably_text_file(path: str | Path, sample_size: int = 8192) -> bool:
"""
Return whether a file appears to contain text.
+18 -2
View File
@@ -35,7 +35,7 @@ class CliTests(unittest.TestCase):
result = _invoke(["list-collections"])
self.assertEqual(result.exit_code, 0)
self.assertEqual(result.stdout, "books\ncode\n")
self.assertEqual(result.stdout, "· books\n· code\n")
def test_create_collection(self) -> None:
with patch(
@@ -89,7 +89,10 @@ class CliTests(unittest.TestCase):
count_collection.assert_called_once_with("notes")
self.assertEqual(result.exit_code, 0)
self.assertEqual(result.stdout, "7\n")
self.assertEqual(
result.stdout,
"The 'notes' collection contains 7 records.\n",
)
def test_import_data(self) -> None:
with patch(
@@ -105,6 +108,19 @@ class CliTests(unittest.TestCase):
self.assertEqual(result.exit_code, 0)
self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n")
def test_import_data_rejects_non_text_files(self) -> None:
with patch(
"chromy.handlers.import_data.is_probably_text_file",
return_value=False,
):
result = _invoke(["import", "notes", "romeo_and_juliet.txt"])
self.assertEqual(result.exit_code, 1)
self.assertEqual(
result.stdout,
"Error: The file 'romeo_and_juliet.txt' is not a text file.\n",
)
def test_query(self) -> None:
query_result = {"ids": [["1"]], "documents": [["hello"]]}
+20 -2
View File
@@ -1,11 +1,29 @@
from __future__ import annotations
import unittest
from unittest.mock import patch
from chromy.embed import embed
class EmbedTest(unittest.TestCase):
def test_embed_function(self) -> None:
self.assertEqual(0, 0)
def test_embed_returns_empty_list_for_empty_chunks(self) -> None:
self.assertEqual(embed([]), [])
def test_embed_pairs_text_with_list_embeddings(self) -> None:
with patch(
"chromy.embed.DefaultEmbeddingFunction",
return_value=lambda chunks: ((1.0, 2.0), (3.0, 4.0)),
):
result = embed(["first", "second"])
self.assertEqual(
result,
[
{"text": "first", "embedding": [1.0, 2.0]},
{"text": "second", "embedding": [3.0, 4.0]},
],
)
if __name__ == "__main__":
+13 -2
View File
@@ -8,6 +8,7 @@ from pathlib import Path
from typing import TypeVar
from unittest.mock import patch
from chromy.errors import UnsupportedTextFileError
from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import (
@@ -47,7 +48,7 @@ class HandlerTests(unittest.TestCase):
)
self.assertEqual(exit_code, 0)
self.assertEqual(output, "notes\nplays\n")
self.assertEqual(output, "· notes\n· plays\n")
def test_create_collection_uses_typed_input(self) -> None:
with patch(
@@ -86,7 +87,7 @@ class HandlerTests(unittest.TestCase):
count.assert_called_once_with("notes")
self.assertEqual(exit_code, 0)
self.assertEqual(output, "7\n")
self.assertEqual(output, "The 'notes' collection contains 7 records.\n")
def test_import_data_uses_typed_input(self) -> None:
with patch(
@@ -106,6 +107,16 @@ class HandlerTests(unittest.TestCase):
self.assertEqual(exit_code, 0)
self.assertEqual(output, "Added 3 records to collection 'notes'.\n")
def test_import_data_rejects_non_text_files(self) -> None:
with (
patch(
"chromy.handlers.import_data.is_probably_text_file",
return_value=False,
),
self.assertRaises(UnsupportedTextFileError),
):
handle_import("notes", "romeo_and_juliet.txt")
def test_query_uses_typed_input(self) -> None:
query_result = {"ids": [["1"]], "documents": [["hello"]]}
with (