add metadata (file_name)

This commit is contained in:
2026-04-21 18:24:49 +02:00
parent 1513dbb473
commit 5865c3119e
2 changed files with 44 additions and 27 deletions
+21 -21
View File
@@ -2,12 +2,26 @@ from typing import List
from uuid import uuid4
import chromadb
from chromadb.api import ClientAPI
from chromadb.api.types import QueryResult
from chromadb.errors import NotFoundError
from embed import EmbeddingRecord
def _get_client_and_collection(
collection_name: str,
) -> tuple[ClientAPI, chromadb.Collection]:
client = chromadb.PersistentClient()
try:
collection = client.get_collection(name=collection_name)
except NotFoundError:
raise
return client, collection
def list_collections() -> List[str]:
client = chromadb.PersistentClient()
collections = client.list_collections()
@@ -30,30 +44,21 @@ def delete_collection(name: str) -> None:
client.delete_collection(name=name)
def count_collection(name: str) -> int:
client = chromadb.PersistentClient()
try:
collection = client.get_collection(name=name)
except NotFoundError:
raise
def count_collection(collection_name: str) -> int:
_, collection = _get_client_and_collection(collection_name)
return collection.count()
def add_data(collection: str, data: List[EmbeddingRecord]) -> None:
def add_data(collection_name: str, data: List[EmbeddingRecord], file_name: str) -> None:
if not data:
return
client = chromadb.PersistentClient()
_, collection = _get_client_and_collection(collection_name)
try:
target_collection = client.get_collection(name=collection)
except NotFoundError:
raise
target_collection.add(
collection.add(
ids=[str(uuid4()) for _ in data],
metadatas=[{"file_name": file_name} for _ in data],
documents=[record["text"] for record in data],
embeddings=[record["embedding"] for record in data],
)
@@ -69,11 +74,6 @@ def query_data(collection_name: str, texts: list[str]) -> QueryResult:
"embeddings": [],
}
client = chromadb.PersistentClient()
try:
collection = client.get_collection(name=collection_name)
except NotFoundError:
raise
_, collection = _get_client_and_collection(collection_name)
return collection.query(query_texts=texts)