fix syntax and types

This commit is contained in:
2026-04-24 18:23:02 +02:00
parent 948f8500be
commit c5b6b196b5
8 changed files with 21 additions and 25 deletions
+4 -3
View File
@@ -58,7 +58,9 @@ def count_collection(collection_name: str) -> str:
_, collection = _get_client_and_collection(collection_name) _, collection = _get_client_and_collection(collection_name)
count = collection.count() count = collection.count()
return f"The '{collection_name}' collection contains [bold green]{count}[/] records." return (
f"The '{collection_name}' collection contains [bold green]{count}[/] records."
)
def add_data( def add_data(
@@ -71,8 +73,7 @@ def add_data(
_, collection = _get_client_and_collection(collection_name) _, collection = _get_client_and_collection(collection_name)
embeddings: list[Sequence[float]] = [record["embedding"] embeddings: list[Sequence[float]] = [record["embedding"] for record in data]
for record in data]
collection.add( collection.add(
ids=[str(uuid4()) for _ in data], ids=[str(uuid4()) for _ in data],
+3 -4
View File
@@ -3,16 +3,16 @@ from __future__ import annotations
from typing import Annotated, Callable from typing import Annotated, Callable
import typer import typer
from rich import print
from chromadb.errors import InternalError, NotFoundError from chromadb.errors import InternalError, NotFoundError
from rich import print
from chromy.handlers.import_data import handle_import
from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import ( from chromy.handlers.delete_collection import (
handle_delete_collection, handle_delete_collection,
handle_delete_records, handle_delete_records,
) )
from chromy.handlers.import_data import handle_import
from chromy.handlers.list_collections import handle_list_collections from chromy.handlers.list_collections import handle_list_collections
from chromy.handlers.query import handle_query from chromy.handlers.query import handle_query
@@ -114,8 +114,7 @@ def import_data(
], ],
file: Annotated[ file: Annotated[
str, str,
typer.Argument( typer.Argument(help="Path to the file to chunk and add to the collection."),
help="Path to the file to chunk and add to the collection."),
], ],
) -> None: ) -> None:
try: try:
+1
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from rich import print from rich import print
from chromy.chroma_functions import count_collection from chromy.chroma_functions import count_collection
+1
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from rich import print from rich import print
from chromy.chroma_functions import create_collection from chromy.chroma_functions import create_collection
+3 -4
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from rich import print from rich import print
from chromy.chroma_functions import delete_collection, delete_data from chromy.chroma_functions import delete_collection, delete_data
@@ -8,15 +9,13 @@ def _parse_where_clause(where_clause: str) -> dict[str, str]:
condition, separator, value = where_clause.partition("=") condition, separator, value = where_clause.partition("=")
if separator == "": if separator == "":
raise ValueError( raise ValueError("Invalid --where value. Expected <condition>=<value>.")
"Invalid --where value. Expected <condition>=<value>.")
condition = condition.strip() condition = condition.strip()
value = value.strip() value = value.strip()
if not condition or not value: if not condition or not value:
raise ValueError( raise ValueError("Invalid --where value. Expected <condition>=<value>.")
"Invalid --where value. Expected <condition>=<value>.")
return {condition: value} return {condition: value}
+1 -1
View File
@@ -14,7 +14,7 @@ from chromy.embed import embed
CONSOLE = Console() CONSOLE = Console()
def print_lines(lines: Sequence[str]) -> None: def print_lines(lines: Sequence[Rule | Text]) -> None:
for line in lines: for line in lines:
CONSOLE.print(line) CONSOLE.print(line)
+6 -11
View File
@@ -51,15 +51,13 @@ class CliTests(unittest.TestCase):
def test_create_collection_with_same_name(self) -> None: def test_create_collection_with_same_name(self) -> None:
with patch( with patch(
"chromy.handlers.create_collection.create_collection", "chromy.handlers.create_collection.create_collection",
side_effect=InternalError() side_effect=InternalError(),
) as create_collection: ) as create_collection:
result = _invoke(["create-collection", "notes"]) result = _invoke(["create-collection", "notes"])
create_collection.assert_called_once_with("notes") create_collection.assert_called_once_with("notes")
self.assertEqual(result.exit_code, 1) self.assertEqual(result.exit_code, 1)
self.assertEqual( self.assertEqual(result.stdout, "Error: Collection 'notes' already exists.\n")
result.stdout, "Error: Collection 'notes' already exists.\n")
def test_delete_collection(self) -> None: def test_delete_collection(self) -> None:
with patch( with patch(
@@ -74,14 +72,13 @@ class CliTests(unittest.TestCase):
def test_delete_non_existent_collection(self) -> None: def test_delete_non_existent_collection(self) -> None:
with patch( with patch(
"chromy.handlers.delete_collection.delete_collection", "chromy.handlers.delete_collection.delete_collection",
side_effect=NotFoundError() side_effect=NotFoundError(),
) as delete_collection: ) as delete_collection:
result = _invoke(["delete-collection", "notes"]) result = _invoke(["delete-collection", "notes"])
delete_collection.assert_called_once_with("notes") delete_collection.assert_called_once_with("notes")
self.assertEqual(result.exit_code, 1) self.assertEqual(result.exit_code, 1)
self.assertEqual( self.assertEqual(result.stdout, "Error: Collection 'notes' does not exist.\n")
result.stdout, "Error: Collection 'notes' does not exist.\n")
def test_count(self) -> None: def test_count(self) -> None:
with patch( with patch(
@@ -106,8 +103,7 @@ class CliTests(unittest.TestCase):
self._fixture_path("romeo_and_juliet.txt"), self._fixture_path("romeo_and_juliet.txt"),
) )
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
self.assertEqual( self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n")
result.stdout, "Added 3 records to collection 'notes'.\n")
def test_query(self) -> None: def test_query(self) -> None:
query_result = {"ids": [["1"]], "documents": [["hello"]]} query_result = {"ids": [["1"]], "documents": [["hello"]]}
@@ -139,8 +135,7 @@ class CliTests(unittest.TestCase):
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0)
self.assertEqual( self.assertEqual(
result.stdout, result.stdout,
"Deleted 2 record(s) from collection 'notes' " "Deleted 2 record(s) from collection 'notes' where file_name=play.txt.\n",
"where file_name=play.txt.\n",
) )
def test_invalid_delete_filter_keeps_user_facing_error(self) -> None: def test_invalid_delete_filter_keeps_user_facing_error(self) -> None:
+2 -2
View File
@@ -8,13 +8,13 @@ from pathlib import Path
from typing import TypeVar from typing import TypeVar
from unittest.mock import patch from unittest.mock import patch
from chromy.handlers.import_data import handle_import
from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.count_collection import handle_count_collection
from chromy.handlers.create_collection import handle_create_collection from chromy.handlers.create_collection import handle_create_collection
from chromy.handlers.delete_collection import ( from chromy.handlers.delete_collection import (
handle_delete_collection, handle_delete_collection,
handle_delete_records, handle_delete_records,
) )
from chromy.handlers.import_data import handle_import
from chromy.handlers.list_collections import handle_list_collections from chromy.handlers.list_collections import handle_list_collections
from chromy.handlers.query import handle_query from chromy.handlers.query import handle_query
@@ -154,7 +154,7 @@ class HandlerTests(unittest.TestCase):
def _capture_output( def _capture_output(
handler: Callable[..., int], handler: Callable[..., int],
*arguments: CommandT, *arguments: object,
) -> tuple[int, str]: ) -> tuple[int, str]:
output = io.StringIO() output = io.StringIO()