add ruff. fix all linting

This commit is contained in:
Matteo Rosati
2026-04-22 17:03:01 +02:00
parent bd5f649663
commit 33b46c2c21
14 changed files with 294 additions and 63 deletions
+15 -7
View File
@@ -1,9 +1,12 @@
from typing import List
from __future__ import annotations
from collections.abc import Sequence
from typing import cast
from uuid import uuid4
import chromadb
from chromadb.api import ClientAPI
from chromadb.api.types import QueryResult
from chromadb.api.types import QueryResult, Where
from chromadb.errors import NotFoundError
from chromy.embed import EmbeddingRecord
@@ -22,7 +25,7 @@ def _get_client_and_collection(
return client, collection
def list_collections() -> List[str]:
def list_collections() -> list[str]:
client = chromadb.PersistentClient()
collections = client.list_collections()
@@ -46,7 +49,7 @@ def delete_collection(name: str) -> None:
def delete_data(collection_name: str, where: dict[str, str]) -> int:
_, collection = _get_client_and_collection(collection_name)
result = collection.delete(where=where)
result = collection.delete(where=cast(Where, where))
return int(result.get("deleted", 0))
@@ -57,17 +60,19 @@ def count_collection(collection_name: str) -> int:
return collection.count()
def add_data(collection_name: str, data: List[EmbeddingRecord], file_name: str) -> None:
def add_data(collection_name: str, data: list[EmbeddingRecord], file_name: str) -> None:
if not data:
return
_, collection = _get_client_and_collection(collection_name)
embeddings: list[Sequence[float]] = [record["embedding"] for record in data]
collection.add(
ids=[str(uuid4()) for _ in data],
metadatas=[{"file_name": file_name} for _ in data],
documents=[record["text"] for record in data],
embeddings=[record["embedding"] for record in data],
embeddings=embeddings,
)
@@ -78,7 +83,10 @@ def query_data(collection_name: str, texts: list[str]) -> QueryResult:
"documents": [],
"metadatas": [],
"distances": [],
"embeddings": [],
"embeddings": None,
"uris": None,
"data": None,
"included": ["documents", "metadatas", "distances"],
}
_, collection = _get_client_and_collection(collection_name)