add ruff. fix all linting
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
from typing import List
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import QueryResult
|
||||
from chromadb.api.types import QueryResult, Where
|
||||
from chromadb.errors import NotFoundError
|
||||
|
||||
from chromy.embed import EmbeddingRecord
|
||||
@@ -22,7 +25,7 @@ def _get_client_and_collection(
|
||||
return client, collection
|
||||
|
||||
|
||||
def list_collections() -> List[str]:
|
||||
def list_collections() -> list[str]:
|
||||
client = chromadb.PersistentClient()
|
||||
collections = client.list_collections()
|
||||
|
||||
@@ -46,7 +49,7 @@ def delete_collection(name: str) -> None:
|
||||
|
||||
def delete_data(collection_name: str, where: dict[str, str]) -> int:
|
||||
_, collection = _get_client_and_collection(collection_name)
|
||||
result = collection.delete(where=where)
|
||||
result = collection.delete(where=cast(Where, where))
|
||||
|
||||
return int(result.get("deleted", 0))
|
||||
|
||||
@@ -57,17 +60,19 @@ def count_collection(collection_name: str) -> int:
|
||||
return collection.count()
|
||||
|
||||
|
||||
def add_data(collection_name: str, data: List[EmbeddingRecord], file_name: str) -> None:
|
||||
def add_data(collection_name: str, data: list[EmbeddingRecord], file_name: str) -> None:
|
||||
if not data:
|
||||
return
|
||||
|
||||
_, collection = _get_client_and_collection(collection_name)
|
||||
|
||||
embeddings: list[Sequence[float]] = [record["embedding"] for record in data]
|
||||
|
||||
collection.add(
|
||||
ids=[str(uuid4()) for _ in data],
|
||||
metadatas=[{"file_name": file_name} for _ in data],
|
||||
documents=[record["text"] for record in data],
|
||||
embeddings=[record["embedding"] for record in data],
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,7 +83,10 @@ def query_data(collection_name: str, texts: list[str]) -> QueryResult:
|
||||
"documents": [],
|
||||
"metadatas": [],
|
||||
"distances": [],
|
||||
"embeddings": [],
|
||||
"embeddings": None,
|
||||
"uris": None,
|
||||
"data": None,
|
||||
"included": ["documents", "metadatas", "distances"],
|
||||
}
|
||||
|
||||
_, collection = _get_client_and_collection(collection_name)
|
||||
|
||||
Reference in New Issue
Block a user