diff --git a/chromy/chroma_functions.py b/chromy/chroma_functions.py index 731f035..c74b6f0 100644 --- a/chromy/chroma_functions.py +++ b/chromy/chroma_functions.py @@ -58,7 +58,9 @@ def count_collection(collection_name: str) -> str: _, collection = _get_client_and_collection(collection_name) count = collection.count() - return f"The '{collection_name}' collection contains [bold green]{count}[/] records." + return ( + f"The '{collection_name}' collection contains [bold green]{count}[/] records." + ) def add_data( @@ -71,8 +73,7 @@ def add_data( _, collection = _get_client_and_collection(collection_name) - embeddings: list[Sequence[float]] = [record["embedding"] - for record in data] + embeddings: list[Sequence[float]] = [record["embedding"] for record in data] collection.add( ids=[str(uuid4()) for _ in data], diff --git a/chromy/cli.py b/chromy/cli.py index 12f0095..d142254 100644 --- a/chromy/cli.py +++ b/chromy/cli.py @@ -3,16 +3,16 @@ from __future__ import annotations from typing import Annotated, Callable import typer -from rich import print from chromadb.errors import InternalError, NotFoundError +from rich import print -from chromy.handlers.import_data import handle_import from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.create_collection import handle_create_collection from chromy.handlers.delete_collection import ( handle_delete_collection, handle_delete_records, ) +from chromy.handlers.import_data import handle_import from chromy.handlers.list_collections import handle_list_collections from chromy.handlers.query import handle_query @@ -114,8 +114,7 @@ def import_data( ], file: Annotated[ str, - typer.Argument( - help="Path to the file to chunk and add to the collection."), + typer.Argument(help="Path to the file to chunk and add to the collection."), ], ) -> None: try: diff --git a/chromy/handlers/count_collection.py b/chromy/handlers/count_collection.py index 53da70f..9218d74 100644 --- a/chromy/handlers/count_collection.py +++ b/chromy/handlers/count_collection.py @@ -1,6 +1,7 @@ from __future__ import annotations from rich import print + from chromy.chroma_functions import count_collection diff --git a/chromy/handlers/create_collection.py b/chromy/handlers/create_collection.py index 2113046..7ce78aa 100644 --- a/chromy/handlers/create_collection.py +++ b/chromy/handlers/create_collection.py @@ -1,6 +1,7 @@ from __future__ import annotations from rich import print + from chromy.chroma_functions import create_collection diff --git a/chromy/handlers/delete_collection.py b/chromy/handlers/delete_collection.py index af1a41a..5bdd886 100644 --- a/chromy/handlers/delete_collection.py +++ b/chromy/handlers/delete_collection.py @@ -1,6 +1,7 @@ from __future__ import annotations from rich import print + from chromy.chroma_functions import delete_collection, delete_data @@ -8,15 +9,13 @@ def _parse_where_clause(where_clause: str) -> dict[str, str]: condition, separator, value = where_clause.partition("=") if separator == "": - raise ValueError( - "Invalid --where value. Expected =.") + raise ValueError("Invalid --where value. Expected =.") condition = condition.strip() value = value.strip() if not condition or not value: - raise ValueError( - "Invalid --where value. Expected =.") + raise ValueError("Invalid --where value. Expected =.") return {condition: value} diff --git a/chromy/utilities.py b/chromy/utilities.py index bd02b82..fdb086f 100644 --- a/chromy/utilities.py +++ b/chromy/utilities.py @@ -14,7 +14,7 @@ from chromy.embed import embed CONSOLE = Console() -def print_lines(lines: Sequence[str]) -> None: +def print_lines(lines: Sequence[Rule | Text]) -> None: for line in lines: CONSOLE.print(line) diff --git a/tests/test_cli.py b/tests/test_cli.py index 2d64187..dbc35a0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -51,15 +51,13 @@ class CliTests(unittest.TestCase): def test_create_collection_with_same_name(self) -> None: with patch( "chromy.handlers.create_collection.create_collection", - side_effect=InternalError() - + side_effect=InternalError(), ) as create_collection: result = _invoke(["create-collection", "notes"]) create_collection.assert_called_once_with("notes") self.assertEqual(result.exit_code, 1) - self.assertEqual( - result.stdout, "Error: Collection 'notes' already exists.\n") + self.assertEqual(result.stdout, "Error: Collection 'notes' already exists.\n") def test_delete_collection(self) -> None: with patch( @@ -74,14 +72,13 @@ class CliTests(unittest.TestCase): def test_delete_non_existent_collection(self) -> None: with patch( "chromy.handlers.delete_collection.delete_collection", - side_effect=NotFoundError() + side_effect=NotFoundError(), ) as delete_collection: result = _invoke(["delete-collection", "notes"]) delete_collection.assert_called_once_with("notes") self.assertEqual(result.exit_code, 1) - self.assertEqual( - result.stdout, "Error: Collection 'notes' does not exist.\n") + self.assertEqual(result.stdout, "Error: Collection 'notes' does not exist.\n") def test_count(self) -> None: with patch( @@ -106,8 +103,7 @@ class CliTests(unittest.TestCase): self._fixture_path("romeo_and_juliet.txt"), ) self.assertEqual(result.exit_code, 0) - self.assertEqual( - result.stdout, "Added 3 records to collection 'notes'.\n") + self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n") def test_query(self) -> None: query_result = {"ids": [["1"]], "documents": [["hello"]]} @@ -139,8 +135,7 @@ class CliTests(unittest.TestCase): self.assertEqual(result.exit_code, 0) self.assertEqual( result.stdout, - "Deleted 2 record(s) from collection 'notes' " - "where file_name=play.txt.\n", + "Deleted 2 record(s) from collection 'notes' where file_name=play.txt.\n", ) def test_invalid_delete_filter_keeps_user_facing_error(self) -> None: diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 897ab0d..d66274e 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -8,13 +8,13 @@ from pathlib import Path from typing import TypeVar from unittest.mock import patch -from chromy.handlers.import_data import handle_import from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.create_collection import handle_create_collection from chromy.handlers.delete_collection import ( handle_delete_collection, handle_delete_records, ) +from chromy.handlers.import_data import handle_import from chromy.handlers.list_collections import handle_list_collections from chromy.handlers.query import handle_query @@ -154,7 +154,7 @@ class HandlerTests(unittest.TestCase): def _capture_output( handler: Callable[..., int], - *arguments: CommandT, + *arguments: object, ) -> tuple[int, str]: output = io.StringIO()