diff --git a/chroma_functions.py b/chroma_functions.py index ec8d045..f4f4994 100644 --- a/chroma_functions.py +++ b/chroma_functions.py @@ -44,6 +44,13 @@ def delete_collection(name: str) -> None: client.delete_collection(name=name) +def delete_data(collection_name: str, where: dict[str, str]) -> int: + _, collection = _get_client_and_collection(collection_name) + result = collection.delete(where=where) + + return int(result.get("deleted", 0)) + + def count_collection(collection_name: str) -> int: _, collection = _get_client_and_collection(collection_name) diff --git a/cli_app.py b/cli_app.py index cec76db..6ccaf8a 100644 --- a/cli_app.py +++ b/cli_app.py @@ -9,7 +9,10 @@ from chromadb.errors import InternalError, NotFoundError from handlers.add_data import handle_add_data from handlers.count_collection import handle_count_collection from handlers.create_collection import handle_create_collection -from handlers.delete_collection import handle_delete_collection +from handlers.delete_collection import ( + handle_delete_collection, + handle_delete_records, +) from handlers.list_collections import handle_list_collections from handlers.query import handle_query @@ -81,11 +84,25 @@ COMMANDS: dict[str, CommandConfig] = { ), ), ), + "delete": CommandConfig( + handler=handle_delete_records, + error_handlers=( + CliErrorHandler( + exception_type=NotFoundError, + message=lambda args: f"Collection '{args.collection}' does not exist.", + ), + CliErrorHandler( + exception_type=ValueError, + message=lambda args: str(args.error_message), + ), + ), + ), } def execute_command(args: Namespace) -> int: command = COMMANDS[args.command] + args.error_message = "An unexpected value was provided." try: return command.handler(args) diff --git a/cli_parser.py b/cli_parser.py index d8456f3..569772e 100644 --- a/cli_parser.py +++ b/cli_parser.py @@ -8,6 +8,8 @@ from dataclasses import dataclass class ArgumentSpec: name: str help: str + required: bool = False + metavar: str | None = None @dataclass(frozen=True, slots=True) @@ -62,6 +64,20 @@ COMMAND_SPECS: tuple[CommandSpec, ...] = ( ArgumentSpec("query_text", "The text to query."), ), ), + CommandSpec( + name="delete", + aliases=("del",), + help="Delete records from a collection using a metadata filter.", + arguments=( + ArgumentSpec("collection", "Name of the target collection."), + ArgumentSpec( + "--where", + "Metadata filter in the format =.", + required=True, + metavar="CONDITION=VALUE", + ), + ), + ), ) @@ -77,7 +93,15 @@ def _add_command( ) for argument in command.arguments: - subparser.add_argument(argument.name, help=argument.help) + argument_kwargs: dict[str, object] = {"help": argument.help} + + if argument.metavar is not None: + argument_kwargs["metavar"] = argument.metavar + + if argument.name.startswith("-"): + argument_kwargs["required"] = argument.required + + subparser.add_argument(argument.name, **argument_kwargs) subparser.set_defaults(command=command.name) diff --git a/handlers/delete_collection.py b/handlers/delete_collection.py index 058c779..c24ed96 100644 --- a/handlers/delete_collection.py +++ b/handlers/delete_collection.py @@ -1,9 +1,40 @@ from argparse import Namespace -from chroma_functions import delete_collection +from chroma_functions import delete_collection, delete_data + + +def _parse_where_clause(where_clause: str) -> dict[str, str]: + condition, separator, value = where_clause.partition("=") + + if separator == "": + raise ValueError("Invalid --where value. Expected =.") + + condition = condition.strip() + value = value.strip() + + if not condition or not value: + raise ValueError("Invalid --where value. Expected =.") + + return {condition: value} def handle_delete_collection(args: Namespace) -> int: delete_collection(args.collection) print(f"Deleted collection '{args.collection}'.") return 0 + + +def handle_delete_records(args: Namespace) -> int: + try: + where = _parse_where_clause(args.where) + except ValueError as exc: + args.error_message = str(exc) + raise + + deleted = delete_data(args.collection, where) + condition, value = next(iter(where.items())) + print( + f"Deleted {deleted} record(s) from collection '{args.collection}' " + f"where {condition}={value}." + ) + return 0