simplify the app using typer

This commit is contained in:
Matteo Rosati
2026-04-22 22:14:26 +02:00
parent 2dfaa68466
commit b52952a2eb
17 changed files with 334 additions and 505 deletions
+7
View File
@@ -122,6 +122,7 @@ delete-collection | dc <collection>
count | co <collection> count | co <collection>
add-data | ad <collection> <file> add-data | ad <collection> <file>
query | q <collection> <query_text> query | q <collection> <query_text>
delete | del <collection> --where <condition>=<value>
``` ```
### Examples ### Examples
@@ -162,6 +163,12 @@ Delete a collection:
chromy delete-collection notes chromy delete-collection notes
``` ```
Delete records by metadata:
```bash
chromy delete notes --where file_name=example.txt
```
## How ingestion works ## How ingestion works
When you run `add-data`, the file is: When you run `add-data`, the file is:
+159
View File
@@ -0,0 +1,159 @@
from __future__ import annotations
from typing import Annotated, Callable
import typer
from chromadb.errors import InternalError, NotFoundError
from chromy.handlers.add_data import handle_add_data
from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import (
handle_delete_collection,
handle_delete_records,
)
from chromy.handlers.list_collections import handle_list_collections
from chromy.handlers.query import handle_query
app = typer.Typer(help="Inspect local Chroma collections.")
ExitCodeHandler = Callable[[], int]
def _run(handler: ExitCodeHandler) -> None:
exit_code = handler()
if exit_code != 0:
raise typer.Exit(exit_code)
def _fail(message: str) -> None:
typer.echo(message)
raise typer.Exit(1)
@app.command("lc", help="List all collections stored in the local Chroma database.")
@app.command(
"list-collections",
help="List all collections stored in the local Chroma database.",
)
def list_collections() -> None:
_run(handle_list_collections)
@app.command("cc", help="Create a collection in the local Chroma database.")
@app.command(
"create-collection",
help="Create a collection in the local Chroma database.",
)
def create_collection(
collection: Annotated[
str,
typer.Argument(help="Name of the collection to create."),
],
) -> None:
try:
_run(lambda: handle_create_collection(collection))
except InternalError:
_fail(f"Collection '{collection}' already exists.")
@app.command("dc", help="Delete a collection from the local Chroma database.")
@app.command(
"delete-collection",
help="Delete a collection from the local Chroma database.",
)
def delete_collection(
collection: Annotated[
str,
typer.Argument(help="Name of the collection to delete."),
],
) -> None:
try:
_run(lambda: handle_delete_collection(collection))
except NotFoundError:
_fail(f"Collection '{collection}' does not exist.")
@app.command("co", help="Count records in a collection from the local Chroma database.")
@app.command(
"count",
help="Count records in a collection from the local Chroma database.",
)
def count(
collection: Annotated[
str,
typer.Argument(help="Name of the collection to count."),
],
) -> None:
try:
_run(lambda: handle_count_collection(collection))
except NotFoundError:
_fail(f"Collection '{collection}' does not exist.")
@app.command(
"ad",
help="Chunk, embed, and add a file to a collection in the local Chroma database.",
)
@app.command(
"add-data",
help="Chunk, embed, and add a file to a collection in the local Chroma database.",
)
def add_data(
collection: Annotated[
str,
typer.Argument(help="Name of the target collection."),
],
file: Annotated[
str,
typer.Argument(help="Path to the file to chunk and add to the collection."),
],
) -> None:
try:
_run(lambda: handle_add_data(collection, file))
except NotFoundError:
_fail(f"Collection '{collection}' does not exist.")
except FileNotFoundError:
_fail(f"The file {file} was not found.")
@app.command("q", help="Query a collection with the provided text.")
@app.command("query", help="Query a collection with the provided text.")
def query(
collection: Annotated[
str,
typer.Argument(help="Name of the target collection."),
],
query_text: Annotated[
str,
typer.Argument(help="The text to query."),
],
) -> None:
try:
_run(lambda: handle_query(collection, query_text))
except NotFoundError:
_fail(f"Collection '{collection}' does not exist.")
@app.command("del", help="Delete records from a collection using a metadata filter.")
@app.command("delete", help="Delete records from a collection using a metadata filter.")
def delete_records(
collection: Annotated[
str,
typer.Argument(help="Name of the target collection."),
],
where: Annotated[
str,
typer.Option(
"--where",
help="Metadata filter in the format <condition>=<value>.",
metavar="CONDITION=VALUE",
),
],
) -> None:
try:
_run(lambda: handle_delete_records(collection, where))
except NotFoundError:
_fail(f"Collection '{collection}' does not exist.")
except ValueError as exc:
_fail(str(exc))
-180
View File
@@ -1,180 +0,0 @@
from __future__ import annotations
from argparse import Namespace
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Generic, TypeVar, assert_never
from chromadb.errors import InternalError, NotFoundError
from chromy.command_inputs import (
AddDataInput,
CommandInput,
CountCollectionInput,
CreateCollectionInput,
DeleteCollectionInput,
DeleteRecordsInput,
ListCollectionsInput,
QueryInput,
)
from chromy.handlers.add_data import handle_add_data
from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import (
handle_delete_collection,
handle_delete_records,
)
from chromy.handlers.list_collections import handle_list_collections
from chromy.handlers.query import handle_query
CommandT = TypeVar("CommandT", bound=CommandInput)
CollectionCommandT = TypeVar(
"CollectionCommandT",
DeleteCollectionInput,
CountCollectionInput,
AddDataInput,
QueryInput,
DeleteRecordsInput,
)
CommandHandler = Callable[[CommandT], int]
ErrorMessageBuilder = Callable[[CommandT, Exception], str]
@dataclass(frozen=True, slots=True)
class CliErrorHandler(Generic[CommandT]):
exception_type: type[Exception]
message: ErrorMessageBuilder[CommandT]
def build_command_input(args: Namespace) -> CommandInput:
command = str(args.command)
match command:
case "list-collections":
return ListCollectionsInput()
case "create-collection":
return CreateCollectionInput(collection=str(args.collection))
case "delete-collection":
return DeleteCollectionInput(collection=str(args.collection))
case "count":
return CountCollectionInput(collection=str(args.collection))
case "add-data":
return AddDataInput(collection=str(args.collection), file=str(args.file))
case "query":
return QueryInput(
collection=str(args.collection),
query_text=str(args.query_text),
)
case "delete":
return DeleteRecordsInput(
collection=str(args.collection),
where=str(args.where),
)
case _:
raise ValueError(f"Unknown command: {command}")
def execute_command(args: Namespace) -> int:
command_input = build_command_input(args)
match command_input:
case ListCollectionsInput():
return _run_command(command_input, handle_list_collections)
case CreateCollectionInput():
return _run_command(
command_input,
handle_create_collection,
(
CliErrorHandler(
exception_type=InternalError,
message=_collection_already_exists_message,
),
),
)
case DeleteCollectionInput():
return _run_command(
command_input,
handle_delete_collection,
(_collection_not_found_handler(DeleteCollectionInput),),
)
case CountCollectionInput():
return _run_command(
command_input,
handle_count_collection,
(_collection_not_found_handler(CountCollectionInput),),
)
case AddDataInput():
return _run_command(
command_input,
handle_add_data,
(
_collection_not_found_handler(AddDataInput),
CliErrorHandler(
exception_type=FileNotFoundError,
message=_file_not_found_message,
),
),
)
case QueryInput():
return _run_command(
command_input,
handle_query,
(_collection_not_found_handler(QueryInput),),
)
case DeleteRecordsInput():
return _run_command(
command_input,
handle_delete_records,
(
_collection_not_found_handler(DeleteRecordsInput),
CliErrorHandler(
exception_type=ValueError,
message=_exception_message,
),
),
)
assert_never(command_input)
def _run_command(
command_input: CommandT,
handler: CommandHandler[CommandT],
error_handlers: Sequence[CliErrorHandler[CommandT]] = (),
) -> int:
try:
return handler(command_input)
except Exception as exc:
for error_handler in error_handlers:
if isinstance(exc, error_handler.exception_type):
print(error_handler.message(command_input, exc))
return 1
raise
def _collection_already_exists_message(
command: CreateCollectionInput,
_: Exception,
) -> str:
return f"Collection '{command.collection}' already exists."
def _collection_not_found_handler(
_: type[CollectionCommandT],
) -> CliErrorHandler[CollectionCommandT]:
return CliErrorHandler(
exception_type=NotFoundError,
message=_collection_not_found_message,
)
def _collection_not_found_message(command: CollectionCommandT, _: Exception) -> str:
return f"Collection '{command.collection}' does not exist."
def _file_not_found_message(command: AddDataInput, _: Exception) -> str:
return f"The file {command.file} was not found."
def _exception_message(_: DeleteRecordsInput, exc: Exception) -> str:
return str(exc)
-123
View File
@@ -1,123 +0,0 @@
from __future__ import annotations
import argparse
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class ArgumentSpec:
name: str
help: str
required: bool = False
metavar: str | None = None
@dataclass(frozen=True, slots=True)
class CommandSpec:
name: str
aliases: tuple[str, ...]
help: str
arguments: tuple[ArgumentSpec, ...] = ()
COMMAND_SPECS: tuple[CommandSpec, ...] = (
CommandSpec(
name="list-collections",
aliases=("lc",),
help="List all collections stored in the local Chroma database.",
),
CommandSpec(
name="create-collection",
aliases=("cc",),
help="Create a collection in the local Chroma database.",
arguments=(ArgumentSpec("collection", "Name of the collection to create."),),
),
CommandSpec(
name="delete-collection",
aliases=("dc",),
help="Delete a collection from the local Chroma database.",
arguments=(ArgumentSpec("collection", "Name of the collection to delete."),),
),
CommandSpec(
name="count",
aliases=("co",),
help="Count records in a collection from the local Chroma database.",
arguments=(ArgumentSpec("collection", "Name of the collection to count."),),
),
CommandSpec(
name="add-data",
aliases=("ad",),
help=(
"Chunk, embed, and add a file to a collection in the local Chroma database."
),
arguments=(
ArgumentSpec("collection", "Name of the target collection."),
ArgumentSpec(
"file", "Path to the file to chunk and add to the collection."
),
),
),
CommandSpec(
name="query",
aliases=("q",),
help="Query a collection with the provided text.",
arguments=(
ArgumentSpec("collection", "Name of the target collection."),
ArgumentSpec("query_text", "The text to query."),
),
),
CommandSpec(
name="delete",
aliases=("del",),
help="Delete records from a collection using a metadata filter.",
arguments=(
ArgumentSpec("collection", "Name of the target collection."),
ArgumentSpec(
"--where",
"Metadata filter in the format <condition>=<value>.",
required=True,
metavar="CONDITION=VALUE",
),
),
),
)
def _add_command(
subparsers: argparse._SubParsersAction[argparse.ArgumentParser],
command: CommandSpec,
) -> None:
subparser = subparsers.add_parser(
command.name,
aliases=list(command.aliases),
help=command.help,
description=command.help,
)
for argument in command.arguments:
if argument.name.startswith("-"):
subparser.add_argument(
argument.name,
help=argument.help,
metavar=argument.metavar,
required=argument.required,
)
continue
subparser.add_argument(
argument.name,
help=argument.help,
metavar=argument.metavar,
)
subparser.set_defaults(command=command.name)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Inspect local Chroma collections.")
subparsers = parser.add_subparsers(dest="command", required=True)
for command in COMMAND_SPECS:
_add_command(subparsers, command)
return parser
-52
View File
@@ -1,52 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class ListCollectionsInput:
pass
@dataclass(frozen=True, slots=True)
class CreateCollectionInput:
collection: str
@dataclass(frozen=True, slots=True)
class DeleteCollectionInput:
collection: str
@dataclass(frozen=True, slots=True)
class CountCollectionInput:
collection: str
@dataclass(frozen=True, slots=True)
class AddDataInput:
collection: str
file: str
@dataclass(frozen=True, slots=True)
class QueryInput:
collection: str
query_text: str
@dataclass(frozen=True, slots=True)
class DeleteRecordsInput:
collection: str
where: str
CommandInput = (
ListCollectionsInput
| CreateCollectionInput
| DeleteCollectionInput
| CountCollectionInput
| AddDataInput
| QueryInput
| DeleteRecordsInput
)
+3 -4
View File
@@ -1,10 +1,9 @@
from __future__ import annotations from __future__ import annotations
from chromy.command_inputs import AddDataInput
from chromy.utilities import ingest_file from chromy.utilities import ingest_file
def handle_add_data(command: AddDataInput) -> int: def handle_add_data(collection: str, file: str) -> int:
records_added = ingest_file(command.collection, command.file) records_added = ingest_file(collection, file)
print(f"Added {records_added} records to collection '{command.collection}'.") print(f"Added {records_added} records to collection '{collection}'.")
return 0 return 0
+2 -3
View File
@@ -1,9 +1,8 @@
from __future__ import annotations from __future__ import annotations
from chromy.chroma_functions import count_collection from chromy.chroma_functions import count_collection
from chromy.command_inputs import CountCollectionInput
def handle_count_collection(command: CountCollectionInput) -> int: def handle_count_collection(collection: str) -> int:
print(count_collection(command.collection)) print(count_collection(collection))
return 0 return 0
+2 -3
View File
@@ -1,10 +1,9 @@
from __future__ import annotations from __future__ import annotations
from chromy.chroma_functions import create_collection from chromy.chroma_functions import create_collection
from chromy.command_inputs import CreateCollectionInput
def handle_create_collection(command: CreateCollectionInput) -> int: def handle_create_collection(collection: str) -> int:
collection_name = create_collection(command.collection) collection_name = create_collection(collection)
print(f"Created collection '{collection_name}'.") print(f"Created collection '{collection_name}'.")
return 0 return 0
+7 -8
View File
@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
from chromy.chroma_functions import delete_collection, delete_data from chromy.chroma_functions import delete_collection, delete_data
from chromy.command_inputs import DeleteCollectionInput, DeleteRecordsInput
def _parse_where_clause(where_clause: str) -> dict[str, str]: def _parse_where_clause(where_clause: str) -> dict[str, str]:
@@ -19,18 +18,18 @@ def _parse_where_clause(where_clause: str) -> dict[str, str]:
return {condition: value} return {condition: value}
def handle_delete_collection(command: DeleteCollectionInput) -> int: def handle_delete_collection(collection: str) -> int:
delete_collection(command.collection) delete_collection(collection)
print(f"Deleted collection '{command.collection}'.") print(f"Deleted collection '{collection}'.")
return 0 return 0
def handle_delete_records(command: DeleteRecordsInput) -> int: def handle_delete_records(collection: str, where_clause: str) -> int:
where = _parse_where_clause(command.where) where = _parse_where_clause(where_clause)
deleted = delete_data(command.collection, where) deleted = delete_data(collection, where)
condition, value = next(iter(where.items())) condition, value = next(iter(where.items()))
print( print(
f"Deleted {deleted} record(s) from collection '{command.collection}' " f"Deleted {deleted} record(s) from collection '{collection}' "
f"where {condition}={value}." f"where {condition}={value}."
) )
return 0 return 0
+1 -2
View File
@@ -1,11 +1,10 @@
from __future__ import annotations from __future__ import annotations
from chromy.chroma_functions import list_collections from chromy.chroma_functions import list_collections
from chromy.command_inputs import ListCollectionsInput
from chromy.utilities import print_lines from chromy.utilities import print_lines
def handle_list_collections(_: ListCollectionsInput) -> int: def handle_list_collections() -> int:
collections = list_collections() collections = list_collections()
if not collections: if not collections:
print("No collections found.") print("No collections found.")
+2 -3
View File
@@ -1,10 +1,9 @@
from __future__ import annotations from __future__ import annotations
from chromy.command_inputs import QueryInput
from chromy.utilities import format_query_result, print_lines, run_query from chromy.utilities import format_query_result, print_lines, run_query
def handle_query(command: QueryInput) -> int: def handle_query(collection: str, query_text: str) -> int:
result = run_query(command.collection, command.query_text) result = run_query(collection, query_text)
print_lines(format_query_result(result)) print_lines(format_query_result(result))
return 0 return 0
+4 -6
View File
@@ -2,15 +2,13 @@ from __future__ import annotations
from dotenv import load_dotenv from dotenv import load_dotenv
from chromy.cli_app import execute_command from chromy.cli import app
from chromy.cli_parser import build_parser
def main() -> int: def main() -> None:
load_dotenv() load_dotenv()
args = build_parser().parse_args() app()
return execute_command(args)
if __name__ == "__main__": if __name__ == "__main__":
raise SystemExit(main()) main()
+1
View File
@@ -16,6 +16,7 @@ dependencies = [
"semchunk>=4.0.0", "semchunk>=4.0.0",
"tiktoken>=0.12.0", "tiktoken>=0.12.0",
"transformers>=5.5.4", "transformers>=5.5.4",
"typer>=0.24.1",
] ]
[project.scripts] [project.scripts]
+131
View File
@@ -0,0 +1,131 @@
from __future__ import annotations
import unittest
from collections.abc import Sequence
from unittest.mock import patch
from click.testing import Result
from typer.testing import CliRunner
from chromy.cli import app
class CliTests(unittest.TestCase):
def test_list_collections_and_alias(self) -> None:
for command in ("list-collections", "lc"):
with patch(
"chromy.handlers.list_collections.list_collections",
return_value=[],
):
result = _invoke([command])
self.assertEqual(result.exit_code, 0)
self.assertEqual(result.stdout, "No collections found.\n")
def test_create_collection_and_alias(self) -> None:
for command in ("create-collection", "cc"):
with patch(
"chromy.handlers.create_collection.create_collection",
return_value="notes",
) as create_collection:
result = _invoke([command, "notes"])
create_collection.assert_called_once_with("notes")
self.assertEqual(result.exit_code, 0)
self.assertEqual(result.stdout, "Created collection 'notes'.\n")
def test_delete_collection_and_alias(self) -> None:
for command in ("delete-collection", "dc"):
with patch(
"chromy.handlers.delete_collection.delete_collection",
) as delete_collection:
result = _invoke([command, "notes"])
delete_collection.assert_called_once_with("notes")
self.assertEqual(result.exit_code, 0)
self.assertEqual(result.stdout, "Deleted collection 'notes'.\n")
def test_count_and_alias(self) -> None:
for command in ("count", "co"):
with patch(
"chromy.handlers.count_collection.count_collection",
return_value=7,
) as count_collection:
result = _invoke([command, "notes"])
count_collection.assert_called_once_with("notes")
self.assertEqual(result.exit_code, 0)
self.assertEqual(result.stdout, "7\n")
def test_add_data_and_alias(self) -> None:
for command in ("add-data", "ad"):
with patch(
"chromy.handlers.add_data.ingest_file",
return_value=3,
) as ingest_file:
result = _invoke([command, "notes", "romeo_and_juliet.txt"])
ingest_file.assert_called_once_with("notes", "romeo_and_juliet.txt")
self.assertEqual(result.exit_code, 0)
self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n")
def test_query_and_alias(self) -> None:
query_result = {"ids": [["1"]], "documents": [["hello"]]}
for command in ("query", "q"):
with (
patch(
"chromy.handlers.query.run_query", return_value=query_result
) as run,
patch(
"chromy.handlers.query.format_query_result",
return_value=["Query results:", "1"],
) as format_result,
):
result = _invoke([command, "notes", "Where is Romeo?"])
run.assert_called_once_with("notes", "Where is Romeo?")
format_result.assert_called_once_with(query_result)
self.assertEqual(result.exit_code, 0)
self.assertEqual(result.stdout, "Query results:\n1\n")
def test_delete_records_and_alias(self) -> None:
for command in ("delete", "del"):
with patch(
"chromy.handlers.delete_collection.delete_data",
return_value=2,
) as delete_data:
result = _invoke(
[command, "notes", "--where", " file_name = play.txt "],
)
delete_data.assert_called_once_with("notes", {"file_name": "play.txt"})
self.assertEqual(result.exit_code, 0)
self.assertEqual(
result.stdout,
"Deleted 2 record(s) from collection 'notes' "
"where file_name=play.txt.\n",
)
def test_invalid_delete_filter_keeps_user_facing_error(self) -> None:
result = _invoke(["delete", "notes", "--where", "file_name"])
self.assertEqual(result.exit_code, 1)
self.assertEqual(
result.stdout,
"Invalid --where value. Expected <condition>=<value>.\n",
)
def test_delete_requires_where_option(self) -> None:
result = _invoke(["delete", "notes"])
self.assertNotEqual(result.exit_code, 0)
self.assertIn("Missing option", result.output)
def _invoke(arguments: Sequence[str]) -> Result:
return CliRunner().invoke(app, list(arguments))
if __name__ == "__main__":
unittest.main()
-98
View File
@@ -1,98 +0,0 @@
from __future__ import annotations
import io
import unittest
from argparse import Namespace
from collections.abc import Sequence
from contextlib import redirect_stdout
from chromy.cli_app import build_command_input, execute_command
from chromy.cli_parser import build_parser
from chromy.command_inputs import (
AddDataInput,
CountCollectionInput,
CreateCollectionInput,
DeleteCollectionInput,
DeleteRecordsInput,
ListCollectionsInput,
QueryInput,
)
class BuildCommandInputTests(unittest.TestCase):
def test_parser_converts_list_collections_and_alias(self) -> None:
self.assertEqual(_parse_input(["list-collections"]), ListCollectionsInput())
self.assertEqual(_parse_input(["lc"]), ListCollectionsInput())
def test_parser_converts_create_collection_and_alias(self) -> None:
expected = CreateCollectionInput(collection="notes")
self.assertEqual(_parse_input(["create-collection", "notes"]), expected)
self.assertEqual(_parse_input(["cc", "notes"]), expected)
def test_parser_converts_delete_collection_and_alias(self) -> None:
expected = DeleteCollectionInput(collection="notes")
self.assertEqual(_parse_input(["delete-collection", "notes"]), expected)
self.assertEqual(_parse_input(["dc", "notes"]), expected)
def test_parser_converts_count_and_alias(self) -> None:
expected = CountCollectionInput(collection="notes")
self.assertEqual(_parse_input(["count", "notes"]), expected)
self.assertEqual(_parse_input(["co", "notes"]), expected)
def test_parser_converts_add_data_and_alias(self) -> None:
expected = AddDataInput(collection="notes", file="romeo_and_juliet.txt")
self.assertEqual(
_parse_input(["add-data", "notes", "romeo_and_juliet.txt"]),
expected,
)
self.assertEqual(
_parse_input(["ad", "notes", "romeo_and_juliet.txt"]),
expected,
)
def test_parser_converts_query_and_alias(self) -> None:
expected = QueryInput(collection="notes", query_text="Where is Romeo?")
self.assertEqual(
_parse_input(["query", "notes", "Where is Romeo?"]),
expected,
)
self.assertEqual(_parse_input(["q", "notes", "Where is Romeo?"]), expected)
def test_parser_converts_delete_records_and_alias(self) -> None:
expected = DeleteRecordsInput(collection="notes", where="file_name=play.txt")
self.assertEqual(
_parse_input(["delete", "notes", "--where", "file_name=play.txt"]),
expected,
)
self.assertEqual(
_parse_input(["del", "notes", "--where", "file_name=play.txt"]),
expected,
)
def test_invalid_delete_filter_keeps_user_facing_error(self) -> None:
args = Namespace(command="delete", collection="notes", where="file_name")
output = io.StringIO()
with redirect_stdout(output):
exit_code = execute_command(args)
self.assertEqual(exit_code, 1)
self.assertEqual(
output.getvalue().strip(),
"Invalid --where value. Expected <condition>=<value>.",
)
self.assertFalse(hasattr(args, "error_message"))
def _parse_input(argv: Sequence[str]) -> object:
return build_command_input(build_parser().parse_args(argv))
if __name__ == "__main__":
unittest.main()
+13 -23
View File
@@ -7,15 +7,6 @@ from contextlib import redirect_stdout
from typing import TypeVar from typing import TypeVar
from unittest.mock import patch from unittest.mock import patch
from chromy.command_inputs import (
AddDataInput,
CountCollectionInput,
CreateCollectionInput,
DeleteCollectionInput,
DeleteRecordsInput,
ListCollectionsInput,
QueryInput,
)
from chromy.handlers.add_data import handle_add_data from chromy.handlers.add_data import handle_add_data
from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection from chromy.handlers.create_collection import handle_create_collection
@@ -36,7 +27,6 @@ class HandlerTests(unittest.TestCase):
): ):
exit_code, output = _capture_output( exit_code, output = _capture_output(
handle_list_collections, handle_list_collections,
ListCollectionsInput(),
) )
self.assertEqual(exit_code, 0) self.assertEqual(exit_code, 0)
@@ -49,7 +39,6 @@ class HandlerTests(unittest.TestCase):
): ):
exit_code, output = _capture_output( exit_code, output = _capture_output(
handle_list_collections, handle_list_collections,
ListCollectionsInput(),
) )
self.assertEqual(exit_code, 0) self.assertEqual(exit_code, 0)
@@ -62,7 +51,7 @@ class HandlerTests(unittest.TestCase):
) as create_collection: ) as create_collection:
exit_code, output = _capture_output( exit_code, output = _capture_output(
handle_create_collection, handle_create_collection,
CreateCollectionInput(collection="notes"), "notes",
) )
create_collection.assert_called_once_with("notes") create_collection.assert_called_once_with("notes")
@@ -73,7 +62,7 @@ class HandlerTests(unittest.TestCase):
with patch("chromy.handlers.delete_collection.delete_collection") as delete: with patch("chromy.handlers.delete_collection.delete_collection") as delete:
exit_code, output = _capture_output( exit_code, output = _capture_output(
handle_delete_collection, handle_delete_collection,
DeleteCollectionInput(collection="notes"), "notes",
) )
delete.assert_called_once_with("notes") delete.assert_called_once_with("notes")
@@ -87,7 +76,7 @@ class HandlerTests(unittest.TestCase):
) as count: ) as count:
exit_code, output = _capture_output( exit_code, output = _capture_output(
handle_count_collection, handle_count_collection,
CountCollectionInput(collection="notes"), "notes",
) )
count.assert_called_once_with("notes") count.assert_called_once_with("notes")
@@ -101,7 +90,8 @@ class HandlerTests(unittest.TestCase):
) as ingest_file: ) as ingest_file:
exit_code, output = _capture_output( exit_code, output = _capture_output(
handle_add_data, handle_add_data,
AddDataInput(collection="notes", file="romeo_and_juliet.txt"), "notes",
"romeo_and_juliet.txt",
) )
ingest_file.assert_called_once_with("notes", "romeo_and_juliet.txt") ingest_file.assert_called_once_with("notes", "romeo_and_juliet.txt")
@@ -119,7 +109,8 @@ class HandlerTests(unittest.TestCase):
): ):
exit_code, output = _capture_output( exit_code, output = _capture_output(
handle_query, handle_query,
QueryInput(collection="notes", query_text="hello"), "notes",
"hello",
) )
run.assert_called_once_with("notes", "hello") run.assert_called_once_with("notes", "hello")
@@ -134,7 +125,8 @@ class HandlerTests(unittest.TestCase):
) as delete_data: ) as delete_data:
exit_code, output = _capture_output( exit_code, output = _capture_output(
handle_delete_records, handle_delete_records,
DeleteRecordsInput(collection="notes", where=" file_name = play.txt "), "notes",
" file_name = play.txt ",
) )
delete_data.assert_called_once_with("notes", {"file_name": "play.txt"}) delete_data.assert_called_once_with("notes", {"file_name": "play.txt"})
@@ -149,19 +141,17 @@ class HandlerTests(unittest.TestCase):
ValueError, ValueError,
"Invalid --where value. Expected <condition>=<value>.", "Invalid --where value. Expected <condition>=<value>.",
): ):
handle_delete_records( handle_delete_records("notes", "file_name")
DeleteRecordsInput(collection="notes", where="file_name")
)
def _capture_output( def _capture_output(
handler: Callable[[CommandT], int], handler: Callable[..., int],
command: CommandT, *arguments: CommandT,
) -> tuple[int, str]: ) -> tuple[int, str]:
output = io.StringIO() output = io.StringIO()
with redirect_stdout(output): with redirect_stdout(output):
exit_code = handler(command) exit_code = handler(*arguments)
return exit_code, output.getvalue() return exit_code, output.getvalue()
Generated
+2
View File
@@ -264,6 +264,7 @@ dependencies = [
{ name = "semchunk" }, { name = "semchunk" },
{ name = "tiktoken" }, { name = "tiktoken" },
{ name = "transformers" }, { name = "transformers" },
{ name = "typer" },
] ]
[package.dev-dependencies] [package.dev-dependencies]
@@ -283,6 +284,7 @@ requires-dist = [
{ name = "semchunk", specifier = ">=4.0.0" }, { name = "semchunk", specifier = ">=4.0.0" },
{ name = "tiktoken", specifier = ">=0.12.0" }, { name = "tiktoken", specifier = ">=0.12.0" },
{ name = "transformers", specifier = ">=5.5.4" }, { name = "transformers", specifier = ">=5.5.4" },
{ name = "typer", specifier = ">=0.24.1" },
] ]
[package.metadata.requires-dev] [package.metadata.requires-dev]