add metadata (file_name)

This commit is contained in:
2026-04-21 18:24:49 +02:00
parent 1513dbb473
commit 5865c3119e
2 changed files with 44 additions and 27 deletions
+21 -21
View File
@@ -2,12 +2,26 @@ from typing import List
from uuid import uuid4 from uuid import uuid4
import chromadb import chromadb
from chromadb.api import ClientAPI
from chromadb.api.types import QueryResult from chromadb.api.types import QueryResult
from chromadb.errors import NotFoundError from chromadb.errors import NotFoundError
from embed import EmbeddingRecord from embed import EmbeddingRecord
def _get_client_and_collection(
collection_name: str,
) -> tuple[ClientAPI, chromadb.Collection]:
client = chromadb.PersistentClient()
try:
collection = client.get_collection(name=collection_name)
except NotFoundError:
raise
return client, collection
def list_collections() -> List[str]: def list_collections() -> List[str]:
client = chromadb.PersistentClient() client = chromadb.PersistentClient()
collections = client.list_collections() collections = client.list_collections()
@@ -30,30 +44,21 @@ def delete_collection(name: str) -> None:
client.delete_collection(name=name) client.delete_collection(name=name)
def count_collection(name: str) -> int: def count_collection(collection_name: str) -> int:
client = chromadb.PersistentClient() _, collection = _get_client_and_collection(collection_name)
try:
collection = client.get_collection(name=name)
except NotFoundError:
raise
return collection.count() return collection.count()
def add_data(collection: str, data: List[EmbeddingRecord]) -> None: def add_data(collection_name: str, data: List[EmbeddingRecord], file_name: str) -> None:
if not data: if not data:
return return
client = chromadb.PersistentClient() _, collection = _get_client_and_collection(collection_name)
try: collection.add(
target_collection = client.get_collection(name=collection)
except NotFoundError:
raise
target_collection.add(
ids=[str(uuid4()) for _ in data], ids=[str(uuid4()) for _ in data],
metadatas=[{"file_name": file_name} for _ in data],
documents=[record["text"] for record in data], documents=[record["text"] for record in data],
embeddings=[record["embedding"] for record in data], embeddings=[record["embedding"] for record in data],
) )
@@ -69,11 +74,6 @@ def query_data(collection_name: str, texts: list[str]) -> QueryResult:
"embeddings": [], "embeddings": [],
} }
client = chromadb.PersistentClient() _, collection = _get_client_and_collection(collection_name)
try:
collection = client.get_collection(name=collection_name)
except NotFoundError:
raise
return collection.query(query_texts=texts) return collection.query(query_texts=texts)
+23 -6
View File
@@ -1,4 +1,5 @@
from chromadb import QueryResult from chromadb import QueryResult
from collections.abc import Mapping
from chroma_functions import add_data, query_data from chroma_functions import add_data, query_data
from chunk_functions import chunk_file from chunk_functions import chunk_file
@@ -13,7 +14,7 @@ def print_lines(lines: list[str]) -> None:
def ingest_file(collection_name: str, file_path: str) -> int: def ingest_file(collection_name: str, file_path: str) -> int:
chunks = chunk_file(file_path) chunks = chunk_file(file_path)
embeddings = embed(chunks) embeddings = embed(chunks)
add_data(collection_name, embeddings) add_data(collection_name, embeddings, file_path)
return len(embeddings) return len(embeddings)
@@ -25,22 +26,38 @@ def format_query_result(result: QueryResult) -> list[str]:
ids = result.get("ids", [[]]) ids = result.get("ids", [[]])
documents = result.get("documents", [[]]) documents = result.get("documents", [[]])
distances = result.get("distances", [[]]) distances = result.get("distances", [[]])
metadatas = result.get("metadatas", [[]])
first_ids = ids[0] if ids else [] first_ids = ids[0] if ids else []
first_documents = documents[0] if documents else [] first_documents = documents[0] if documents else []
first_distances = distances[0] if distances else [] first_distances = distances[0] if distances else []
first_metadatas = metadatas[0] if metadatas else []
if not first_ids: if not first_ids:
return ["No results found."] return ["No results found."]
lines = ["Query results:"] lines = ["Query results:"]
for index, document_id in enumerate(first_ids, start=1): for index, document_id in enumerate(first_ids, start=1):
lines.append(f"{index}. id: {document_id}") lines.append(f"{index}.\tid: {document_id}")
i = index - 1
if index - 1 < len(first_distances): if i < len(first_distances):
lines.append(f" distance: {first_distances[index - 1]}") lines.append(f"\tdistance: {first_distances[i]}")
if index - 1 < len(first_documents): if i < len(first_metadatas):
lines.append(f" document: {first_documents[index - 1]}") metadata = first_metadatas[i]
if isinstance(metadata, Mapping):
file_name = metadata.get("file_name")
if file_name:
lines.append(f"\tfile_name: {file_name}")
if i < len(first_documents):
lines.append(f"\tdocument: {first_documents[i]}")
# Print a separator between documents
lines.append(60 * "-")
return lines return lines