From 5865c3119e4a345a7598e5f9c88afe1d865d3665 Mon Sep 17 00:00:00 2001 From: Matteo Rosati Date: Tue, 21 Apr 2026 18:24:49 +0200 Subject: [PATCH] add metadata (file_name) --- chroma_functions.py | 42 +++++++++++++++++++++--------------------- utilities.py | 29 +++++++++++++++++++++++------ 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/chroma_functions.py b/chroma_functions.py index b8211eb..ec8d045 100644 --- a/chroma_functions.py +++ b/chroma_functions.py @@ -2,12 +2,26 @@ from typing import List from uuid import uuid4 import chromadb +from chromadb.api import ClientAPI from chromadb.api.types import QueryResult from chromadb.errors import NotFoundError from embed import EmbeddingRecord +def _get_client_and_collection( + collection_name: str, +) -> tuple[ClientAPI, chromadb.Collection]: + client = chromadb.PersistentClient() + + try: + collection = client.get_collection(name=collection_name) + except NotFoundError: + raise + + return client, collection + + def list_collections() -> List[str]: client = chromadb.PersistentClient() collections = client.list_collections() @@ -30,30 +44,21 @@ def delete_collection(name: str) -> None: client.delete_collection(name=name) -def count_collection(name: str) -> int: - client = chromadb.PersistentClient() - - try: - collection = client.get_collection(name=name) - except NotFoundError: - raise +def count_collection(collection_name: str) -> int: + _, collection = _get_client_and_collection(collection_name) return collection.count() -def add_data(collection: str, data: List[EmbeddingRecord]) -> None: +def add_data(collection_name: str, data: List[EmbeddingRecord], file_name: str) -> None: if not data: return - client = chromadb.PersistentClient() + _, collection = _get_client_and_collection(collection_name) - try: - target_collection = client.get_collection(name=collection) - except NotFoundError: - raise - - target_collection.add( + collection.add( ids=[str(uuid4()) for _ in data], + metadatas=[{"file_name": file_name} for _ in data], documents=[record["text"] for record in data], embeddings=[record["embedding"] for record in data], ) @@ -69,11 +74,6 @@ def query_data(collection_name: str, texts: list[str]) -> QueryResult: "embeddings": [], } - client = chromadb.PersistentClient() - - try: - collection = client.get_collection(name=collection_name) - except NotFoundError: - raise + _, collection = _get_client_and_collection(collection_name) return collection.query(query_texts=texts) diff --git a/utilities.py b/utilities.py index 6173698..22bc6c9 100644 --- a/utilities.py +++ b/utilities.py @@ -1,4 +1,5 @@ from chromadb import QueryResult +from collections.abc import Mapping from chroma_functions import add_data, query_data from chunk_functions import chunk_file @@ -13,7 +14,7 @@ def print_lines(lines: list[str]) -> None: def ingest_file(collection_name: str, file_path: str) -> int: chunks = chunk_file(file_path) embeddings = embed(chunks) - add_data(collection_name, embeddings) + add_data(collection_name, embeddings, file_path) return len(embeddings) @@ -25,22 +26,38 @@ def format_query_result(result: QueryResult) -> list[str]: ids = result.get("ids", [[]]) documents = result.get("documents", [[]]) distances = result.get("distances", [[]]) + metadatas = result.get("metadatas", [[]]) first_ids = ids[0] if ids else [] first_documents = documents[0] if documents else [] first_distances = distances[0] if distances else [] + first_metadatas = metadatas[0] if metadatas else [] if not first_ids: return ["No results found."] lines = ["Query results:"] + for index, document_id in enumerate(first_ids, start=1): - lines.append(f"{index}. id: {document_id}") + lines.append(f"{index}.\tid: {document_id}") + i = index - 1 - if index - 1 < len(first_distances): - lines.append(f" distance: {first_distances[index - 1]}") + if i < len(first_distances): + lines.append(f"\tdistance: {first_distances[i]}") - if index - 1 < len(first_documents): - lines.append(f" document: {first_documents[index - 1]}") + if i < len(first_metadatas): + metadata = first_metadatas[i] + + if isinstance(metadata, Mapping): + file_name = metadata.get("file_name") + + if file_name: + lines.append(f"\tfile_name: {file_name}") + + if i < len(first_documents): + lines.append(f"\tdocument: {first_documents[i]}") + + # Print a separator between documents + lines.append(60 * "-") return lines