add ruff. fix all linting
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
from typing import List
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import QueryResult
|
||||
from chromadb.api.types import QueryResult, Where
|
||||
from chromadb.errors import NotFoundError
|
||||
|
||||
from chromy.embed import EmbeddingRecord
|
||||
@@ -22,7 +25,7 @@ def _get_client_and_collection(
|
||||
return client, collection
|
||||
|
||||
|
||||
def list_collections() -> List[str]:
|
||||
def list_collections() -> list[str]:
|
||||
client = chromadb.PersistentClient()
|
||||
collections = client.list_collections()
|
||||
|
||||
@@ -46,7 +49,7 @@ def delete_collection(name: str) -> None:
|
||||
|
||||
def delete_data(collection_name: str, where: dict[str, str]) -> int:
|
||||
_, collection = _get_client_and_collection(collection_name)
|
||||
result = collection.delete(where=where)
|
||||
result = collection.delete(where=cast(Where, where))
|
||||
|
||||
return int(result.get("deleted", 0))
|
||||
|
||||
@@ -57,17 +60,19 @@ def count_collection(collection_name: str) -> int:
|
||||
return collection.count()
|
||||
|
||||
|
||||
def add_data(collection_name: str, data: List[EmbeddingRecord], file_name: str) -> None:
|
||||
def add_data(collection_name: str, data: list[EmbeddingRecord], file_name: str) -> None:
|
||||
if not data:
|
||||
return
|
||||
|
||||
_, collection = _get_client_and_collection(collection_name)
|
||||
|
||||
embeddings: list[Sequence[float]] = [record["embedding"] for record in data]
|
||||
|
||||
collection.add(
|
||||
ids=[str(uuid4()) for _ in data],
|
||||
metadatas=[{"file_name": file_name} for _ in data],
|
||||
documents=[record["text"] for record in data],
|
||||
embeddings=[record["embedding"] for record in data],
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,7 +83,10 @@ def query_data(collection_name: str, texts: list[str]) -> QueryResult:
|
||||
"documents": [],
|
||||
"metadatas": [],
|
||||
"distances": [],
|
||||
"embeddings": [],
|
||||
"embeddings": None,
|
||||
"uris": None,
|
||||
"data": None,
|
||||
"included": ["documents", "metadatas", "distances"],
|
||||
}
|
||||
|
||||
_, collection = _get_client_and_collection(collection_name)
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import cast
|
||||
|
||||
import semchunk
|
||||
|
||||
|
||||
def chunk_text(text: str, chunk_size: int = 800) -> List[str]:
|
||||
def chunk_text(text: str, chunk_size: int = 800) -> list[str]:
|
||||
chunker = semchunk.chunkerify("gpt-4", chunk_size)
|
||||
chunks = chunker(text)
|
||||
|
||||
return chunks
|
||||
return cast("list[str]", chunks)
|
||||
|
||||
|
||||
def chunk_file(filename: str, chunk_size: int = 800) -> List[str]:
|
||||
def chunk_file(filename: str, chunk_size: int = 800) -> list[str]:
|
||||
contents = Path(filename).read_text()
|
||||
|
||||
return chunk_text(contents, chunk_size)
|
||||
|
||||
+18
-14
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from argparse import Namespace
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, Protocol, TypeVar, assert_never
|
||||
from typing import Generic, TypeVar, assert_never
|
||||
|
||||
from chromadb.errors import InternalError, NotFoundError
|
||||
|
||||
@@ -27,17 +27,19 @@ from chromy.handlers.delete_collection import (
|
||||
from chromy.handlers.list_collections import handle_list_collections
|
||||
from chromy.handlers.query import handle_query
|
||||
|
||||
|
||||
CommandT = TypeVar("CommandT", bound=CommandInput)
|
||||
CollectionCommandT = TypeVar("CollectionCommandT", bound="HasCollection")
|
||||
CollectionCommandT = TypeVar(
|
||||
"CollectionCommandT",
|
||||
DeleteCollectionInput,
|
||||
CountCollectionInput,
|
||||
AddDataInput,
|
||||
QueryInput,
|
||||
DeleteRecordsInput,
|
||||
)
|
||||
CommandHandler = Callable[[CommandT], int]
|
||||
ErrorMessageBuilder = Callable[[CommandT, Exception], str]
|
||||
|
||||
|
||||
class HasCollection(Protocol):
|
||||
collection: str
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class CliErrorHandler(Generic[CommandT]):
|
||||
exception_type: type[Exception]
|
||||
@@ -93,20 +95,20 @@ def execute_command(args: Namespace) -> int:
|
||||
return _run_command(
|
||||
command_input,
|
||||
handle_delete_collection,
|
||||
(_collection_not_found_handler(),),
|
||||
(_collection_not_found_handler(DeleteCollectionInput),),
|
||||
)
|
||||
case CountCollectionInput():
|
||||
return _run_command(
|
||||
command_input,
|
||||
handle_count_collection,
|
||||
(_collection_not_found_handler(),),
|
||||
(_collection_not_found_handler(CountCollectionInput),),
|
||||
)
|
||||
case AddDataInput():
|
||||
return _run_command(
|
||||
command_input,
|
||||
handle_add_data,
|
||||
(
|
||||
_collection_not_found_handler(),
|
||||
_collection_not_found_handler(AddDataInput),
|
||||
CliErrorHandler(
|
||||
exception_type=FileNotFoundError,
|
||||
message=_file_not_found_message,
|
||||
@@ -117,14 +119,14 @@ def execute_command(args: Namespace) -> int:
|
||||
return _run_command(
|
||||
command_input,
|
||||
handle_query,
|
||||
(_collection_not_found_handler(),),
|
||||
(_collection_not_found_handler(QueryInput),),
|
||||
)
|
||||
case DeleteRecordsInput():
|
||||
return _run_command(
|
||||
command_input,
|
||||
handle_delete_records,
|
||||
(
|
||||
_collection_not_found_handler(),
|
||||
_collection_not_found_handler(DeleteRecordsInput),
|
||||
CliErrorHandler(
|
||||
exception_type=ValueError,
|
||||
message=_exception_message,
|
||||
@@ -157,14 +159,16 @@ def _collection_already_exists_message(
|
||||
return f"Collection '{command.collection}' already exists."
|
||||
|
||||
|
||||
def _collection_not_found_handler() -> CliErrorHandler[CollectionCommandT]:
|
||||
def _collection_not_found_handler(
|
||||
_: type[CollectionCommandT],
|
||||
) -> CliErrorHandler[CollectionCommandT]:
|
||||
return CliErrorHandler(
|
||||
exception_type=NotFoundError,
|
||||
message=_collection_not_found_message,
|
||||
)
|
||||
|
||||
|
||||
def _collection_not_found_message(command: HasCollection, _: Exception) -> str:
|
||||
def _collection_not_found_message(command: CollectionCommandT, _: Exception) -> str:
|
||||
return f"Collection '{command.collection}' does not exist."
|
||||
|
||||
|
||||
|
||||
+15
-8
@@ -47,7 +47,9 @@ COMMAND_SPECS: tuple[CommandSpec, ...] = (
|
||||
CommandSpec(
|
||||
name="add-data",
|
||||
aliases=("ad",),
|
||||
help="Chunk, embed, and add a file to a collection in the local Chroma database.",
|
||||
help=(
|
||||
"Chunk, embed, and add a file to a collection in the local Chroma database."
|
||||
),
|
||||
arguments=(
|
||||
ArgumentSpec("collection", "Name of the target collection."),
|
||||
ArgumentSpec(
|
||||
@@ -93,15 +95,20 @@ def _add_command(
|
||||
)
|
||||
|
||||
for argument in command.arguments:
|
||||
argument_kwargs: dict[str, object] = {"help": argument.help}
|
||||
|
||||
if argument.metavar is not None:
|
||||
argument_kwargs["metavar"] = argument.metavar
|
||||
|
||||
if argument.name.startswith("-"):
|
||||
argument_kwargs["required"] = argument.required
|
||||
subparser.add_argument(
|
||||
argument.name,
|
||||
help=argument.help,
|
||||
metavar=argument.metavar,
|
||||
required=argument.required,
|
||||
)
|
||||
continue
|
||||
|
||||
subparser.add_argument(argument.name, **argument_kwargs)
|
||||
subparser.add_argument(
|
||||
argument.name,
|
||||
help=argument.help,
|
||||
metavar=argument.metavar,
|
||||
)
|
||||
|
||||
subparser.set_defaults(command=command.name)
|
||||
|
||||
|
||||
+6
-4
@@ -1,14 +1,16 @@
|
||||
from typing import List, TypedDict
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
|
||||
|
||||
|
||||
class EmbeddingRecord(TypedDict):
|
||||
text: str
|
||||
embedding: List[float]
|
||||
embedding: list[float]
|
||||
|
||||
|
||||
def embed(chunks: List[str]) -> List[EmbeddingRecord]:
|
||||
def embed(chunks: list[str]) -> list[EmbeddingRecord]:
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
@@ -22,5 +24,5 @@ def embed(chunks: List[str]) -> List[EmbeddingRecord]:
|
||||
embedding.tolist() if hasattr(embedding, "tolist") else list(embedding)
|
||||
),
|
||||
}
|
||||
for text, embedding in zip(chunks, embeddings)
|
||||
for text, embedding in zip(chunks, embeddings, strict=False)
|
||||
]
|
||||
|
||||
@@ -8,15 +8,13 @@ def _parse_where_clause(where_clause: str) -> dict[str, str]:
|
||||
condition, separator, value = where_clause.partition("=")
|
||||
|
||||
if separator == "":
|
||||
raise ValueError(
|
||||
"Invalid --where value. Expected <condition>=<value>.")
|
||||
raise ValueError("Invalid --where value. Expected <condition>=<value>.")
|
||||
|
||||
condition = condition.strip()
|
||||
value = value.strip()
|
||||
|
||||
if not condition or not value:
|
||||
raise ValueError(
|
||||
"Invalid --where value. Expected <condition>=<value>.")
|
||||
raise ValueError("Invalid --where value. Expected <condition>=<value>.")
|
||||
|
||||
return {condition: value}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from chromy.command_inputs import ListCollectionsInput
|
||||
from chromy.chroma_functions import list_collections
|
||||
from chromy.command_inputs import ListCollectionsInput
|
||||
from chromy.utilities import print_lines
|
||||
|
||||
|
||||
|
||||
+2
-1
@@ -1,6 +1,7 @@
|
||||
from chromadb import QueryResult
|
||||
from collections.abc import Mapping
|
||||
|
||||
from chromadb import QueryResult
|
||||
|
||||
from chromy.chroma_functions import add_data, query_data
|
||||
from chromy.chunk_functions import chunk_file
|
||||
from chromy.embed import embed
|
||||
|
||||
Reference in New Issue
Block a user