From b52952a2eb1c26544dc8a872dab9545e6373e94a Mon Sep 17 00:00:00 2001 From: Matteo Rosati Date: Wed, 22 Apr 2026 22:14:26 +0200 Subject: [PATCH] simplify the app using typer --- README.md | 7 ++ chromy/cli.py | 159 +++++++++++++++++++++++ chromy/cli_app.py | 180 --------------------------- chromy/cli_parser.py | 123 ------------------ chromy/command_inputs.py | 52 -------- chromy/handlers/add_data.py | 7 +- chromy/handlers/count_collection.py | 5 +- chromy/handlers/create_collection.py | 5 +- chromy/handlers/delete_collection.py | 15 ++- chromy/handlers/list_collections.py | 3 +- chromy/handlers/query.py | 5 +- chromy/main.py | 10 +- pyproject.toml | 1 + tests/test_cli.py | 131 +++++++++++++++++++ tests/test_cli_command_inputs.py | 98 --------------- tests/test_handlers.py | 36 ++---- uv.lock | 2 + 17 files changed, 334 insertions(+), 505 deletions(-) create mode 100644 chromy/cli.py delete mode 100644 chromy/cli_app.py delete mode 100644 chromy/cli_parser.py delete mode 100644 chromy/command_inputs.py create mode 100644 tests/test_cli.py delete mode 100644 tests/test_cli_command_inputs.py diff --git a/README.md b/README.md index a2bb941..6f3e1ce 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ delete-collection | dc count | co add-data | ad query | q +delete | del --where = ``` ### Examples @@ -162,6 +163,12 @@ Delete a collection: chromy delete-collection notes ``` +Delete records by metadata: + +```bash +chromy delete notes --where file_name=example.txt +``` + ## How ingestion works When you run `add-data`, the file is: diff --git a/chromy/cli.py b/chromy/cli.py new file mode 100644 index 0000000..80c3be9 --- /dev/null +++ b/chromy/cli.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from typing import Annotated, Callable + +import typer +from chromadb.errors import InternalError, NotFoundError + +from chromy.handlers.add_data import handle_add_data +from chromy.handlers.count_collection import handle_count_collection +from chromy.handlers.create_collection import handle_create_collection +from chromy.handlers.delete_collection import ( + handle_delete_collection, + handle_delete_records, +) +from chromy.handlers.list_collections import handle_list_collections +from chromy.handlers.query import handle_query + +app = typer.Typer(help="Inspect local Chroma collections.") + +ExitCodeHandler = Callable[[], int] + + +def _run(handler: ExitCodeHandler) -> None: + exit_code = handler() + if exit_code != 0: + raise typer.Exit(exit_code) + + +def _fail(message: str) -> None: + typer.echo(message) + raise typer.Exit(1) + + +@app.command("lc", help="List all collections stored in the local Chroma database.") +@app.command( + "list-collections", + help="List all collections stored in the local Chroma database.", +) +def list_collections() -> None: + _run(handle_list_collections) + + +@app.command("cc", help="Create a collection in the local Chroma database.") +@app.command( + "create-collection", + help="Create a collection in the local Chroma database.", +) +def create_collection( + collection: Annotated[ + str, + typer.Argument(help="Name of the collection to create."), + ], +) -> None: + try: + _run(lambda: handle_create_collection(collection)) + except InternalError: + _fail(f"Collection '{collection}' already exists.") + + +@app.command("dc", help="Delete a collection from the local Chroma database.") +@app.command( + "delete-collection", + help="Delete a collection from the local Chroma database.", +) +def delete_collection( + collection: Annotated[ + str, + typer.Argument(help="Name of the collection to delete."), + ], +) -> None: + try: + _run(lambda: handle_delete_collection(collection)) + except NotFoundError: + _fail(f"Collection '{collection}' does not exist.") + + +@app.command("co", help="Count records in a collection from the local Chroma database.") +@app.command( + "count", + help="Count records in a collection from the local Chroma database.", +) +def count( + collection: Annotated[ + str, + typer.Argument(help="Name of the collection to count."), + ], +) -> None: + try: + _run(lambda: handle_count_collection(collection)) + except NotFoundError: + _fail(f"Collection '{collection}' does not exist.") + + +@app.command( + "ad", + help="Chunk, embed, and add a file to a collection in the local Chroma database.", +) +@app.command( + "add-data", + help="Chunk, embed, and add a file to a collection in the local Chroma database.", +) +def add_data( + collection: Annotated[ + str, + typer.Argument(help="Name of the target collection."), + ], + file: Annotated[ + str, + typer.Argument(help="Path to the file to chunk and add to the collection."), + ], +) -> None: + try: + _run(lambda: handle_add_data(collection, file)) + except NotFoundError: + _fail(f"Collection '{collection}' does not exist.") + except FileNotFoundError: + _fail(f"The file {file} was not found.") + + +@app.command("q", help="Query a collection with the provided text.") +@app.command("query", help="Query a collection with the provided text.") +def query( + collection: Annotated[ + str, + typer.Argument(help="Name of the target collection."), + ], + query_text: Annotated[ + str, + typer.Argument(help="The text to query."), + ], +) -> None: + try: + _run(lambda: handle_query(collection, query_text)) + except NotFoundError: + _fail(f"Collection '{collection}' does not exist.") + + +@app.command("del", help="Delete records from a collection using a metadata filter.") +@app.command("delete", help="Delete records from a collection using a metadata filter.") +def delete_records( + collection: Annotated[ + str, + typer.Argument(help="Name of the target collection."), + ], + where: Annotated[ + str, + typer.Option( + "--where", + help="Metadata filter in the format =.", + metavar="CONDITION=VALUE", + ), + ], +) -> None: + try: + _run(lambda: handle_delete_records(collection, where)) + except NotFoundError: + _fail(f"Collection '{collection}' does not exist.") + except ValueError as exc: + _fail(str(exc)) diff --git a/chromy/cli_app.py b/chromy/cli_app.py deleted file mode 100644 index 58e2270..0000000 --- a/chromy/cli_app.py +++ /dev/null @@ -1,180 +0,0 @@ -from __future__ import annotations - -from argparse import Namespace -from collections.abc import Callable, Sequence -from dataclasses import dataclass -from typing import Generic, TypeVar, assert_never - -from chromadb.errors import InternalError, NotFoundError - -from chromy.command_inputs import ( - AddDataInput, - CommandInput, - CountCollectionInput, - CreateCollectionInput, - DeleteCollectionInput, - DeleteRecordsInput, - ListCollectionsInput, - QueryInput, -) -from chromy.handlers.add_data import handle_add_data -from chromy.handlers.count_collection import handle_count_collection -from chromy.handlers.create_collection import handle_create_collection -from chromy.handlers.delete_collection import ( - handle_delete_collection, - handle_delete_records, -) -from chromy.handlers.list_collections import handle_list_collections -from chromy.handlers.query import handle_query - -CommandT = TypeVar("CommandT", bound=CommandInput) -CollectionCommandT = TypeVar( - "CollectionCommandT", - DeleteCollectionInput, - CountCollectionInput, - AddDataInput, - QueryInput, - DeleteRecordsInput, -) -CommandHandler = Callable[[CommandT], int] -ErrorMessageBuilder = Callable[[CommandT, Exception], str] - - -@dataclass(frozen=True, slots=True) -class CliErrorHandler(Generic[CommandT]): - exception_type: type[Exception] - message: ErrorMessageBuilder[CommandT] - - -def build_command_input(args: Namespace) -> CommandInput: - command = str(args.command) - - match command: - case "list-collections": - return ListCollectionsInput() - case "create-collection": - return CreateCollectionInput(collection=str(args.collection)) - case "delete-collection": - return DeleteCollectionInput(collection=str(args.collection)) - case "count": - return CountCollectionInput(collection=str(args.collection)) - case "add-data": - return AddDataInput(collection=str(args.collection), file=str(args.file)) - case "query": - return QueryInput( - collection=str(args.collection), - query_text=str(args.query_text), - ) - case "delete": - return DeleteRecordsInput( - collection=str(args.collection), - where=str(args.where), - ) - case _: - raise ValueError(f"Unknown command: {command}") - - -def execute_command(args: Namespace) -> int: - command_input = build_command_input(args) - - match command_input: - case ListCollectionsInput(): - return _run_command(command_input, handle_list_collections) - case CreateCollectionInput(): - return _run_command( - command_input, - handle_create_collection, - ( - CliErrorHandler( - exception_type=InternalError, - message=_collection_already_exists_message, - ), - ), - ) - case DeleteCollectionInput(): - return _run_command( - command_input, - handle_delete_collection, - (_collection_not_found_handler(DeleteCollectionInput),), - ) - case CountCollectionInput(): - return _run_command( - command_input, - handle_count_collection, - (_collection_not_found_handler(CountCollectionInput),), - ) - case AddDataInput(): - return _run_command( - command_input, - handle_add_data, - ( - _collection_not_found_handler(AddDataInput), - CliErrorHandler( - exception_type=FileNotFoundError, - message=_file_not_found_message, - ), - ), - ) - case QueryInput(): - return _run_command( - command_input, - handle_query, - (_collection_not_found_handler(QueryInput),), - ) - case DeleteRecordsInput(): - return _run_command( - command_input, - handle_delete_records, - ( - _collection_not_found_handler(DeleteRecordsInput), - CliErrorHandler( - exception_type=ValueError, - message=_exception_message, - ), - ), - ) - - assert_never(command_input) - - -def _run_command( - command_input: CommandT, - handler: CommandHandler[CommandT], - error_handlers: Sequence[CliErrorHandler[CommandT]] = (), -) -> int: - try: - return handler(command_input) - except Exception as exc: - for error_handler in error_handlers: - if isinstance(exc, error_handler.exception_type): - print(error_handler.message(command_input, exc)) - return 1 - raise - - -def _collection_already_exists_message( - command: CreateCollectionInput, - _: Exception, -) -> str: - return f"Collection '{command.collection}' already exists." - - -def _collection_not_found_handler( - _: type[CollectionCommandT], -) -> CliErrorHandler[CollectionCommandT]: - return CliErrorHandler( - exception_type=NotFoundError, - message=_collection_not_found_message, - ) - - -def _collection_not_found_message(command: CollectionCommandT, _: Exception) -> str: - return f"Collection '{command.collection}' does not exist." - - -def _file_not_found_message(command: AddDataInput, _: Exception) -> str: - return f"The file {command.file} was not found." - - -def _exception_message(_: DeleteRecordsInput, exc: Exception) -> str: - return str(exc) diff --git a/chromy/cli_parser.py b/chromy/cli_parser.py deleted file mode 100644 index b860408..0000000 --- a/chromy/cli_parser.py +++ /dev/null @@ -1,123 +0,0 @@ -from __future__ import annotations - -import argparse -from dataclasses import dataclass - - -@dataclass(frozen=True, slots=True) -class ArgumentSpec: - name: str - help: str - required: bool = False - metavar: str | None = None - - -@dataclass(frozen=True, slots=True) -class CommandSpec: - name: str - aliases: tuple[str, ...] - help: str - arguments: tuple[ArgumentSpec, ...] = () - - -COMMAND_SPECS: tuple[CommandSpec, ...] = ( - CommandSpec( - name="list-collections", - aliases=("lc",), - help="List all collections stored in the local Chroma database.", - ), - CommandSpec( - name="create-collection", - aliases=("cc",), - help="Create a collection in the local Chroma database.", - arguments=(ArgumentSpec("collection", "Name of the collection to create."),), - ), - CommandSpec( - name="delete-collection", - aliases=("dc",), - help="Delete a collection from the local Chroma database.", - arguments=(ArgumentSpec("collection", "Name of the collection to delete."),), - ), - CommandSpec( - name="count", - aliases=("co",), - help="Count records in a collection from the local Chroma database.", - arguments=(ArgumentSpec("collection", "Name of the collection to count."),), - ), - CommandSpec( - name="add-data", - aliases=("ad",), - help=( - "Chunk, embed, and add a file to a collection in the local Chroma database." - ), - arguments=( - ArgumentSpec("collection", "Name of the target collection."), - ArgumentSpec( - "file", "Path to the file to chunk and add to the collection." - ), - ), - ), - CommandSpec( - name="query", - aliases=("q",), - help="Query a collection with the provided text.", - arguments=( - ArgumentSpec("collection", "Name of the target collection."), - ArgumentSpec("query_text", "The text to query."), - ), - ), - CommandSpec( - name="delete", - aliases=("del",), - help="Delete records from a collection using a metadata filter.", - arguments=( - ArgumentSpec("collection", "Name of the target collection."), - ArgumentSpec( - "--where", - "Metadata filter in the format =.", - required=True, - metavar="CONDITION=VALUE", - ), - ), - ), -) - - -def _add_command( - subparsers: argparse._SubParsersAction[argparse.ArgumentParser], - command: CommandSpec, -) -> None: - subparser = subparsers.add_parser( - command.name, - aliases=list(command.aliases), - help=command.help, - description=command.help, - ) - - for argument in command.arguments: - if argument.name.startswith("-"): - subparser.add_argument( - argument.name, - help=argument.help, - metavar=argument.metavar, - required=argument.required, - ) - continue - - subparser.add_argument( - argument.name, - help=argument.help, - metavar=argument.metavar, - ) - - subparser.set_defaults(command=command.name) - - -def build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Inspect local Chroma collections.") - subparsers = parser.add_subparsers(dest="command", required=True) - - for command in COMMAND_SPECS: - _add_command(subparsers, command) - - return parser diff --git a/chromy/command_inputs.py b/chromy/command_inputs.py deleted file mode 100644 index 2a0a3f6..0000000 --- a/chromy/command_inputs.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass(frozen=True, slots=True) -class ListCollectionsInput: - pass - - -@dataclass(frozen=True, slots=True) -class CreateCollectionInput: - collection: str - - -@dataclass(frozen=True, slots=True) -class DeleteCollectionInput: - collection: str - - -@dataclass(frozen=True, slots=True) -class CountCollectionInput: - collection: str - - -@dataclass(frozen=True, slots=True) -class AddDataInput: - collection: str - file: str - - -@dataclass(frozen=True, slots=True) -class QueryInput: - collection: str - query_text: str - - -@dataclass(frozen=True, slots=True) -class DeleteRecordsInput: - collection: str - where: str - - -CommandInput = ( - ListCollectionsInput - | CreateCollectionInput - | DeleteCollectionInput - | CountCollectionInput - | AddDataInput - | QueryInput - | DeleteRecordsInput -) diff --git a/chromy/handlers/add_data.py b/chromy/handlers/add_data.py index 2afec79..e59b925 100644 --- a/chromy/handlers/add_data.py +++ b/chromy/handlers/add_data.py @@ -1,10 +1,9 @@ from __future__ import annotations -from chromy.command_inputs import AddDataInput from chromy.utilities import ingest_file -def handle_add_data(command: AddDataInput) -> int: - records_added = ingest_file(command.collection, command.file) - print(f"Added {records_added} records to collection '{command.collection}'.") +def handle_add_data(collection: str, file: str) -> int: + records_added = ingest_file(collection, file) + print(f"Added {records_added} records to collection '{collection}'.") return 0 diff --git a/chromy/handlers/count_collection.py b/chromy/handlers/count_collection.py index 004dc9c..559c92b 100644 --- a/chromy/handlers/count_collection.py +++ b/chromy/handlers/count_collection.py @@ -1,9 +1,8 @@ from __future__ import annotations from chromy.chroma_functions import count_collection -from chromy.command_inputs import CountCollectionInput -def handle_count_collection(command: CountCollectionInput) -> int: - print(count_collection(command.collection)) +def handle_count_collection(collection: str) -> int: + print(count_collection(collection)) return 0 diff --git a/chromy/handlers/create_collection.py b/chromy/handlers/create_collection.py index 050f7e0..b42723d 100644 --- a/chromy/handlers/create_collection.py +++ b/chromy/handlers/create_collection.py @@ -1,10 +1,9 @@ from __future__ import annotations from chromy.chroma_functions import create_collection -from chromy.command_inputs import CreateCollectionInput -def handle_create_collection(command: CreateCollectionInput) -> int: - collection_name = create_collection(command.collection) +def handle_create_collection(collection: str) -> int: + collection_name = create_collection(collection) print(f"Created collection '{collection_name}'.") return 0 diff --git a/chromy/handlers/delete_collection.py b/chromy/handlers/delete_collection.py index 4525106..f762983 100644 --- a/chromy/handlers/delete_collection.py +++ b/chromy/handlers/delete_collection.py @@ -1,7 +1,6 @@ from __future__ import annotations from chromy.chroma_functions import delete_collection, delete_data -from chromy.command_inputs import DeleteCollectionInput, DeleteRecordsInput def _parse_where_clause(where_clause: str) -> dict[str, str]: @@ -19,18 +18,18 @@ def _parse_where_clause(where_clause: str) -> dict[str, str]: return {condition: value} -def handle_delete_collection(command: DeleteCollectionInput) -> int: - delete_collection(command.collection) - print(f"Deleted collection '{command.collection}'.") +def handle_delete_collection(collection: str) -> int: + delete_collection(collection) + print(f"Deleted collection '{collection}'.") return 0 -def handle_delete_records(command: DeleteRecordsInput) -> int: - where = _parse_where_clause(command.where) - deleted = delete_data(command.collection, where) +def handle_delete_records(collection: str, where_clause: str) -> int: + where = _parse_where_clause(where_clause) + deleted = delete_data(collection, where) condition, value = next(iter(where.items())) print( - f"Deleted {deleted} record(s) from collection '{command.collection}' " + f"Deleted {deleted} record(s) from collection '{collection}' " f"where {condition}={value}." ) return 0 diff --git a/chromy/handlers/list_collections.py b/chromy/handlers/list_collections.py index 89acf6d..40b251f 100644 --- a/chromy/handlers/list_collections.py +++ b/chromy/handlers/list_collections.py @@ -1,11 +1,10 @@ from __future__ import annotations from chromy.chroma_functions import list_collections -from chromy.command_inputs import ListCollectionsInput from chromy.utilities import print_lines -def handle_list_collections(_: ListCollectionsInput) -> int: +def handle_list_collections() -> int: collections = list_collections() if not collections: print("No collections found.") diff --git a/chromy/handlers/query.py b/chromy/handlers/query.py index 972c502..b94969a 100644 --- a/chromy/handlers/query.py +++ b/chromy/handlers/query.py @@ -1,10 +1,9 @@ from __future__ import annotations -from chromy.command_inputs import QueryInput from chromy.utilities import format_query_result, print_lines, run_query -def handle_query(command: QueryInput) -> int: - result = run_query(command.collection, command.query_text) +def handle_query(collection: str, query_text: str) -> int: + result = run_query(collection, query_text) print_lines(format_query_result(result)) return 0 diff --git a/chromy/main.py b/chromy/main.py index a152e2d..702774b 100644 --- a/chromy/main.py +++ b/chromy/main.py @@ -2,15 +2,13 @@ from __future__ import annotations from dotenv import load_dotenv -from chromy.cli_app import execute_command -from chromy.cli_parser import build_parser +from chromy.cli import app -def main() -> int: +def main() -> None: load_dotenv() - args = build_parser().parse_args() - return execute_command(args) + app() if __name__ == "__main__": - raise SystemExit(main()) + main() diff --git a/pyproject.toml b/pyproject.toml index 1e7cb80..745db33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "semchunk>=4.0.0", "tiktoken>=0.12.0", "transformers>=5.5.4", + "typer>=0.24.1", ] [project.scripts] diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..2bbb703 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import unittest +from collections.abc import Sequence +from unittest.mock import patch + +from click.testing import Result +from typer.testing import CliRunner + +from chromy.cli import app + + +class CliTests(unittest.TestCase): + def test_list_collections_and_alias(self) -> None: + for command in ("list-collections", "lc"): + with patch( + "chromy.handlers.list_collections.list_collections", + return_value=[], + ): + result = _invoke([command]) + + self.assertEqual(result.exit_code, 0) + self.assertEqual(result.stdout, "No collections found.\n") + + def test_create_collection_and_alias(self) -> None: + for command in ("create-collection", "cc"): + with patch( + "chromy.handlers.create_collection.create_collection", + return_value="notes", + ) as create_collection: + result = _invoke([command, "notes"]) + + create_collection.assert_called_once_with("notes") + self.assertEqual(result.exit_code, 0) + self.assertEqual(result.stdout, "Created collection 'notes'.\n") + + def test_delete_collection_and_alias(self) -> None: + for command in ("delete-collection", "dc"): + with patch( + "chromy.handlers.delete_collection.delete_collection", + ) as delete_collection: + result = _invoke([command, "notes"]) + + delete_collection.assert_called_once_with("notes") + self.assertEqual(result.exit_code, 0) + self.assertEqual(result.stdout, "Deleted collection 'notes'.\n") + + def test_count_and_alias(self) -> None: + for command in ("count", "co"): + with patch( + "chromy.handlers.count_collection.count_collection", + return_value=7, + ) as count_collection: + result = _invoke([command, "notes"]) + + count_collection.assert_called_once_with("notes") + self.assertEqual(result.exit_code, 0) + self.assertEqual(result.stdout, "7\n") + + def test_add_data_and_alias(self) -> None: + for command in ("add-data", "ad"): + with patch( + "chromy.handlers.add_data.ingest_file", + return_value=3, + ) as ingest_file: + result = _invoke([command, "notes", "romeo_and_juliet.txt"]) + + ingest_file.assert_called_once_with("notes", "romeo_and_juliet.txt") + self.assertEqual(result.exit_code, 0) + self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n") + + def test_query_and_alias(self) -> None: + query_result = {"ids": [["1"]], "documents": [["hello"]]} + + for command in ("query", "q"): + with ( + patch( + "chromy.handlers.query.run_query", return_value=query_result + ) as run, + patch( + "chromy.handlers.query.format_query_result", + return_value=["Query results:", "1"], + ) as format_result, + ): + result = _invoke([command, "notes", "Where is Romeo?"]) + + run.assert_called_once_with("notes", "Where is Romeo?") + format_result.assert_called_once_with(query_result) + self.assertEqual(result.exit_code, 0) + self.assertEqual(result.stdout, "Query results:\n1\n") + + def test_delete_records_and_alias(self) -> None: + for command in ("delete", "del"): + with patch( + "chromy.handlers.delete_collection.delete_data", + return_value=2, + ) as delete_data: + result = _invoke( + [command, "notes", "--where", " file_name = play.txt "], + ) + + delete_data.assert_called_once_with("notes", {"file_name": "play.txt"}) + self.assertEqual(result.exit_code, 0) + self.assertEqual( + result.stdout, + "Deleted 2 record(s) from collection 'notes' " + "where file_name=play.txt.\n", + ) + + def test_invalid_delete_filter_keeps_user_facing_error(self) -> None: + result = _invoke(["delete", "notes", "--where", "file_name"]) + + self.assertEqual(result.exit_code, 1) + self.assertEqual( + result.stdout, + "Invalid --where value. Expected =.\n", + ) + + def test_delete_requires_where_option(self) -> None: + result = _invoke(["delete", "notes"]) + + self.assertNotEqual(result.exit_code, 0) + self.assertIn("Missing option", result.output) + + +def _invoke(arguments: Sequence[str]) -> Result: + return CliRunner().invoke(app, list(arguments)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cli_command_inputs.py b/tests/test_cli_command_inputs.py deleted file mode 100644 index b1f4dca..0000000 --- a/tests/test_cli_command_inputs.py +++ /dev/null @@ -1,98 +0,0 @@ -from __future__ import annotations - -import io -import unittest -from argparse import Namespace -from collections.abc import Sequence -from contextlib import redirect_stdout - -from chromy.cli_app import build_command_input, execute_command -from chromy.cli_parser import build_parser -from chromy.command_inputs import ( - AddDataInput, - CountCollectionInput, - CreateCollectionInput, - DeleteCollectionInput, - DeleteRecordsInput, - ListCollectionsInput, - QueryInput, -) - - -class BuildCommandInputTests(unittest.TestCase): - def test_parser_converts_list_collections_and_alias(self) -> None: - self.assertEqual(_parse_input(["list-collections"]), ListCollectionsInput()) - self.assertEqual(_parse_input(["lc"]), ListCollectionsInput()) - - def test_parser_converts_create_collection_and_alias(self) -> None: - expected = CreateCollectionInput(collection="notes") - - self.assertEqual(_parse_input(["create-collection", "notes"]), expected) - self.assertEqual(_parse_input(["cc", "notes"]), expected) - - def test_parser_converts_delete_collection_and_alias(self) -> None: - expected = DeleteCollectionInput(collection="notes") - - self.assertEqual(_parse_input(["delete-collection", "notes"]), expected) - self.assertEqual(_parse_input(["dc", "notes"]), expected) - - def test_parser_converts_count_and_alias(self) -> None: - expected = CountCollectionInput(collection="notes") - - self.assertEqual(_parse_input(["count", "notes"]), expected) - self.assertEqual(_parse_input(["co", "notes"]), expected) - - def test_parser_converts_add_data_and_alias(self) -> None: - expected = AddDataInput(collection="notes", file="romeo_and_juliet.txt") - - self.assertEqual( - _parse_input(["add-data", "notes", "romeo_and_juliet.txt"]), - expected, - ) - self.assertEqual( - _parse_input(["ad", "notes", "romeo_and_juliet.txt"]), - expected, - ) - - def test_parser_converts_query_and_alias(self) -> None: - expected = QueryInput(collection="notes", query_text="Where is Romeo?") - - self.assertEqual( - _parse_input(["query", "notes", "Where is Romeo?"]), - expected, - ) - self.assertEqual(_parse_input(["q", "notes", "Where is Romeo?"]), expected) - - def test_parser_converts_delete_records_and_alias(self) -> None: - expected = DeleteRecordsInput(collection="notes", where="file_name=play.txt") - - self.assertEqual( - _parse_input(["delete", "notes", "--where", "file_name=play.txt"]), - expected, - ) - self.assertEqual( - _parse_input(["del", "notes", "--where", "file_name=play.txt"]), - expected, - ) - - def test_invalid_delete_filter_keeps_user_facing_error(self) -> None: - args = Namespace(command="delete", collection="notes", where="file_name") - output = io.StringIO() - - with redirect_stdout(output): - exit_code = execute_command(args) - - self.assertEqual(exit_code, 1) - self.assertEqual( - output.getvalue().strip(), - "Invalid --where value. Expected =.", - ) - self.assertFalse(hasattr(args, "error_message")) - - -def _parse_input(argv: Sequence[str]) -> object: - return build_command_input(build_parser().parse_args(argv)) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_handlers.py b/tests/test_handlers.py index a125d79..5ebbedb 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -7,15 +7,6 @@ from contextlib import redirect_stdout from typing import TypeVar from unittest.mock import patch -from chromy.command_inputs import ( - AddDataInput, - CountCollectionInput, - CreateCollectionInput, - DeleteCollectionInput, - DeleteRecordsInput, - ListCollectionsInput, - QueryInput, -) from chromy.handlers.add_data import handle_add_data from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.create_collection import handle_create_collection @@ -36,7 +27,6 @@ class HandlerTests(unittest.TestCase): ): exit_code, output = _capture_output( handle_list_collections, - ListCollectionsInput(), ) self.assertEqual(exit_code, 0) @@ -49,7 +39,6 @@ class HandlerTests(unittest.TestCase): ): exit_code, output = _capture_output( handle_list_collections, - ListCollectionsInput(), ) self.assertEqual(exit_code, 0) @@ -62,7 +51,7 @@ class HandlerTests(unittest.TestCase): ) as create_collection: exit_code, output = _capture_output( handle_create_collection, - CreateCollectionInput(collection="notes"), + "notes", ) create_collection.assert_called_once_with("notes") @@ -73,7 +62,7 @@ class HandlerTests(unittest.TestCase): with patch("chromy.handlers.delete_collection.delete_collection") as delete: exit_code, output = _capture_output( handle_delete_collection, - DeleteCollectionInput(collection="notes"), + "notes", ) delete.assert_called_once_with("notes") @@ -87,7 +76,7 @@ class HandlerTests(unittest.TestCase): ) as count: exit_code, output = _capture_output( handle_count_collection, - CountCollectionInput(collection="notes"), + "notes", ) count.assert_called_once_with("notes") @@ -101,7 +90,8 @@ class HandlerTests(unittest.TestCase): ) as ingest_file: exit_code, output = _capture_output( handle_add_data, - AddDataInput(collection="notes", file="romeo_and_juliet.txt"), + "notes", + "romeo_and_juliet.txt", ) ingest_file.assert_called_once_with("notes", "romeo_and_juliet.txt") @@ -119,7 +109,8 @@ class HandlerTests(unittest.TestCase): ): exit_code, output = _capture_output( handle_query, - QueryInput(collection="notes", query_text="hello"), + "notes", + "hello", ) run.assert_called_once_with("notes", "hello") @@ -134,7 +125,8 @@ class HandlerTests(unittest.TestCase): ) as delete_data: exit_code, output = _capture_output( handle_delete_records, - DeleteRecordsInput(collection="notes", where=" file_name = play.txt "), + "notes", + " file_name = play.txt ", ) delete_data.assert_called_once_with("notes", {"file_name": "play.txt"}) @@ -149,19 +141,17 @@ class HandlerTests(unittest.TestCase): ValueError, "Invalid --where value. Expected =.", ): - handle_delete_records( - DeleteRecordsInput(collection="notes", where="file_name") - ) + handle_delete_records("notes", "file_name") def _capture_output( - handler: Callable[[CommandT], int], - command: CommandT, + handler: Callable[..., int], + *arguments: CommandT, ) -> tuple[int, str]: output = io.StringIO() with redirect_stdout(output): - exit_code = handler(command) + exit_code = handler(*arguments) return exit_code, output.getvalue() diff --git a/uv.lock b/uv.lock index e878edf..c949307 100644 --- a/uv.lock +++ b/uv.lock @@ -264,6 +264,7 @@ dependencies = [ { name = "semchunk" }, { name = "tiktoken" }, { name = "transformers" }, + { name = "typer" }, ] [package.dev-dependencies] @@ -283,6 +284,7 @@ requires-dist = [ { name = "semchunk", specifier = ">=4.0.0" }, { name = "tiktoken", specifier = ">=0.12.0" }, { name = "transformers", specifier = ">=5.5.4" }, + { name = "typer", specifier = ">=0.24.1" }, ] [package.metadata.requires-dev]