move top-level modules into a real package
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Chromy package."""
|
||||
@@ -0,0 +1,86 @@
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import QueryResult
|
||||
from chromadb.errors import NotFoundError
|
||||
|
||||
from chromy.embed import EmbeddingRecord
|
||||
|
||||
|
||||
def _get_client_and_collection(
|
||||
collection_name: str,
|
||||
) -> tuple[ClientAPI, chromadb.Collection]:
|
||||
client = chromadb.PersistentClient()
|
||||
|
||||
try:
|
||||
collection = client.get_collection(name=collection_name)
|
||||
except NotFoundError:
|
||||
raise
|
||||
|
||||
return client, collection
|
||||
|
||||
|
||||
def list_collections() -> List[str]:
|
||||
client = chromadb.PersistentClient()
|
||||
collections = client.list_collections()
|
||||
|
||||
if not collections:
|
||||
return []
|
||||
|
||||
return [getattr(collection, "name", str(collection)) for collection in collections]
|
||||
|
||||
|
||||
def create_collection(name: str) -> str:
|
||||
client = chromadb.PersistentClient()
|
||||
collection = client.create_collection(name=name)
|
||||
|
||||
return getattr(collection, "name", name)
|
||||
|
||||
|
||||
def delete_collection(name: str) -> None:
|
||||
client = chromadb.PersistentClient()
|
||||
client.delete_collection(name=name)
|
||||
|
||||
|
||||
def delete_data(collection_name: str, where: dict[str, str]) -> int:
|
||||
_, collection = _get_client_and_collection(collection_name)
|
||||
result = collection.delete(where=where)
|
||||
|
||||
return int(result.get("deleted", 0))
|
||||
|
||||
|
||||
def count_collection(collection_name: str) -> int:
|
||||
_, collection = _get_client_and_collection(collection_name)
|
||||
|
||||
return collection.count()
|
||||
|
||||
|
||||
def add_data(collection_name: str, data: List[EmbeddingRecord], file_name: str) -> None:
|
||||
if not data:
|
||||
return
|
||||
|
||||
_, collection = _get_client_and_collection(collection_name)
|
||||
|
||||
collection.add(
|
||||
ids=[str(uuid4()) for _ in data],
|
||||
metadatas=[{"file_name": file_name} for _ in data],
|
||||
documents=[record["text"] for record in data],
|
||||
embeddings=[record["embedding"] for record in data],
|
||||
)
|
||||
|
||||
|
||||
def query_data(collection_name: str, texts: list[str]) -> QueryResult:
|
||||
if not texts:
|
||||
return {
|
||||
"ids": [],
|
||||
"documents": [],
|
||||
"metadatas": [],
|
||||
"distances": [],
|
||||
"embeddings": [],
|
||||
}
|
||||
|
||||
_, collection = _get_client_and_collection(collection_name)
|
||||
|
||||
return collection.query(query_texts=texts)
|
||||
@@ -0,0 +1,17 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import semchunk
|
||||
|
||||
|
||||
def chunk_text(text: str, chunk_size: int = 800) -> List[str]:
|
||||
chunker = semchunk.chunkerify("gpt-4", chunk_size)
|
||||
chunks = chunker(text)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_file(filename: str, chunk_size: int = 800) -> List[str]:
|
||||
contents = Path(filename).read_text()
|
||||
|
||||
return chunk_text(contents, chunk_size)
|
||||
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from argparse import Namespace
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from chromadb.errors import InternalError, NotFoundError
|
||||
|
||||
from chromy.handlers.add_data import handle_add_data
|
||||
from chromy.handlers.count_collection import handle_count_collection
|
||||
from chromy.handlers.create_collection import handle_create_collection
|
||||
from chromy.handlers.delete_collection import (
|
||||
handle_delete_collection,
|
||||
handle_delete_records,
|
||||
)
|
||||
from chromy.handlers.list_collections import handle_list_collections
|
||||
from chromy.handlers.query import handle_query
|
||||
|
||||
|
||||
CommandHandler = Callable[[Namespace], int]
|
||||
ErrorMessageBuilder = Callable[[Namespace], str]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class CliErrorHandler:
|
||||
exception_type: type[BaseException]
|
||||
message: ErrorMessageBuilder
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class CommandConfig:
|
||||
handler: CommandHandler
|
||||
error_handlers: tuple[CliErrorHandler, ...] = ()
|
||||
|
||||
|
||||
COMMANDS: dict[str, CommandConfig] = {
|
||||
"list-collections": CommandConfig(handler=handle_list_collections),
|
||||
"create-collection": CommandConfig(
|
||||
handler=handle_create_collection,
|
||||
error_handlers=(
|
||||
CliErrorHandler(
|
||||
exception_type=InternalError,
|
||||
message=lambda args: f"Collection '{args.collection}' already exists.",
|
||||
),
|
||||
),
|
||||
),
|
||||
"delete-collection": CommandConfig(
|
||||
handler=handle_delete_collection,
|
||||
error_handlers=(
|
||||
CliErrorHandler(
|
||||
exception_type=NotFoundError,
|
||||
message=lambda args: f"Collection '{args.collection}' does not exist.",
|
||||
),
|
||||
),
|
||||
),
|
||||
"count": CommandConfig(
|
||||
handler=handle_count_collection,
|
||||
error_handlers=(
|
||||
CliErrorHandler(
|
||||
exception_type=NotFoundError,
|
||||
message=lambda args: f"Collection '{args.collection}' does not exist.",
|
||||
),
|
||||
),
|
||||
),
|
||||
"add-data": CommandConfig(
|
||||
handler=handle_add_data,
|
||||
error_handlers=(
|
||||
CliErrorHandler(
|
||||
exception_type=NotFoundError,
|
||||
message=lambda args: f"Collection '{args.collection}' does not exist.",
|
||||
),
|
||||
CliErrorHandler(
|
||||
exception_type=FileNotFoundError,
|
||||
message=lambda args: f"The file {args.file} was not found.",
|
||||
),
|
||||
),
|
||||
),
|
||||
"query": CommandConfig(
|
||||
handler=handle_query,
|
||||
error_handlers=(
|
||||
CliErrorHandler(
|
||||
exception_type=NotFoundError,
|
||||
message=lambda args: f"Collection '{args.collection}' does not exist.",
|
||||
),
|
||||
),
|
||||
),
|
||||
"delete": CommandConfig(
|
||||
handler=handle_delete_records,
|
||||
error_handlers=(
|
||||
CliErrorHandler(
|
||||
exception_type=NotFoundError,
|
||||
message=lambda args: f"Collection '{args.collection}' does not exist.",
|
||||
),
|
||||
CliErrorHandler(
|
||||
exception_type=ValueError,
|
||||
message=lambda args: str(args.error_message),
|
||||
),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def execute_command(args: Namespace) -> int:
|
||||
command = COMMANDS[args.command]
|
||||
args.error_message = "An unexpected value was provided."
|
||||
|
||||
try:
|
||||
return command.handler(args)
|
||||
except BaseException as exc:
|
||||
for error_handler in command.error_handlers:
|
||||
if isinstance(exc, error_handler.exception_type):
|
||||
print(error_handler.message(args))
|
||||
return 1
|
||||
raise
|
||||
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ArgumentSpec:
|
||||
name: str
|
||||
help: str
|
||||
required: bool = False
|
||||
metavar: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class CommandSpec:
|
||||
name: str
|
||||
aliases: tuple[str, ...]
|
||||
help: str
|
||||
arguments: tuple[ArgumentSpec, ...] = ()
|
||||
|
||||
|
||||
COMMAND_SPECS: tuple[CommandSpec, ...] = (
|
||||
CommandSpec(
|
||||
name="list-collections",
|
||||
aliases=("lc",),
|
||||
help="List all collections stored in the local Chroma database.",
|
||||
),
|
||||
CommandSpec(
|
||||
name="create-collection",
|
||||
aliases=("cc",),
|
||||
help="Create a collection in the local Chroma database.",
|
||||
arguments=(ArgumentSpec("collection", "Name of the collection to create."),),
|
||||
),
|
||||
CommandSpec(
|
||||
name="delete-collection",
|
||||
aliases=("dc",),
|
||||
help="Delete a collection from the local Chroma database.",
|
||||
arguments=(ArgumentSpec("collection", "Name of the collection to delete."),),
|
||||
),
|
||||
CommandSpec(
|
||||
name="count",
|
||||
aliases=("co",),
|
||||
help="Count records in a collection from the local Chroma database.",
|
||||
arguments=(ArgumentSpec("collection", "Name of the collection to count."),),
|
||||
),
|
||||
CommandSpec(
|
||||
name="add-data",
|
||||
aliases=("ad",),
|
||||
help="Chunk, embed, and add a file to a collection in the local Chroma database.",
|
||||
arguments=(
|
||||
ArgumentSpec("collection", "Name of the target collection."),
|
||||
ArgumentSpec(
|
||||
"file", "Path to the file to chunk and add to the collection."
|
||||
),
|
||||
),
|
||||
),
|
||||
CommandSpec(
|
||||
name="query",
|
||||
aliases=("q",),
|
||||
help="Query a collection with the provided text.",
|
||||
arguments=(
|
||||
ArgumentSpec("collection", "Name of the target collection."),
|
||||
ArgumentSpec("query_text", "The text to query."),
|
||||
),
|
||||
),
|
||||
CommandSpec(
|
||||
name="delete",
|
||||
aliases=("del",),
|
||||
help="Delete records from a collection using a metadata filter.",
|
||||
arguments=(
|
||||
ArgumentSpec("collection", "Name of the target collection."),
|
||||
ArgumentSpec(
|
||||
"--where",
|
||||
"Metadata filter in the format <condition>=<value>.",
|
||||
required=True,
|
||||
metavar="CONDITION=VALUE",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _add_command(
|
||||
subparsers: argparse._SubParsersAction[argparse.ArgumentParser],
|
||||
command: CommandSpec,
|
||||
) -> None:
|
||||
subparser = subparsers.add_parser(
|
||||
command.name,
|
||||
aliases=list(command.aliases),
|
||||
help=command.help,
|
||||
description=command.help,
|
||||
)
|
||||
|
||||
for argument in command.arguments:
|
||||
argument_kwargs: dict[str, object] = {"help": argument.help}
|
||||
|
||||
if argument.metavar is not None:
|
||||
argument_kwargs["metavar"] = argument.metavar
|
||||
|
||||
if argument.name.startswith("-"):
|
||||
argument_kwargs["required"] = argument.required
|
||||
|
||||
subparser.add_argument(argument.name, **argument_kwargs)
|
||||
|
||||
subparser.set_defaults(command=command.name)
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="Inspect local Chroma collections.")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
for command in COMMAND_SPECS:
|
||||
_add_command(subparsers, command)
|
||||
|
||||
return parser
|
||||
@@ -0,0 +1,26 @@
|
||||
from typing import List, TypedDict
|
||||
|
||||
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
|
||||
|
||||
|
||||
class EmbeddingRecord(TypedDict):
|
||||
text: str
|
||||
embedding: List[float]
|
||||
|
||||
|
||||
def embed(chunks: List[str]) -> List[EmbeddingRecord]:
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
embedding_function = DefaultEmbeddingFunction()
|
||||
embeddings = embedding_function(chunks)
|
||||
|
||||
return [
|
||||
{
|
||||
"text": text,
|
||||
"embedding": (
|
||||
embedding.tolist() if hasattr(embedding, "tolist") else list(embedding)
|
||||
),
|
||||
}
|
||||
for text, embedding in zip(chunks, embeddings)
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
"""Command handlers package for the Chroma CLI."""
|
||||
@@ -0,0 +1,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from chromy.utilities import ingest_file
|
||||
|
||||
|
||||
def handle_add_data(args: Namespace) -> int:
|
||||
records_added = ingest_file(args.collection, args.file)
|
||||
print(f"Added {records_added} records to collection '{args.collection}'.")
|
||||
return 0
|
||||
@@ -0,0 +1,8 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from chromy.chroma_functions import count_collection
|
||||
|
||||
|
||||
def handle_count_collection(args: Namespace) -> int:
|
||||
print(count_collection(args.collection))
|
||||
return 0
|
||||
@@ -0,0 +1,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from chromy.chroma_functions import create_collection
|
||||
|
||||
|
||||
def handle_create_collection(args: Namespace) -> int:
|
||||
collection_name = create_collection(args.collection)
|
||||
print(f"Created collection '{collection_name}'.")
|
||||
return 0
|
||||
@@ -0,0 +1,40 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from chromy.chroma_functions import delete_collection, delete_data
|
||||
|
||||
|
||||
def _parse_where_clause(where_clause: str) -> dict[str, str]:
|
||||
condition, separator, value = where_clause.partition("=")
|
||||
|
||||
if separator == "":
|
||||
raise ValueError("Invalid --where value. Expected <condition>=<value>.")
|
||||
|
||||
condition = condition.strip()
|
||||
value = value.strip()
|
||||
|
||||
if not condition or not value:
|
||||
raise ValueError("Invalid --where value. Expected <condition>=<value>.")
|
||||
|
||||
return {condition: value}
|
||||
|
||||
|
||||
def handle_delete_collection(args: Namespace) -> int:
|
||||
delete_collection(args.collection)
|
||||
print(f"Deleted collection '{args.collection}'.")
|
||||
return 0
|
||||
|
||||
|
||||
def handle_delete_records(args: Namespace) -> int:
|
||||
try:
|
||||
where = _parse_where_clause(args.where)
|
||||
except ValueError as exc:
|
||||
args.error_message = str(exc)
|
||||
raise
|
||||
|
||||
deleted = delete_data(args.collection, where)
|
||||
condition, value = next(iter(where.items()))
|
||||
print(
|
||||
f"Deleted {deleted} record(s) from collection '{args.collection}' "
|
||||
f"where {condition}={value}."
|
||||
)
|
||||
return 0
|
||||
@@ -0,0 +1,14 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from chromy.chroma_functions import list_collections
|
||||
from chromy.utilities import print_lines
|
||||
|
||||
|
||||
def handle_list_collections(_: Namespace) -> int:
|
||||
collections = list_collections()
|
||||
if not collections:
|
||||
print("No collections found.")
|
||||
return 0
|
||||
|
||||
print_lines(collections)
|
||||
return 0
|
||||
@@ -0,0 +1,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from chromy.utilities import format_query_result, print_lines, run_query
|
||||
|
||||
|
||||
def handle_query(args: Namespace) -> int:
|
||||
result = run_query(args.collection, args.query_text)
|
||||
print_lines(format_query_result(result))
|
||||
return 0
|
||||
@@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from chromy.cli_app import execute_command
|
||||
from chromy.cli_parser import build_parser
|
||||
|
||||
|
||||
def main() -> int:
|
||||
load_dotenv()
|
||||
args = build_parser().parse_args()
|
||||
return execute_command(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -0,0 +1,63 @@
|
||||
from chromadb import QueryResult
|
||||
from collections.abc import Mapping
|
||||
|
||||
from chromy.chroma_functions import add_data, query_data
|
||||
from chromy.chunk_functions import chunk_file
|
||||
from chromy.embed import embed
|
||||
|
||||
|
||||
def print_lines(lines: list[str]) -> None:
|
||||
for line in lines:
|
||||
print(line)
|
||||
|
||||
|
||||
def ingest_file(collection_name: str, file_path: str) -> int:
|
||||
chunks = chunk_file(file_path)
|
||||
embeddings = embed(chunks)
|
||||
add_data(collection_name, embeddings, file_path)
|
||||
return len(embeddings)
|
||||
|
||||
|
||||
def run_query(collection_name: str, query_text: str) -> QueryResult:
|
||||
return query_data(collection_name, [query_text])
|
||||
|
||||
|
||||
def format_query_result(result: QueryResult) -> list[str]:
|
||||
ids = result.get("ids", [[]])
|
||||
documents = result.get("documents", [[]])
|
||||
distances = result.get("distances", [[]])
|
||||
metadatas = result.get("metadatas", [[]])
|
||||
|
||||
first_ids = ids[0] if ids else []
|
||||
first_documents = documents[0] if documents else []
|
||||
first_distances = distances[0] if distances else []
|
||||
first_metadatas = metadatas[0] if metadatas else []
|
||||
|
||||
if not first_ids:
|
||||
return ["No results found."]
|
||||
|
||||
lines = ["Query results:"]
|
||||
|
||||
for index, document_id in enumerate(first_ids, start=1):
|
||||
lines.append(f"{index}.\tid: {document_id}")
|
||||
i = index - 1
|
||||
|
||||
if i < len(first_distances):
|
||||
lines.append(f"\tdistance: {first_distances[i]}")
|
||||
|
||||
if i < len(first_metadatas):
|
||||
metadata = first_metadatas[i]
|
||||
|
||||
if isinstance(metadata, Mapping):
|
||||
file_name = metadata.get("file_name")
|
||||
|
||||
if file_name:
|
||||
lines.append(f"\tfile_name: {file_name}")
|
||||
|
||||
if i < len(first_documents):
|
||||
lines.append(f"\tdocument: {first_documents[i]}")
|
||||
|
||||
# Print a separator between documents
|
||||
lines.append(60 * "-")
|
||||
|
||||
return lines
|
||||
Reference in New Issue
Block a user