add query
This commit is contained in:
+19
-2
@@ -2,6 +2,7 @@ from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.types import QueryResult
|
||||
from chromadb.errors import NotFoundError
|
||||
|
||||
from embed import EmbeddingRecord
|
||||
@@ -58,5 +59,21 @@ def add_data(collection: str, data: List[EmbeddingRecord]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def query_data(collection_name: str, texts: list[str]):
|
||||
raise NotImplementedError()
|
||||
def query_data(collection_name: str, texts: list[str]) -> QueryResult:
|
||||
if not texts:
|
||||
return {
|
||||
"ids": [],
|
||||
"documents": [],
|
||||
"metadatas": [],
|
||||
"distances": [],
|
||||
"embeddings": [],
|
||||
}
|
||||
|
||||
client = chromadb.PersistentClient()
|
||||
|
||||
try:
|
||||
collection = client.get_collection(name=collection_name)
|
||||
except NotFoundError:
|
||||
raise
|
||||
|
||||
return collection.query(query_texts=texts)
|
||||
|
||||
Reference in New Issue
Block a user