replace argparse.Namespace plumbing with typed command inputs

This commit is contained in:
Matteo Rosati
2026-04-22 16:03:51 +02:00
parent 8ebab832d5
commit 2962a2e088
15 changed files with 560 additions and 115 deletions
+144 -82
View File
@@ -1,11 +1,22 @@
from __future__ import annotations
from argparse import Namespace
from collections.abc import Callable
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Generic, Protocol, TypeVar, assert_never
from chromadb.errors import InternalError, NotFoundError
from chromy.command_inputs import (
AddDataInput,
CommandInput,
CountCollectionInput,
CreateCollectionInput,
DeleteCollectionInput,
DeleteRecordsInput,
ListCollectionsInput,
QueryInput,
)
from chromy.handlers.add_data import handle_add_data
from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection
@@ -17,98 +28,149 @@ from chromy.handlers.list_collections import handle_list_collections
from chromy.handlers.query import handle_query
CommandHandler = Callable[[Namespace], int]
ErrorMessageBuilder = Callable[[Namespace], str]
CommandT = TypeVar("CommandT", bound=CommandInput)
CollectionCommandT = TypeVar("CollectionCommandT", bound="HasCollection")
CommandHandler = Callable[[CommandT], int]
ErrorMessageBuilder = Callable[[CommandT, Exception], str]
class HasCollection(Protocol):
collection: str
@dataclass(frozen=True, slots=True)
class CliErrorHandler:
exception_type: type[BaseException]
message: ErrorMessageBuilder
class CliErrorHandler(Generic[CommandT]):
exception_type: type[Exception]
message: ErrorMessageBuilder[CommandT]
@dataclass(frozen=True, slots=True)
class CommandConfig:
handler: CommandHandler
error_handlers: tuple[CliErrorHandler, ...] = ()
def build_command_input(args: Namespace) -> CommandInput:
command = str(args.command)
COMMANDS: dict[str, CommandConfig] = {
"list-collections": CommandConfig(handler=handle_list_collections),
"create-collection": CommandConfig(
handler=handle_create_collection,
error_handlers=(
CliErrorHandler(
exception_type=InternalError,
message=lambda args: f"Collection '{args.collection}' already exists.",
),
),
),
"delete-collection": CommandConfig(
handler=handle_delete_collection,
error_handlers=(
CliErrorHandler(
exception_type=NotFoundError,
message=lambda args: f"Collection '{args.collection}' does not exist.",
),
),
),
"count": CommandConfig(
handler=handle_count_collection,
error_handlers=(
CliErrorHandler(
exception_type=NotFoundError,
message=lambda args: f"Collection '{args.collection}' does not exist.",
),
),
),
"add-data": CommandConfig(
handler=handle_add_data,
error_handlers=(
CliErrorHandler(
exception_type=NotFoundError,
message=lambda args: f"Collection '{args.collection}' does not exist.",
),
CliErrorHandler(
exception_type=FileNotFoundError,
message=lambda args: f"The file {args.file} was not found.",
),
),
),
"query": CommandConfig(
handler=handle_query,
error_handlers=(
CliErrorHandler(
exception_type=NotFoundError,
message=lambda args: f"Collection '{args.collection}' does not exist.",
),
),
),
"delete": CommandConfig(
handler=handle_delete_records,
error_handlers=(
CliErrorHandler(
exception_type=NotFoundError,
message=lambda args: f"Collection '{args.collection}' does not exist.",
),
CliErrorHandler(
exception_type=ValueError,
message=lambda args: str(args.error_message),
),
),
),
}
match command:
case "list-collections":
return ListCollectionsInput()
case "create-collection":
return CreateCollectionInput(collection=str(args.collection))
case "delete-collection":
return DeleteCollectionInput(collection=str(args.collection))
case "count":
return CountCollectionInput(collection=str(args.collection))
case "add-data":
return AddDataInput(collection=str(args.collection), file=str(args.file))
case "query":
return QueryInput(
collection=str(args.collection),
query_text=str(args.query_text),
)
case "delete":
return DeleteRecordsInput(
collection=str(args.collection),
where=str(args.where),
)
case _:
raise ValueError(f"Unknown command: {command}")
def execute_command(args: Namespace) -> int:
command = COMMANDS[args.command]
args.error_message = "An unexpected value was provided."
command_input = build_command_input(args)
match command_input:
case ListCollectionsInput():
return _run_command(command_input, handle_list_collections)
case CreateCollectionInput():
return _run_command(
command_input,
handle_create_collection,
(
CliErrorHandler(
exception_type=InternalError,
message=_collection_already_exists_message,
),
),
)
case DeleteCollectionInput():
return _run_command(
command_input,
handle_delete_collection,
(_collection_not_found_handler(),),
)
case CountCollectionInput():
return _run_command(
command_input,
handle_count_collection,
(_collection_not_found_handler(),),
)
case AddDataInput():
return _run_command(
command_input,
handle_add_data,
(
_collection_not_found_handler(),
CliErrorHandler(
exception_type=FileNotFoundError,
message=_file_not_found_message,
),
),
)
case QueryInput():
return _run_command(
command_input,
handle_query,
(_collection_not_found_handler(),),
)
case DeleteRecordsInput():
return _run_command(
command_input,
handle_delete_records,
(
_collection_not_found_handler(),
CliErrorHandler(
exception_type=ValueError,
message=_exception_message,
),
),
)
assert_never(command_input)
def _run_command(
command_input: CommandT,
handler: CommandHandler[CommandT],
error_handlers: Sequence[CliErrorHandler[CommandT]] = (),
) -> int:
try:
return command.handler(args)
except BaseException as exc:
for error_handler in command.error_handlers:
return handler(command_input)
except Exception as exc:
for error_handler in error_handlers:
if isinstance(exc, error_handler.exception_type):
print(error_handler.message(args))
print(error_handler.message(command_input, exc))
return 1
raise
def _collection_already_exists_message(
command: CreateCollectionInput,
_: Exception,
) -> str:
return f"Collection '{command.collection}' already exists."
def _collection_not_found_handler() -> CliErrorHandler[CollectionCommandT]:
return CliErrorHandler(
exception_type=NotFoundError,
message=_collection_not_found_message,
)
def _collection_not_found_message(command: HasCollection, _: Exception) -> str:
return f"Collection '{command.collection}' does not exist."
def _file_not_found_message(command: AddDataInput, _: Exception) -> str:
return f"The file {command.file} was not found."
def _exception_message(_: DeleteRecordsInput, exc: Exception) -> str:
return str(exc)
+52
View File
@@ -0,0 +1,52 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class ListCollectionsInput:
pass
@dataclass(frozen=True, slots=True)
class CreateCollectionInput:
collection: str
@dataclass(frozen=True, slots=True)
class DeleteCollectionInput:
collection: str
@dataclass(frozen=True, slots=True)
class CountCollectionInput:
collection: str
@dataclass(frozen=True, slots=True)
class AddDataInput:
collection: str
file: str
@dataclass(frozen=True, slots=True)
class QueryInput:
collection: str
query_text: str
@dataclass(frozen=True, slots=True)
class DeleteRecordsInput:
collection: str
where: str
CommandInput = (
ListCollectionsInput
| CreateCollectionInput
| DeleteCollectionInput
| CountCollectionInput
| AddDataInput
| QueryInput
| DeleteRecordsInput
)
+5 -4
View File
@@ -1,9 +1,10 @@
from argparse import Namespace
from __future__ import annotations
from chromy.command_inputs import AddDataInput
from chromy.utilities import ingest_file
def handle_add_data(args: Namespace) -> int:
records_added = ingest_file(args.collection, args.file)
print(f"Added {records_added} records to collection '{args.collection}'.")
def handle_add_data(command: AddDataInput) -> int:
records_added = ingest_file(command.collection, command.file)
print(f"Added {records_added} records to collection '{command.collection}'.")
return 0
+4 -3
View File
@@ -1,8 +1,9 @@
from argparse import Namespace
from __future__ import annotations
from chromy.chroma_functions import count_collection
from chromy.command_inputs import CountCollectionInput
def handle_count_collection(args: Namespace) -> int:
print(count_collection(args.collection))
def handle_count_collection(command: CountCollectionInput) -> int:
print(count_collection(command.collection))
return 0
+4 -3
View File
@@ -1,9 +1,10 @@
from argparse import Namespace
from __future__ import annotations
from chromy.chroma_functions import create_collection
from chromy.command_inputs import CreateCollectionInput
def handle_create_collection(args: Namespace) -> int:
collection_name = create_collection(args.collection)
def handle_create_collection(command: CreateCollectionInput) -> int:
collection_name = create_collection(command.collection)
print(f"Created collection '{collection_name}'.")
return 0
+13 -15
View File
@@ -1,40 +1,38 @@
from argparse import Namespace
from __future__ import annotations
from chromy.chroma_functions import delete_collection, delete_data
from chromy.command_inputs import DeleteCollectionInput, DeleteRecordsInput
def _parse_where_clause(where_clause: str) -> dict[str, str]:
condition, separator, value = where_clause.partition("=")
if separator == "":
raise ValueError("Invalid --where value. Expected <condition>=<value>.")
raise ValueError(
"Invalid --where value. Expected <condition>=<value>.")
condition = condition.strip()
value = value.strip()
if not condition or not value:
raise ValueError("Invalid --where value. Expected <condition>=<value>.")
raise ValueError(
"Invalid --where value. Expected <condition>=<value>.")
return {condition: value}
def handle_delete_collection(args: Namespace) -> int:
delete_collection(args.collection)
print(f"Deleted collection '{args.collection}'.")
def handle_delete_collection(command: DeleteCollectionInput) -> int:
delete_collection(command.collection)
print(f"Deleted collection '{command.collection}'.")
return 0
def handle_delete_records(args: Namespace) -> int:
try:
where = _parse_where_clause(args.where)
except ValueError as exc:
args.error_message = str(exc)
raise
deleted = delete_data(args.collection, where)
def handle_delete_records(command: DeleteRecordsInput) -> int:
where = _parse_where_clause(command.where)
deleted = delete_data(command.collection, where)
condition, value = next(iter(where.items()))
print(
f"Deleted {deleted} record(s) from collection '{args.collection}' "
f"Deleted {deleted} record(s) from collection '{command.collection}' "
f"where {condition}={value}."
)
return 0
+3 -2
View File
@@ -1,10 +1,11 @@
from argparse import Namespace
from __future__ import annotations
from chromy.command_inputs import ListCollectionsInput
from chromy.chroma_functions import list_collections
from chromy.utilities import print_lines
def handle_list_collections(_: Namespace) -> int:
def handle_list_collections(_: ListCollectionsInput) -> int:
collections = list_collections()
if not collections:
print("No collections found.")
+4 -3
View File
@@ -1,9 +1,10 @@
from argparse import Namespace
from __future__ import annotations
from chromy.command_inputs import QueryInput
from chromy.utilities import format_query_result, print_lines, run_query
def handle_query(args: Namespace) -> int:
result = run_query(args.collection, args.query_text)
def handle_query(command: QueryInput) -> int:
result = run_query(command.collection, command.query_text)
print_lines(format_query_result(result))
return 0