Compare commits
4 Commits
55bbd897f4
...
d71fce7a6a
| Author | SHA1 | Date | |
|---|---|---|---|
| d71fce7a6a | |||
| c6ad060e85 | |||
| c5b6b196b5 | |||
| 948f8500be |
@@ -8,6 +8,7 @@ import chromadb
|
|||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI
|
||||||
from chromadb.api.types import QueryResult, Where
|
from chromadb.api.types import QueryResult, Where
|
||||||
from chromadb.errors import NotFoundError
|
from chromadb.errors import NotFoundError
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
from chromy.embed import EmbeddingRecord
|
from chromy.embed import EmbeddingRecord
|
||||||
|
|
||||||
@@ -25,14 +26,17 @@ def _get_client_and_collection(
|
|||||||
return client, collection
|
return client, collection
|
||||||
|
|
||||||
|
|
||||||
def list_collections() -> list[str]:
|
def list_collections() -> list[Text]:
|
||||||
client = chromadb.PersistentClient()
|
client = chromadb.PersistentClient()
|
||||||
collections = client.list_collections()
|
collections = client.list_collections()
|
||||||
|
|
||||||
if not collections:
|
if not collections:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [getattr(collection, "name", str(collection)) for collection in collections]
|
return [
|
||||||
|
Text("· " + getattr(collection, "name", str(collection)))
|
||||||
|
for collection in collections
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_collection(name: str) -> str:
|
def create_collection(name: str) -> str:
|
||||||
@@ -58,7 +62,9 @@ def count_collection(collection_name: str) -> str:
|
|||||||
_, collection = _get_client_and_collection(collection_name)
|
_, collection = _get_client_and_collection(collection_name)
|
||||||
count = collection.count()
|
count = collection.count()
|
||||||
|
|
||||||
return f"The '{collection_name}' collection contains [bold green]{count}[/] records."
|
return (
|
||||||
|
f"The '{collection_name}' collection contains [bold green]{count}[/] records."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_data(
|
def add_data(
|
||||||
@@ -71,8 +77,7 @@ def add_data(
|
|||||||
|
|
||||||
_, collection = _get_client_and_collection(collection_name)
|
_, collection = _get_client_and_collection(collection_name)
|
||||||
|
|
||||||
embeddings: list[Sequence[float]] = [record["embedding"]
|
embeddings: list[Sequence[float]] = [record["embedding"] for record in data]
|
||||||
for record in data]
|
|
||||||
|
|
||||||
collection.add(
|
collection.add(
|
||||||
ids=[str(uuid4()) for _ in data],
|
ids=[str(uuid4()) for _ in data],
|
||||||
|
|||||||
+7
-5
@@ -1,18 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from plistlib import InvalidFileException
|
||||||
from typing import Annotated, Callable
|
from typing import Annotated, Callable
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
from rich import print
|
|
||||||
from chromadb.errors import InternalError, NotFoundError
|
from chromadb.errors import InternalError, NotFoundError
|
||||||
|
from rich import print
|
||||||
|
|
||||||
from chromy.handlers.import_data import handle_import
|
|
||||||
from chromy.handlers.count_collection import handle_count_collection
|
from chromy.handlers.count_collection import handle_count_collection
|
||||||
from chromy.handlers.create_collection import handle_create_collection
|
from chromy.handlers.create_collection import handle_create_collection
|
||||||
from chromy.handlers.delete_collection import (
|
from chromy.handlers.delete_collection import (
|
||||||
handle_delete_collection,
|
handle_delete_collection,
|
||||||
handle_delete_records,
|
handle_delete_records,
|
||||||
)
|
)
|
||||||
|
from chromy.handlers.import_data import handle_import
|
||||||
from chromy.handlers.list_collections import handle_list_collections
|
from chromy.handlers.list_collections import handle_list_collections
|
||||||
from chromy.handlers.query import handle_query
|
from chromy.handlers.query import handle_query
|
||||||
|
|
||||||
@@ -114,8 +115,7 @@ def import_data(
|
|||||||
],
|
],
|
||||||
file: Annotated[
|
file: Annotated[
|
||||||
str,
|
str,
|
||||||
typer.Argument(
|
typer.Argument(help="Path to the file to chunk and add to the collection."),
|
||||||
help="Path to the file to chunk and add to the collection."),
|
|
||||||
],
|
],
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -123,7 +123,9 @@ def import_data(
|
|||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
_fail(f"Collection '{collection}' does not exist.")
|
_fail(f"Collection '{collection}' does not exist.")
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
_fail(f"The file {file} was not found.")
|
_fail(f"The file '{file}' was not found.")
|
||||||
|
except InvalidFileException:
|
||||||
|
_fail(f"The file '{file}' is not a text file.")
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from rich import print
|
from rich import print
|
||||||
|
|
||||||
from chromy.chroma_functions import count_collection
|
from chromy.chroma_functions import count_collection
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from rich import print
|
from rich import print
|
||||||
|
|
||||||
from chromy.chroma_functions import create_collection
|
from chromy.chroma_functions import create_collection
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from rich import print
|
from rich import print
|
||||||
|
|
||||||
from chromy.chroma_functions import delete_collection, delete_data
|
from chromy.chroma_functions import delete_collection, delete_data
|
||||||
|
|
||||||
|
|
||||||
@@ -8,15 +9,13 @@ def _parse_where_clause(where_clause: str) -> dict[str, str]:
|
|||||||
condition, separator, value = where_clause.partition("=")
|
condition, separator, value = where_clause.partition("=")
|
||||||
|
|
||||||
if separator == "":
|
if separator == "":
|
||||||
raise ValueError(
|
raise ValueError("Invalid --where value. Expected <condition>=<value>.")
|
||||||
"Invalid --where value. Expected <condition>=<value>.")
|
|
||||||
|
|
||||||
condition = condition.strip()
|
condition = condition.strip()
|
||||||
value = value.strip()
|
value = value.strip()
|
||||||
|
|
||||||
if not condition or not value:
|
if not condition or not value:
|
||||||
raise ValueError(
|
raise ValueError("Invalid --where value. Expected <condition>=<value>.")
|
||||||
"Invalid --where value. Expected <condition>=<value>.")
|
|
||||||
|
|
||||||
return {condition: value}
|
return {condition: value}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from plistlib import InvalidFileException
|
||||||
|
|
||||||
from rich import print
|
from rich import print
|
||||||
|
|
||||||
from chromy.utilities import ingest_file
|
from chromy.utilities import ingest_file
|
||||||
|
|
||||||
|
from ..utilities import is_probably_text_file
|
||||||
|
|
||||||
|
|
||||||
def _get_absolute_path(file: str) -> str:
|
def _get_absolute_path(file: str) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -21,11 +26,15 @@ def _get_absolute_path(file: str) -> str:
|
|||||||
raise FileNotFoundError()
|
raise FileNotFoundError()
|
||||||
|
|
||||||
file_path = Path(file)
|
file_path = Path(file)
|
||||||
return str(file_path.resolve(file_path))
|
return str(file_path.resolve())
|
||||||
|
|
||||||
|
|
||||||
def handle_import(collection: str, file: str) -> int:
|
def handle_import(collection: str, file: str) -> int:
|
||||||
|
absolute_path = _get_absolute_path(file)
|
||||||
|
|
||||||
|
if not is_probably_text_file(absolute_path):
|
||||||
|
raise InvalidFileException()
|
||||||
|
|
||||||
records_added = ingest_file(collection, _get_absolute_path(file))
|
records_added = ingest_file(collection, _get_absolute_path(file))
|
||||||
print(
|
print(f"[bold green]Added[/] {records_added} records to collection '{collection}'.")
|
||||||
f"[bold green]Added[/] {records_added} records to collection '{collection}'.")
|
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -6,9 +6,11 @@ from chromy.utilities import print_lines
|
|||||||
|
|
||||||
def handle_list_collections() -> int:
|
def handle_list_collections() -> int:
|
||||||
collections = list_collections()
|
collections = list_collections()
|
||||||
|
|
||||||
if not collections:
|
if not collections:
|
||||||
print("No collections found.")
|
print("No collections found.")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
print_lines(collections)
|
print_lines(collections)
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
+47
-10
@@ -1,12 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from rich.text import Text
|
|
||||||
from rich.rule import Rule
|
|
||||||
from rich.console import Console
|
|
||||||
|
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from chromadb import QueryResult
|
from chromadb import QueryResult
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.rule import Rule
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
from chromy.chroma_functions import add_data, query_data
|
from chromy.chroma_functions import add_data, query_data
|
||||||
from chromy.chunk_functions import chunk_file
|
from chromy.chunk_functions import chunk_file
|
||||||
@@ -15,7 +15,7 @@ from chromy.embed import embed
|
|||||||
CONSOLE = Console()
|
CONSOLE = Console()
|
||||||
|
|
||||||
|
|
||||||
def print_lines(lines: Sequence[str]) -> None:
|
def print_lines(lines: Sequence[Rule | Text]) -> None:
|
||||||
for line in lines:
|
for line in lines:
|
||||||
CONSOLE.print(line)
|
CONSOLE.print(line)
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ def run_query(collection_name: str, query_text: str) -> QueryResult:
|
|||||||
return query_data(collection_name, [query_text])
|
return query_data(collection_name, [query_text])
|
||||||
|
|
||||||
|
|
||||||
def format_query_result(result: QueryResult) -> list[str]:
|
def format_query_result(result: QueryResult) -> list[Rule | Text]:
|
||||||
ids = result.get("ids", [[]])
|
ids = result.get("ids", [[]])
|
||||||
documents = result.get("documents", [[]])
|
documents = result.get("documents", [[]])
|
||||||
distances = result.get("distances", [[]])
|
distances = result.get("distances", [[]])
|
||||||
@@ -43,12 +43,11 @@ def format_query_result(result: QueryResult) -> list[str]:
|
|||||||
first_metadatas = metadatas[0] if metadatas else []
|
first_metadatas = metadatas[0] if metadatas else []
|
||||||
|
|
||||||
if not first_ids:
|
if not first_ids:
|
||||||
return ["No results found."]
|
return [Text.from_markup("[yellow]No results found.[/]")]
|
||||||
|
|
||||||
lines = [Rule(title="Query results")]
|
lines: list[Rule | Text] = [Rule(title="Query results")]
|
||||||
|
|
||||||
for index, document_id in enumerate(first_ids, start=1):
|
for index, document_id in enumerate(first_ids, start=1):
|
||||||
# lines.append(f"{index}.\tid: {document_id}")
|
|
||||||
lines.append(
|
lines.append(
|
||||||
Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}")
|
Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}")
|
||||||
)
|
)
|
||||||
@@ -72,9 +71,47 @@ def format_query_result(result: QueryResult) -> list[str]:
|
|||||||
|
|
||||||
if i < len(first_documents):
|
if i < len(first_documents):
|
||||||
lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n"))
|
lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n"))
|
||||||
lines.append(first_documents[i])
|
lines.append(Text(first_documents[i]))
|
||||||
|
|
||||||
# Print a separator between documents
|
# Print a separator between documents
|
||||||
lines.append(Rule())
|
lines.append(Rule())
|
||||||
|
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
def is_probably_text_file(path: str | Path, sample_size: int = 8192) -> bool:
|
||||||
|
"""
|
||||||
|
Return whether a file appears to contain text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str | Path): The path to the file to inspect.
|
||||||
|
sample_size (int): The maximum number of bytes to read from the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: ``True`` if the sampled bytes decode as UTF-8, UTF-8 with BOM,
|
||||||
|
UTF-16, or UTF-32, or if the file is empty. Otherwise, ``False``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
with path.open("rb") as f:
|
||||||
|
sample = f.read(sample_size)
|
||||||
|
|
||||||
|
if not sample:
|
||||||
|
return True
|
||||||
|
|
||||||
|
encodings = (
|
||||||
|
"utf-8",
|
||||||
|
"utf-8-sig",
|
||||||
|
"utf-16",
|
||||||
|
"utf-32",
|
||||||
|
)
|
||||||
|
|
||||||
|
for encoding in encodings:
|
||||||
|
try:
|
||||||
|
sample.decode(encoding)
|
||||||
|
return True
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|||||||
+6
-11
@@ -51,15 +51,13 @@ class CliTests(unittest.TestCase):
|
|||||||
def test_create_collection_with_same_name(self) -> None:
|
def test_create_collection_with_same_name(self) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"chromy.handlers.create_collection.create_collection",
|
"chromy.handlers.create_collection.create_collection",
|
||||||
side_effect=InternalError()
|
side_effect=InternalError(),
|
||||||
|
|
||||||
) as create_collection:
|
) as create_collection:
|
||||||
result = _invoke(["create-collection", "notes"])
|
result = _invoke(["create-collection", "notes"])
|
||||||
|
|
||||||
create_collection.assert_called_once_with("notes")
|
create_collection.assert_called_once_with("notes")
|
||||||
self.assertEqual(result.exit_code, 1)
|
self.assertEqual(result.exit_code, 1)
|
||||||
self.assertEqual(
|
self.assertEqual(result.stdout, "Error: Collection 'notes' already exists.\n")
|
||||||
result.stdout, "Error: Collection 'notes' already exists.\n")
|
|
||||||
|
|
||||||
def test_delete_collection(self) -> None:
|
def test_delete_collection(self) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
@@ -74,14 +72,13 @@ class CliTests(unittest.TestCase):
|
|||||||
def test_delete_non_existent_collection(self) -> None:
|
def test_delete_non_existent_collection(self) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"chromy.handlers.delete_collection.delete_collection",
|
"chromy.handlers.delete_collection.delete_collection",
|
||||||
side_effect=NotFoundError()
|
side_effect=NotFoundError(),
|
||||||
) as delete_collection:
|
) as delete_collection:
|
||||||
result = _invoke(["delete-collection", "notes"])
|
result = _invoke(["delete-collection", "notes"])
|
||||||
|
|
||||||
delete_collection.assert_called_once_with("notes")
|
delete_collection.assert_called_once_with("notes")
|
||||||
self.assertEqual(result.exit_code, 1)
|
self.assertEqual(result.exit_code, 1)
|
||||||
self.assertEqual(
|
self.assertEqual(result.stdout, "Error: Collection 'notes' does not exist.\n")
|
||||||
result.stdout, "Error: Collection 'notes' does not exist.\n")
|
|
||||||
|
|
||||||
def test_count(self) -> None:
|
def test_count(self) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
@@ -106,8 +103,7 @@ class CliTests(unittest.TestCase):
|
|||||||
self._fixture_path("romeo_and_juliet.txt"),
|
self._fixture_path("romeo_and_juliet.txt"),
|
||||||
)
|
)
|
||||||
self.assertEqual(result.exit_code, 0)
|
self.assertEqual(result.exit_code, 0)
|
||||||
self.assertEqual(
|
self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n")
|
||||||
result.stdout, "Added 3 records to collection 'notes'.\n")
|
|
||||||
|
|
||||||
def test_query(self) -> None:
|
def test_query(self) -> None:
|
||||||
query_result = {"ids": [["1"]], "documents": [["hello"]]}
|
query_result = {"ids": [["1"]], "documents": [["hello"]]}
|
||||||
@@ -139,8 +135,7 @@ class CliTests(unittest.TestCase):
|
|||||||
self.assertEqual(result.exit_code, 0)
|
self.assertEqual(result.exit_code, 0)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.stdout,
|
result.stdout,
|
||||||
"Deleted 2 record(s) from collection 'notes' "
|
"Deleted 2 record(s) from collection 'notes' where file_name=play.txt.\n",
|
||||||
"where file_name=play.txt.\n",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_invalid_delete_filter_keeps_user_facing_error(self) -> None:
|
def test_invalid_delete_filter_keeps_user_facing_error(self) -> None:
|
||||||
|
|||||||
@@ -8,13 +8,13 @@ from pathlib import Path
|
|||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from chromy.handlers.import_data import handle_import
|
|
||||||
from chromy.handlers.count_collection import handle_count_collection
|
from chromy.handlers.count_collection import handle_count_collection
|
||||||
from chromy.handlers.create_collection import handle_create_collection
|
from chromy.handlers.create_collection import handle_create_collection
|
||||||
from chromy.handlers.delete_collection import (
|
from chromy.handlers.delete_collection import (
|
||||||
handle_delete_collection,
|
handle_delete_collection,
|
||||||
handle_delete_records,
|
handle_delete_records,
|
||||||
)
|
)
|
||||||
|
from chromy.handlers.import_data import handle_import
|
||||||
from chromy.handlers.list_collections import handle_list_collections
|
from chromy.handlers.list_collections import handle_list_collections
|
||||||
from chromy.handlers.query import handle_query
|
from chromy.handlers.query import handle_query
|
||||||
|
|
||||||
@@ -154,7 +154,7 @@ class HandlerTests(unittest.TestCase):
|
|||||||
|
|
||||||
def _capture_output(
|
def _capture_output(
|
||||||
handler: Callable[..., int],
|
handler: Callable[..., int],
|
||||||
*arguments: CommandT,
|
*arguments: object,
|
||||||
) -> tuple[int, str]:
|
) -> tuple[int, str]:
|
||||||
output = io.StringIO()
|
output = io.StringIO()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user