add ruff. fix all linting

This commit is contained in:
Matteo Rosati
2026-04-22 17:03:01 +02:00
parent bd5f649663
commit 33b46c2c21
14 changed files with 294 additions and 63 deletions
+15 -7
View File
@@ -1,9 +1,12 @@
from typing import List
from __future__ import annotations
from collections.abc import Sequence
from typing import cast
from uuid import uuid4
import chromadb
from chromadb.api import ClientAPI
from chromadb.api.types import QueryResult
from chromadb.api.types import QueryResult, Where
from chromadb.errors import NotFoundError
from chromy.embed import EmbeddingRecord
@@ -22,7 +25,7 @@ def _get_client_and_collection(
return client, collection
def list_collections() -> List[str]:
def list_collections() -> list[str]:
client = chromadb.PersistentClient()
collections = client.list_collections()
@@ -46,7 +49,7 @@ def delete_collection(name: str) -> None:
def delete_data(collection_name: str, where: dict[str, str]) -> int:
_, collection = _get_client_and_collection(collection_name)
result = collection.delete(where=where)
result = collection.delete(where=cast(Where, where))
return int(result.get("deleted", 0))
@@ -57,17 +60,19 @@ def count_collection(collection_name: str) -> int:
return collection.count()
def add_data(collection_name: str, data: List[EmbeddingRecord], file_name: str) -> None:
def add_data(collection_name: str, data: list[EmbeddingRecord], file_name: str) -> None:
if not data:
return
_, collection = _get_client_and_collection(collection_name)
embeddings: list[Sequence[float]] = [record["embedding"] for record in data]
collection.add(
ids=[str(uuid4()) for _ in data],
metadatas=[{"file_name": file_name} for _ in data],
documents=[record["text"] for record in data],
embeddings=[record["embedding"] for record in data],
embeddings=embeddings,
)
@@ -78,7 +83,10 @@ def query_data(collection_name: str, texts: list[str]) -> QueryResult:
"documents": [],
"metadatas": [],
"distances": [],
"embeddings": [],
"embeddings": None,
"uris": None,
"data": None,
"included": ["documents", "metadatas", "distances"],
}
_, collection = _get_client_and_collection(collection_name)
+6 -4
View File
@@ -1,17 +1,19 @@
from __future__ import annotations
from pathlib import Path
from typing import List
from typing import cast
import semchunk
def chunk_text(text: str, chunk_size: int = 800) -> List[str]:
def chunk_text(text: str, chunk_size: int = 800) -> list[str]:
chunker = semchunk.chunkerify("gpt-4", chunk_size)
chunks = chunker(text)
return chunks
return cast("list[str]", chunks)
def chunk_file(filename: str, chunk_size: int = 800) -> List[str]:
def chunk_file(filename: str, chunk_size: int = 800) -> list[str]:
contents = Path(filename).read_text()
return chunk_text(contents, chunk_size)
+18 -14
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from argparse import Namespace
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Generic, Protocol, TypeVar, assert_never
from typing import Generic, TypeVar, assert_never
from chromadb.errors import InternalError, NotFoundError
@@ -27,17 +27,19 @@ from chromy.handlers.delete_collection import (
from chromy.handlers.list_collections import handle_list_collections
from chromy.handlers.query import handle_query
CommandT = TypeVar("CommandT", bound=CommandInput)
CollectionCommandT = TypeVar("CollectionCommandT", bound="HasCollection")
CollectionCommandT = TypeVar(
"CollectionCommandT",
DeleteCollectionInput,
CountCollectionInput,
AddDataInput,
QueryInput,
DeleteRecordsInput,
)
CommandHandler = Callable[[CommandT], int]
ErrorMessageBuilder = Callable[[CommandT, Exception], str]
class HasCollection(Protocol):
collection: str
@dataclass(frozen=True, slots=True)
class CliErrorHandler(Generic[CommandT]):
exception_type: type[Exception]
@@ -93,20 +95,20 @@ def execute_command(args: Namespace) -> int:
return _run_command(
command_input,
handle_delete_collection,
(_collection_not_found_handler(),),
(_collection_not_found_handler(DeleteCollectionInput),),
)
case CountCollectionInput():
return _run_command(
command_input,
handle_count_collection,
(_collection_not_found_handler(),),
(_collection_not_found_handler(CountCollectionInput),),
)
case AddDataInput():
return _run_command(
command_input,
handle_add_data,
(
_collection_not_found_handler(),
_collection_not_found_handler(AddDataInput),
CliErrorHandler(
exception_type=FileNotFoundError,
message=_file_not_found_message,
@@ -117,14 +119,14 @@ def execute_command(args: Namespace) -> int:
return _run_command(
command_input,
handle_query,
(_collection_not_found_handler(),),
(_collection_not_found_handler(QueryInput),),
)
case DeleteRecordsInput():
return _run_command(
command_input,
handle_delete_records,
(
_collection_not_found_handler(),
_collection_not_found_handler(DeleteRecordsInput),
CliErrorHandler(
exception_type=ValueError,
message=_exception_message,
@@ -157,14 +159,16 @@ def _collection_already_exists_message(
return f"Collection '{command.collection}' already exists."
def _collection_not_found_handler() -> CliErrorHandler[CollectionCommandT]:
def _collection_not_found_handler(
_: type[CollectionCommandT],
) -> CliErrorHandler[CollectionCommandT]:
return CliErrorHandler(
exception_type=NotFoundError,
message=_collection_not_found_message,
)
def _collection_not_found_message(command: HasCollection, _: Exception) -> str:
def _collection_not_found_message(command: CollectionCommandT, _: Exception) -> str:
return f"Collection '{command.collection}' does not exist."
+15 -8
View File
@@ -47,7 +47,9 @@ COMMAND_SPECS: tuple[CommandSpec, ...] = (
CommandSpec(
name="add-data",
aliases=("ad",),
help="Chunk, embed, and add a file to a collection in the local Chroma database.",
help=(
"Chunk, embed, and add a file to a collection in the local Chroma database."
),
arguments=(
ArgumentSpec("collection", "Name of the target collection."),
ArgumentSpec(
@@ -93,15 +95,20 @@ def _add_command(
)
for argument in command.arguments:
argument_kwargs: dict[str, object] = {"help": argument.help}
if argument.metavar is not None:
argument_kwargs["metavar"] = argument.metavar
if argument.name.startswith("-"):
argument_kwargs["required"] = argument.required
subparser.add_argument(
argument.name,
help=argument.help,
metavar=argument.metavar,
required=argument.required,
)
continue
subparser.add_argument(argument.name, **argument_kwargs)
subparser.add_argument(
argument.name,
help=argument.help,
metavar=argument.metavar,
)
subparser.set_defaults(command=command.name)
+6 -4
View File
@@ -1,14 +1,16 @@
from typing import List, TypedDict
from __future__ import annotations
from typing import TypedDict
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
class EmbeddingRecord(TypedDict):
text: str
embedding: List[float]
embedding: list[float]
def embed(chunks: List[str]) -> List[EmbeddingRecord]:
def embed(chunks: list[str]) -> list[EmbeddingRecord]:
if not chunks:
return []
@@ -22,5 +24,5 @@ def embed(chunks: List[str]) -> List[EmbeddingRecord]:
embedding.tolist() if hasattr(embedding, "tolist") else list(embedding)
),
}
for text, embedding in zip(chunks, embeddings)
for text, embedding in zip(chunks, embeddings, strict=False)
]
+2 -4
View File
@@ -8,15 +8,13 @@ def _parse_where_clause(where_clause: str) -> dict[str, str]:
condition, separator, value = where_clause.partition("=")
if separator == "":
raise ValueError(
"Invalid --where value. Expected <condition>=<value>.")
raise ValueError("Invalid --where value. Expected <condition>=<value>.")
condition = condition.strip()
value = value.strip()
if not condition or not value:
raise ValueError(
"Invalid --where value. Expected <condition>=<value>.")
raise ValueError("Invalid --where value. Expected <condition>=<value>.")
return {condition: value}
+1 -1
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
from chromy.command_inputs import ListCollectionsInput
from chromy.chroma_functions import list_collections
from chromy.command_inputs import ListCollectionsInput
from chromy.utilities import print_lines
+2 -1
View File
@@ -1,6 +1,7 @@
from chromadb import QueryResult
from collections.abc import Mapping
from chromadb import QueryResult
from chromy.chroma_functions import add_data, query_data
from chromy.chunk_functions import chunk_file
from chromy.embed import embed