From ad73a6a985e741d4ac05cdaa41ab35243bbd5c8d Mon Sep 17 00:00:00 2001 From: Matteo Rosati Date: Tue, 21 Apr 2026 17:20:45 +0200 Subject: [PATCH] add query --- chroma_functions.py | 21 +++++++++++++++++++-- main.py | 12 ++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/chroma_functions.py b/chroma_functions.py index df9ec93..b8211eb 100644 --- a/chroma_functions.py +++ b/chroma_functions.py @@ -2,6 +2,7 @@ from typing import List from uuid import uuid4 import chromadb +from chromadb.api.types import QueryResult from chromadb.errors import NotFoundError from embed import EmbeddingRecord @@ -58,5 +59,21 @@ def add_data(collection: str, data: List[EmbeddingRecord]) -> None: ) -def query_data(collection_name: str, texts: list[str]): - raise NotImplementedError() +def query_data(collection_name: str, texts: list[str]) -> QueryResult: + if not texts: + return { + "ids": [], + "documents": [], + "metadatas": [], + "distances": [], + "embeddings": [], + } + + client = chromadb.PersistentClient() + + try: + collection = client.get_collection(name=collection_name) + except NotFoundError: + raise + + return collection.query(query_texts=texts) diff --git a/main.py b/main.py index 5007ef3..bcb345f 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,7 @@ from chroma_functions import ( create_collection, delete_collection, list_collections, + query_data, ) from chunk_functions import chunk_file from cli_parser import build_parser @@ -80,6 +81,17 @@ def main() -> int: return 0 + if args.command in {"query", "q"}: + try: + result = query_data(args.collection, [args.texts]) + except NotFoundError: + print(f"Collection '{args.collection}' does not exist.") + return 1 + + print(result) + + return 0 + print("Nothing to do. Use -h to see available commands.") return 0