add documents
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.errors import NotFoundError
|
from chromadb.errors import NotFoundError
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from embed import EmbeddingRecord
|
||||||
|
|
||||||
|
|
||||||
def list_collections() -> List[str]:
|
def list_collections() -> List[str]:
|
||||||
@@ -34,3 +37,21 @@ def count_collection(name: str) -> int:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
return collection.count()
|
return collection.count()
|
||||||
|
|
||||||
|
|
||||||
|
def add_data(collection: str, data: List[EmbeddingRecord]) -> None:
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
client = chromadb.PersistentClient()
|
||||||
|
|
||||||
|
try:
|
||||||
|
target_collection = client.get_collection(name=collection)
|
||||||
|
except NotFoundError:
|
||||||
|
raise
|
||||||
|
|
||||||
|
target_collection.add(
|
||||||
|
ids=[str(uuid4()) for _ in data],
|
||||||
|
documents=[record["text"] for record in data],
|
||||||
|
embeddings=[record["embedding"] for record in data],
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
import semchunk
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(text: str, chunk_size: int = 800) -> List[str]:
|
|
||||||
chunker = semchunk.chunkerify("gpt-4", chunk_size)
|
|
||||||
chunks = chunker(text)
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import semchunk
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_text(text: str, chunk_size: int = 800) -> List[str]:
|
||||||
|
chunker = semchunk.chunkerify("gpt-4", chunk_size)
|
||||||
|
chunks = chunker(text)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_file(filename: str, chunk_size: int = 800) -> List[str]:
|
||||||
|
contents = Path(filename).read_text()
|
||||||
|
|
||||||
|
return chunk_text(contents, chunk_size)
|
||||||
@@ -32,4 +32,12 @@ def build_parser() -> argparse.ArgumentParser:
|
|||||||
)
|
)
|
||||||
count_parser.add_argument("name", help="Name of the collection to count.")
|
count_parser.add_argument("name", help="Name of the collection to count.")
|
||||||
|
|
||||||
|
add_parser = subparsers.add_parser(
|
||||||
|
"add-data",
|
||||||
|
aliases=["ad"],
|
||||||
|
help="Chunk, embed, and add a file to a collection in the local Chroma database.",
|
||||||
|
)
|
||||||
|
add_parser.add_argument("collection", help="Name of the target collection.")
|
||||||
|
add_parser.add_argument("file", help="Path to the file to chunk and add to the collection.")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from chromadb.errors import NotFoundError, InternalError
|
from chromadb.errors import InternalError, NotFoundError
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from chroma_functions import (
|
from chroma_functions import (
|
||||||
|
add_data,
|
||||||
count_collection,
|
count_collection,
|
||||||
create_collection,
|
create_collection,
|
||||||
delete_collection,
|
delete_collection,
|
||||||
list_collections,
|
list_collections,
|
||||||
)
|
)
|
||||||
|
from chunk_functions import chunk_file
|
||||||
from cli_parser import build_parser
|
from cli_parser import build_parser
|
||||||
|
from embed import embed
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
@@ -58,6 +64,22 @@ def main() -> int:
|
|||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
if args.command in {"add-data", "ad"}:
|
||||||
|
try:
|
||||||
|
chunks = chunk_file(args.file)
|
||||||
|
embeddings = embed(chunks)
|
||||||
|
add_data(args.collection, embeddings)
|
||||||
|
except NotFoundError:
|
||||||
|
print(f"Collection '{args.collection}' does not exist.")
|
||||||
|
return 1
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"The file {args.file} was not found.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
print(f"Added {len(embeddings)} records to collection '{args.collection}'.")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
print("Nothing to do. Use -h to see available commands.")
|
print("Nothing to do. Use -h to see available commands.")
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
Reference in New Issue
Block a user