add metadata (file_name)
This commit is contained in:
+21
-21
@@ -2,12 +2,26 @@ from typing import List
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
|
from chromadb.api import ClientAPI
|
||||||
from chromadb.api.types import QueryResult
|
from chromadb.api.types import QueryResult
|
||||||
from chromadb.errors import NotFoundError
|
from chromadb.errors import NotFoundError
|
||||||
|
|
||||||
from embed import EmbeddingRecord
|
from embed import EmbeddingRecord
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client_and_collection(
|
||||||
|
collection_name: str,
|
||||||
|
) -> tuple[ClientAPI, chromadb.Collection]:
|
||||||
|
client = chromadb.PersistentClient()
|
||||||
|
|
||||||
|
try:
|
||||||
|
collection = client.get_collection(name=collection_name)
|
||||||
|
except NotFoundError:
|
||||||
|
raise
|
||||||
|
|
||||||
|
return client, collection
|
||||||
|
|
||||||
|
|
||||||
def list_collections() -> List[str]:
|
def list_collections() -> List[str]:
|
||||||
client = chromadb.PersistentClient()
|
client = chromadb.PersistentClient()
|
||||||
collections = client.list_collections()
|
collections = client.list_collections()
|
||||||
@@ -30,30 +44,21 @@ def delete_collection(name: str) -> None:
|
|||||||
client.delete_collection(name=name)
|
client.delete_collection(name=name)
|
||||||
|
|
||||||
|
|
||||||
def count_collection(name: str) -> int:
|
def count_collection(collection_name: str) -> int:
|
||||||
client = chromadb.PersistentClient()
|
_, collection = _get_client_and_collection(collection_name)
|
||||||
|
|
||||||
try:
|
|
||||||
collection = client.get_collection(name=name)
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
|
|
||||||
return collection.count()
|
return collection.count()
|
||||||
|
|
||||||
|
|
||||||
def add_data(collection: str, data: List[EmbeddingRecord]) -> None:
|
def add_data(collection_name: str, data: List[EmbeddingRecord], file_name: str) -> None:
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
client = chromadb.PersistentClient()
|
_, collection = _get_client_and_collection(collection_name)
|
||||||
|
|
||||||
try:
|
collection.add(
|
||||||
target_collection = client.get_collection(name=collection)
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
|
|
||||||
target_collection.add(
|
|
||||||
ids=[str(uuid4()) for _ in data],
|
ids=[str(uuid4()) for _ in data],
|
||||||
|
metadatas=[{"file_name": file_name} for _ in data],
|
||||||
documents=[record["text"] for record in data],
|
documents=[record["text"] for record in data],
|
||||||
embeddings=[record["embedding"] for record in data],
|
embeddings=[record["embedding"] for record in data],
|
||||||
)
|
)
|
||||||
@@ -69,11 +74,6 @@ def query_data(collection_name: str, texts: list[str]) -> QueryResult:
|
|||||||
"embeddings": [],
|
"embeddings": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
client = chromadb.PersistentClient()
|
_, collection = _get_client_and_collection(collection_name)
|
||||||
|
|
||||||
try:
|
|
||||||
collection = client.get_collection(name=collection_name)
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
|
|
||||||
return collection.query(query_texts=texts)
|
return collection.query(query_texts=texts)
|
||||||
|
|||||||
+23
-6
@@ -1,4 +1,5 @@
|
|||||||
from chromadb import QueryResult
|
from chromadb import QueryResult
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
from chroma_functions import add_data, query_data
|
from chroma_functions import add_data, query_data
|
||||||
from chunk_functions import chunk_file
|
from chunk_functions import chunk_file
|
||||||
@@ -13,7 +14,7 @@ def print_lines(lines: list[str]) -> None:
|
|||||||
def ingest_file(collection_name: str, file_path: str) -> int:
|
def ingest_file(collection_name: str, file_path: str) -> int:
|
||||||
chunks = chunk_file(file_path)
|
chunks = chunk_file(file_path)
|
||||||
embeddings = embed(chunks)
|
embeddings = embed(chunks)
|
||||||
add_data(collection_name, embeddings)
|
add_data(collection_name, embeddings, file_path)
|
||||||
return len(embeddings)
|
return len(embeddings)
|
||||||
|
|
||||||
|
|
||||||
@@ -25,22 +26,38 @@ def format_query_result(result: QueryResult) -> list[str]:
|
|||||||
ids = result.get("ids", [[]])
|
ids = result.get("ids", [[]])
|
||||||
documents = result.get("documents", [[]])
|
documents = result.get("documents", [[]])
|
||||||
distances = result.get("distances", [[]])
|
distances = result.get("distances", [[]])
|
||||||
|
metadatas = result.get("metadatas", [[]])
|
||||||
|
|
||||||
first_ids = ids[0] if ids else []
|
first_ids = ids[0] if ids else []
|
||||||
first_documents = documents[0] if documents else []
|
first_documents = documents[0] if documents else []
|
||||||
first_distances = distances[0] if distances else []
|
first_distances = distances[0] if distances else []
|
||||||
|
first_metadatas = metadatas[0] if metadatas else []
|
||||||
|
|
||||||
if not first_ids:
|
if not first_ids:
|
||||||
return ["No results found."]
|
return ["No results found."]
|
||||||
|
|
||||||
lines = ["Query results:"]
|
lines = ["Query results:"]
|
||||||
|
|
||||||
for index, document_id in enumerate(first_ids, start=1):
|
for index, document_id in enumerate(first_ids, start=1):
|
||||||
lines.append(f"{index}. id: {document_id}")
|
lines.append(f"{index}.\tid: {document_id}")
|
||||||
|
i = index - 1
|
||||||
|
|
||||||
if index - 1 < len(first_distances):
|
if i < len(first_distances):
|
||||||
lines.append(f" distance: {first_distances[index - 1]}")
|
lines.append(f"\tdistance: {first_distances[i]}")
|
||||||
|
|
||||||
if index - 1 < len(first_documents):
|
if i < len(first_metadatas):
|
||||||
lines.append(f" document: {first_documents[index - 1]}")
|
metadata = first_metadatas[i]
|
||||||
|
|
||||||
|
if isinstance(metadata, Mapping):
|
||||||
|
file_name = metadata.get("file_name")
|
||||||
|
|
||||||
|
if file_name:
|
||||||
|
lines.append(f"\tfile_name: {file_name}")
|
||||||
|
|
||||||
|
if i < len(first_documents):
|
||||||
|
lines.append(f"\tdocument: {first_documents[i]}")
|
||||||
|
|
||||||
|
# Print a separator between documents
|
||||||
|
lines.append(60 * "-")
|
||||||
|
|
||||||
return lines
|
return lines
|
||||||
|
|||||||
Reference in New Issue
Block a user