diff --git a/chromy/chroma_functions.py b/chromy/chroma_functions.py index 33b11b1..a945bee 100644 --- a/chromy/chroma_functions.py +++ b/chromy/chroma_functions.py @@ -8,7 +8,6 @@ import chromadb from chromadb.api import ClientAPI from chromadb.api.types import QueryResult, Where from chromadb.errors import NotFoundError -from rich.text import Text from chromy.embed import EmbeddingRecord @@ -26,17 +25,14 @@ def _get_client_and_collection( return client, collection -def list_collections() -> list[Text]: +def list_collections() -> list[str]: client = chromadb.PersistentClient() collections = client.list_collections() if not collections: return [] - return [ - Text("· " + getattr(collection, "name", str(collection))) - for collection in collections - ] + return [getattr(collection, "name", str(collection)) for collection in collections] def create_collection(name: str) -> str: @@ -58,13 +54,9 @@ def delete_data(collection_name: str, where: dict[str, str]) -> int: return int(result.get("deleted", 0)) -def count_collection(collection_name: str) -> str: +def count_collection(collection_name: str) -> int: _, collection = _get_client_and_collection(collection_name) - count = collection.count() - - return ( - f"The '{collection_name}' collection contains [bold green]{count}[/] records." - ) + return collection.count() def add_data( diff --git a/chromy/cli.py b/chromy/cli.py index 2f2f1b9..0dc6899 100644 --- a/chromy/cli.py +++ b/chromy/cli.py @@ -1,12 +1,12 @@ from __future__ import annotations -from plistlib import InvalidFileException from typing import Annotated, Callable import typer from chromadb.errors import InternalError, NotFoundError from rich import print +from chromy.errors import UnsupportedTextFileError from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.create_collection import handle_create_collection from chromy.handlers.delete_collection import ( @@ -124,7 +124,7 @@ def import_data( _fail(f"Collection '{collection}' does not exist.") except FileNotFoundError: _fail(f"The file '{file}' was not found.") - except InvalidFileException: + except UnsupportedTextFileError: _fail(f"The file '{file}' is not a text file.") diff --git a/chromy/errors.py b/chromy/errors.py new file mode 100644 index 0000000..2cb3fcb --- /dev/null +++ b/chromy/errors.py @@ -0,0 +1,5 @@ +from __future__ import annotations + + +class UnsupportedTextFileError(Exception): + """Raised when a file does not appear to contain supported text content.""" diff --git a/chromy/handlers/count_collection.py b/chromy/handlers/count_collection.py index 9218d74..ec205fe 100644 --- a/chromy/handlers/count_collection.py +++ b/chromy/handlers/count_collection.py @@ -3,8 +3,9 @@ from __future__ import annotations from rich import print from chromy.chroma_functions import count_collection +from chromy.output import format_count_message def handle_count_collection(collection: str) -> int: - print(count_collection(collection)) + print(format_count_message(collection, count_collection(collection))) return 0 diff --git a/chromy/handlers/import_data.py b/chromy/handlers/import_data.py index 365736c..d59281f 100644 --- a/chromy/handlers/import_data.py +++ b/chromy/handlers/import_data.py @@ -2,10 +2,10 @@ from __future__ import annotations import os from pathlib import Path -from plistlib import InvalidFileException from rich import print +from chromy.errors import UnsupportedTextFileError from chromy.utilities import ingest_file from ..utilities import is_probably_text_file @@ -33,8 +33,8 @@ def handle_import(collection: str, file: str) -> int: absolute_path = _get_absolute_path(file) if not is_probably_text_file(absolute_path): - raise InvalidFileException() + raise UnsupportedTextFileError() - records_added = ingest_file(collection, _get_absolute_path(file)) + records_added = ingest_file(collection, absolute_path) print(f"[bold green]Added[/] {records_added} records to collection '{collection}'.") return 0 diff --git a/chromy/handlers/list_collections.py b/chromy/handlers/list_collections.py index 6762583..5395d4b 100644 --- a/chromy/handlers/list_collections.py +++ b/chromy/handlers/list_collections.py @@ -1,7 +1,7 @@ from __future__ import annotations from chromy.chroma_functions import list_collections -from chromy.utilities import print_lines +from chromy.output import format_collection_names, print_lines def handle_list_collections() -> int: @@ -11,6 +11,6 @@ def handle_list_collections() -> int: print("No collections found.") return 0 - print_lines(collections) + print_lines(format_collection_names(collections)) return 0 diff --git a/chromy/handlers/query.py b/chromy/handlers/query.py index b94969a..467499b 100644 --- a/chromy/handlers/query.py +++ b/chromy/handlers/query.py @@ -1,6 +1,7 @@ from __future__ import annotations -from chromy.utilities import format_query_result, print_lines, run_query +from chromy.output import format_query_result, print_lines +from chromy.utilities import run_query def handle_query(collection: str, query_text: str) -> int: diff --git a/chromy/output.py b/chromy/output.py new file mode 100644 index 0000000..a7e4d4a --- /dev/null +++ b/chromy/output.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence + +from chromadb import QueryResult +from rich.console import Console +from rich.rule import Rule +from rich.text import Text + +CONSOLE = Console() + + +def print_lines(lines: Sequence[Rule | Text | str]) -> None: + for line in lines: + CONSOLE.print(line) + + +def format_collection_names(collections: Sequence[str]) -> list[Text]: + return [Text(f"· {collection}") for collection in collections] + + +def format_count_message(collection_name: str, count: int) -> str: + return ( + f"The '{collection_name}' collection contains [bold green]{count}[/] records." + ) + + +def format_query_result(result: QueryResult) -> list[Rule | Text]: + ids = result.get("ids", [[]]) + documents = result.get("documents", [[]]) + distances = result.get("distances", [[]]) + metadatas = result.get("metadatas", [[]]) + + first_ids = ids[0] if ids else [] + first_documents = documents[0] if documents else [] + first_distances = distances[0] if distances else [] + first_metadatas = metadatas[0] if metadatas else [] + + if not first_ids: + return [Text.from_markup("[yellow]No results found.[/]")] + + lines: list[Rule | Text] = [Rule(title="Query results")] + + for index, document_id in enumerate(first_ids, start=1): + lines.append( + Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}") + ) + i = index - 1 + + if i < len(first_distances): + lines.append( + Text.from_markup(f"\t[green]distance[/]\t{first_distances[i]}") + ) + + if i < len(first_metadatas): + metadata = first_metadatas[i] + + if isinstance(metadata, Mapping): + file_name = metadata.get("file_name") + + if file_name: + lines.append(Text.from_markup(f"\t[green]file_name[/]\t{file_name}")) + + if i < len(first_documents): + lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n")) + lines.append(Text(first_documents[i])) + + lines.append(Rule()) + + return lines diff --git a/chromy/utilities.py b/chromy/utilities.py index 3ddf44c..8d940e1 100644 --- a/chromy/utilities.py +++ b/chromy/utilities.py @@ -1,24 +1,13 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence from pathlib import Path from chromadb import QueryResult -from rich.console import Console -from rich.rule import Rule -from rich.text import Text from chromy.chroma_functions import add_data, query_data from chromy.chunk_functions import chunk_file from chromy.embed import embed -CONSOLE = Console() - - -def print_lines(lines: Sequence[Rule | Text]) -> None: - for line in lines: - CONSOLE.print(line) - def ingest_file(collection_name: str, file_path: str) -> int: chunks = chunk_file(file_path) @@ -31,54 +20,6 @@ def run_query(collection_name: str, query_text: str) -> QueryResult: return query_data(collection_name, [query_text]) -def format_query_result(result: QueryResult) -> list[Rule | Text]: - ids = result.get("ids", [[]]) - documents = result.get("documents", [[]]) - distances = result.get("distances", [[]]) - metadatas = result.get("metadatas", [[]]) - - first_ids = ids[0] if ids else [] - first_documents = documents[0] if documents else [] - first_distances = distances[0] if distances else [] - first_metadatas = metadatas[0] if metadatas else [] - - if not first_ids: - return [Text.from_markup("[yellow]No results found.[/]")] - - lines: list[Rule | Text] = [Rule(title="Query results")] - - for index, document_id in enumerate(first_ids, start=1): - lines.append( - Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}") - ) - i = index - 1 - - if i < len(first_distances): - lines.append( - Text.from_markup(f"\t[green]distance[/]\t{first_distances[i]}") - ) - - if i < len(first_metadatas): - metadata = first_metadatas[i] - - if isinstance(metadata, Mapping): - file_name = metadata.get("file_name") - - if file_name: - lines.append( - Text.from_markup(f"\t[green]file_name[/]\t{file_name}") - ) - - if i < len(first_documents): - lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n")) - lines.append(Text(first_documents[i])) - - # Print a separator between documents - lines.append(Rule()) - - return lines - - def is_probably_text_file(path: str | Path, sample_size: int = 8192) -> bool: """ Return whether a file appears to contain text. diff --git a/tests/test_cli.py b/tests/test_cli.py index dbc35a0..efde08f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -31,11 +31,11 @@ class CliTests(unittest.TestCase): with patch( "chromy.handlers.list_collections.list_collections", return_value=["books", "code"], - ): + ): result = _invoke(["list-collections"]) self.assertEqual(result.exit_code, 0) - self.assertEqual(result.stdout, "books\ncode\n") + self.assertEqual(result.stdout, "· books\n· code\n") def test_create_collection(self) -> None: with patch( @@ -89,7 +89,10 @@ class CliTests(unittest.TestCase): count_collection.assert_called_once_with("notes") self.assertEqual(result.exit_code, 0) - self.assertEqual(result.stdout, "7\n") + self.assertEqual( + result.stdout, + "The 'notes' collection contains 7 records.\n", + ) def test_import_data(self) -> None: with patch( @@ -105,6 +108,19 @@ class CliTests(unittest.TestCase): self.assertEqual(result.exit_code, 0) self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n") + def test_import_data_rejects_non_text_files(self) -> None: + with patch( + "chromy.handlers.import_data.is_probably_text_file", + return_value=False, + ): + result = _invoke(["import", "notes", "romeo_and_juliet.txt"]) + + self.assertEqual(result.exit_code, 1) + self.assertEqual( + result.stdout, + "Error: The file 'romeo_and_juliet.txt' is not a text file.\n", + ) + def test_query(self) -> None: query_result = {"ids": [["1"]], "documents": [["hello"]]} diff --git a/tests/test_embed.py b/tests/test_embed.py index 4acfdcc..165e85b 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -1,11 +1,29 @@ from __future__ import annotations import unittest +from unittest.mock import patch + +from chromy.embed import embed class EmbedTest(unittest.TestCase): - def test_embed_function(self) -> None: - self.assertEqual(0, 0) + def test_embed_returns_empty_list_for_empty_chunks(self) -> None: + self.assertEqual(embed([]), []) + + def test_embed_pairs_text_with_list_embeddings(self) -> None: + with patch( + "chromy.embed.DefaultEmbeddingFunction", + return_value=lambda chunks: ((1.0, 2.0), (3.0, 4.0)), + ): + result = embed(["first", "second"]) + + self.assertEqual( + result, + [ + {"text": "first", "embedding": [1.0, 2.0]}, + {"text": "second", "embedding": [3.0, 4.0]}, + ], + ) if __name__ == "__main__": diff --git a/tests/test_handlers.py b/tests/test_handlers.py index d66274e..7bb2d6c 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import TypeVar from unittest.mock import patch +from chromy.errors import UnsupportedTextFileError from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.create_collection import handle_create_collection from chromy.handlers.delete_collection import ( @@ -47,7 +48,7 @@ class HandlerTests(unittest.TestCase): ) self.assertEqual(exit_code, 0) - self.assertEqual(output, "notes\nplays\n") + self.assertEqual(output, "· notes\n· plays\n") def test_create_collection_uses_typed_input(self) -> None: with patch( @@ -86,7 +87,7 @@ class HandlerTests(unittest.TestCase): count.assert_called_once_with("notes") self.assertEqual(exit_code, 0) - self.assertEqual(output, "7\n") + self.assertEqual(output, "The 'notes' collection contains 7 records.\n") def test_import_data_uses_typed_input(self) -> None: with patch( @@ -106,6 +107,16 @@ class HandlerTests(unittest.TestCase): self.assertEqual(exit_code, 0) self.assertEqual(output, "Added 3 records to collection 'notes'.\n") + def test_import_data_rejects_non_text_files(self) -> None: + with ( + patch( + "chromy.handlers.import_data.is_probably_text_file", + return_value=False, + ), + self.assertRaises(UnsupportedTextFileError), + ): + handle_import("notes", "romeo_and_juliet.txt") + def test_query_uses_typed_input(self) -> None: query_result = {"ids": [["1"]], "documents": [["hello"]]} with (