decouple core data from CLI formatting
This commit is contained in:
@@ -8,7 +8,6 @@ import chromadb
|
|||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI
|
||||||
from chromadb.api.types import QueryResult, Where
|
from chromadb.api.types import QueryResult, Where
|
||||||
from chromadb.errors import NotFoundError
|
from chromadb.errors import NotFoundError
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from chromy.embed import EmbeddingRecord
|
from chromy.embed import EmbeddingRecord
|
||||||
|
|
||||||
@@ -26,17 +25,14 @@ def _get_client_and_collection(
|
|||||||
return client, collection
|
return client, collection
|
||||||
|
|
||||||
|
|
||||||
def list_collections() -> list[Text]:
|
def list_collections() -> list[str]:
|
||||||
client = chromadb.PersistentClient()
|
client = chromadb.PersistentClient()
|
||||||
collections = client.list_collections()
|
collections = client.list_collections()
|
||||||
|
|
||||||
if not collections:
|
if not collections:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [
|
return [getattr(collection, "name", str(collection)) for collection in collections]
|
||||||
Text("· " + getattr(collection, "name", str(collection)))
|
|
||||||
for collection in collections
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def create_collection(name: str) -> str:
|
def create_collection(name: str) -> str:
|
||||||
@@ -58,13 +54,9 @@ def delete_data(collection_name: str, where: dict[str, str]) -> int:
|
|||||||
return int(result.get("deleted", 0))
|
return int(result.get("deleted", 0))
|
||||||
|
|
||||||
|
|
||||||
def count_collection(collection_name: str) -> str:
|
def count_collection(collection_name: str) -> int:
|
||||||
_, collection = _get_client_and_collection(collection_name)
|
_, collection = _get_client_and_collection(collection_name)
|
||||||
count = collection.count()
|
return collection.count()
|
||||||
|
|
||||||
return (
|
|
||||||
f"The '{collection_name}' collection contains [bold green]{count}[/] records."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_data(
|
def add_data(
|
||||||
|
|||||||
+2
-2
@@ -1,12 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from plistlib import InvalidFileException
|
|
||||||
from typing import Annotated, Callable
|
from typing import Annotated, Callable
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
from chromadb.errors import InternalError, NotFoundError
|
from chromadb.errors import InternalError, NotFoundError
|
||||||
from rich import print
|
from rich import print
|
||||||
|
|
||||||
|
from chromy.errors import UnsupportedTextFileError
|
||||||
from chromy.handlers.count_collection import handle_count_collection
|
from chromy.handlers.count_collection import handle_count_collection
|
||||||
from chromy.handlers.create_collection import handle_create_collection
|
from chromy.handlers.create_collection import handle_create_collection
|
||||||
from chromy.handlers.delete_collection import (
|
from chromy.handlers.delete_collection import (
|
||||||
@@ -124,7 +124,7 @@ def import_data(
|
|||||||
_fail(f"Collection '{collection}' does not exist.")
|
_fail(f"Collection '{collection}' does not exist.")
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
_fail(f"The file '{file}' was not found.")
|
_fail(f"The file '{file}' was not found.")
|
||||||
except InvalidFileException:
|
except UnsupportedTextFileError:
|
||||||
_fail(f"The file '{file}' is not a text file.")
|
_fail(f"The file '{file}' is not a text file.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedTextFileError(Exception):
|
||||||
|
"""Raised when a file does not appear to contain supported text content."""
|
||||||
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|||||||
from rich import print
|
from rich import print
|
||||||
|
|
||||||
from chromy.chroma_functions import count_collection
|
from chromy.chroma_functions import count_collection
|
||||||
|
from chromy.output import format_count_message
|
||||||
|
|
||||||
|
|
||||||
def handle_count_collection(collection: str) -> int:
|
def handle_count_collection(collection: str) -> int:
|
||||||
print(count_collection(collection))
|
print(format_count_message(collection, count_collection(collection)))
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from plistlib import InvalidFileException
|
|
||||||
|
|
||||||
from rich import print
|
from rich import print
|
||||||
|
|
||||||
|
from chromy.errors import UnsupportedTextFileError
|
||||||
from chromy.utilities import ingest_file
|
from chromy.utilities import ingest_file
|
||||||
|
|
||||||
from ..utilities import is_probably_text_file
|
from ..utilities import is_probably_text_file
|
||||||
@@ -33,8 +33,8 @@ def handle_import(collection: str, file: str) -> int:
|
|||||||
absolute_path = _get_absolute_path(file)
|
absolute_path = _get_absolute_path(file)
|
||||||
|
|
||||||
if not is_probably_text_file(absolute_path):
|
if not is_probably_text_file(absolute_path):
|
||||||
raise InvalidFileException()
|
raise UnsupportedTextFileError()
|
||||||
|
|
||||||
records_added = ingest_file(collection, _get_absolute_path(file))
|
records_added = ingest_file(collection, absolute_path)
|
||||||
print(f"[bold green]Added[/] {records_added} records to collection '{collection}'.")
|
print(f"[bold green]Added[/] {records_added} records to collection '{collection}'.")
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from chromy.chroma_functions import list_collections
|
from chromy.chroma_functions import list_collections
|
||||||
from chromy.utilities import print_lines
|
from chromy.output import format_collection_names, print_lines
|
||||||
|
|
||||||
|
|
||||||
def handle_list_collections() -> int:
|
def handle_list_collections() -> int:
|
||||||
@@ -11,6 +11,6 @@ def handle_list_collections() -> int:
|
|||||||
print("No collections found.")
|
print("No collections found.")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
print_lines(collections)
|
print_lines(format_collection_names(collections))
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from chromy.utilities import format_query_result, print_lines, run_query
|
from chromy.output import format_query_result, print_lines
|
||||||
|
from chromy.utilities import run_query
|
||||||
|
|
||||||
|
|
||||||
def handle_query(collection: str, query_text: str) -> int:
|
def handle_query(collection: str, query_text: str) -> int:
|
||||||
|
|||||||
@@ -0,0 +1,70 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
|
||||||
|
from chromadb import QueryResult
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.rule import Rule
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
CONSOLE = Console()
|
||||||
|
|
||||||
|
|
||||||
|
def print_lines(lines: Sequence[Rule | Text | str]) -> None:
|
||||||
|
for line in lines:
|
||||||
|
CONSOLE.print(line)
|
||||||
|
|
||||||
|
|
||||||
|
def format_collection_names(collections: Sequence[str]) -> list[Text]:
|
||||||
|
return [Text(f"· {collection}") for collection in collections]
|
||||||
|
|
||||||
|
|
||||||
|
def format_count_message(collection_name: str, count: int) -> str:
|
||||||
|
return (
|
||||||
|
f"The '{collection_name}' collection contains [bold green]{count}[/] records."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_query_result(result: QueryResult) -> list[Rule | Text]:
|
||||||
|
ids = result.get("ids", [[]])
|
||||||
|
documents = result.get("documents", [[]])
|
||||||
|
distances = result.get("distances", [[]])
|
||||||
|
metadatas = result.get("metadatas", [[]])
|
||||||
|
|
||||||
|
first_ids = ids[0] if ids else []
|
||||||
|
first_documents = documents[0] if documents else []
|
||||||
|
first_distances = distances[0] if distances else []
|
||||||
|
first_metadatas = metadatas[0] if metadatas else []
|
||||||
|
|
||||||
|
if not first_ids:
|
||||||
|
return [Text.from_markup("[yellow]No results found.[/]")]
|
||||||
|
|
||||||
|
lines: list[Rule | Text] = [Rule(title="Query results")]
|
||||||
|
|
||||||
|
for index, document_id in enumerate(first_ids, start=1):
|
||||||
|
lines.append(
|
||||||
|
Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}")
|
||||||
|
)
|
||||||
|
i = index - 1
|
||||||
|
|
||||||
|
if i < len(first_distances):
|
||||||
|
lines.append(
|
||||||
|
Text.from_markup(f"\t[green]distance[/]\t{first_distances[i]}")
|
||||||
|
)
|
||||||
|
|
||||||
|
if i < len(first_metadatas):
|
||||||
|
metadata = first_metadatas[i]
|
||||||
|
|
||||||
|
if isinstance(metadata, Mapping):
|
||||||
|
file_name = metadata.get("file_name")
|
||||||
|
|
||||||
|
if file_name:
|
||||||
|
lines.append(Text.from_markup(f"\t[green]file_name[/]\t{file_name}"))
|
||||||
|
|
||||||
|
if i < len(first_documents):
|
||||||
|
lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n"))
|
||||||
|
lines.append(Text(first_documents[i]))
|
||||||
|
|
||||||
|
lines.append(Rule())
|
||||||
|
|
||||||
|
return lines
|
||||||
@@ -1,24 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from chromadb import QueryResult
|
from chromadb import QueryResult
|
||||||
from rich.console import Console
|
|
||||||
from rich.rule import Rule
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from chromy.chroma_functions import add_data, query_data
|
from chromy.chroma_functions import add_data, query_data
|
||||||
from chromy.chunk_functions import chunk_file
|
from chromy.chunk_functions import chunk_file
|
||||||
from chromy.embed import embed
|
from chromy.embed import embed
|
||||||
|
|
||||||
CONSOLE = Console()
|
|
||||||
|
|
||||||
|
|
||||||
def print_lines(lines: Sequence[Rule | Text]) -> None:
|
|
||||||
for line in lines:
|
|
||||||
CONSOLE.print(line)
|
|
||||||
|
|
||||||
|
|
||||||
def ingest_file(collection_name: str, file_path: str) -> int:
|
def ingest_file(collection_name: str, file_path: str) -> int:
|
||||||
chunks = chunk_file(file_path)
|
chunks = chunk_file(file_path)
|
||||||
@@ -31,54 +20,6 @@ def run_query(collection_name: str, query_text: str) -> QueryResult:
|
|||||||
return query_data(collection_name, [query_text])
|
return query_data(collection_name, [query_text])
|
||||||
|
|
||||||
|
|
||||||
def format_query_result(result: QueryResult) -> list[Rule | Text]:
|
|
||||||
ids = result.get("ids", [[]])
|
|
||||||
documents = result.get("documents", [[]])
|
|
||||||
distances = result.get("distances", [[]])
|
|
||||||
metadatas = result.get("metadatas", [[]])
|
|
||||||
|
|
||||||
first_ids = ids[0] if ids else []
|
|
||||||
first_documents = documents[0] if documents else []
|
|
||||||
first_distances = distances[0] if distances else []
|
|
||||||
first_metadatas = metadatas[0] if metadatas else []
|
|
||||||
|
|
||||||
if not first_ids:
|
|
||||||
return [Text.from_markup("[yellow]No results found.[/]")]
|
|
||||||
|
|
||||||
lines: list[Rule | Text] = [Rule(title="Query results")]
|
|
||||||
|
|
||||||
for index, document_id in enumerate(first_ids, start=1):
|
|
||||||
lines.append(
|
|
||||||
Text.from_markup(f"[bold]{index}[/].\t[green]id[/]\t\t{document_id}")
|
|
||||||
)
|
|
||||||
i = index - 1
|
|
||||||
|
|
||||||
if i < len(first_distances):
|
|
||||||
lines.append(
|
|
||||||
Text.from_markup(f"\t[green]distance[/]\t{first_distances[i]}")
|
|
||||||
)
|
|
||||||
|
|
||||||
if i < len(first_metadatas):
|
|
||||||
metadata = first_metadatas[i]
|
|
||||||
|
|
||||||
if isinstance(metadata, Mapping):
|
|
||||||
file_name = metadata.get("file_name")
|
|
||||||
|
|
||||||
if file_name:
|
|
||||||
lines.append(
|
|
||||||
Text.from_markup(f"\t[green]file_name[/]\t{file_name}")
|
|
||||||
)
|
|
||||||
|
|
||||||
if i < len(first_documents):
|
|
||||||
lines.append(Text.from_markup("\n[bold green]Retrieved contents[/]\n"))
|
|
||||||
lines.append(Text(first_documents[i]))
|
|
||||||
|
|
||||||
# Print a separator between documents
|
|
||||||
lines.append(Rule())
|
|
||||||
|
|
||||||
return lines
|
|
||||||
|
|
||||||
|
|
||||||
def is_probably_text_file(path: str | Path, sample_size: int = 8192) -> bool:
|
def is_probably_text_file(path: str | Path, sample_size: int = 8192) -> bool:
|
||||||
"""
|
"""
|
||||||
Return whether a file appears to contain text.
|
Return whether a file appears to contain text.
|
||||||
|
|||||||
+19
-3
@@ -31,11 +31,11 @@ class CliTests(unittest.TestCase):
|
|||||||
with patch(
|
with patch(
|
||||||
"chromy.handlers.list_collections.list_collections",
|
"chromy.handlers.list_collections.list_collections",
|
||||||
return_value=["books", "code"],
|
return_value=["books", "code"],
|
||||||
):
|
):
|
||||||
result = _invoke(["list-collections"])
|
result = _invoke(["list-collections"])
|
||||||
|
|
||||||
self.assertEqual(result.exit_code, 0)
|
self.assertEqual(result.exit_code, 0)
|
||||||
self.assertEqual(result.stdout, "books\ncode\n")
|
self.assertEqual(result.stdout, "· books\n· code\n")
|
||||||
|
|
||||||
def test_create_collection(self) -> None:
|
def test_create_collection(self) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
@@ -89,7 +89,10 @@ class CliTests(unittest.TestCase):
|
|||||||
|
|
||||||
count_collection.assert_called_once_with("notes")
|
count_collection.assert_called_once_with("notes")
|
||||||
self.assertEqual(result.exit_code, 0)
|
self.assertEqual(result.exit_code, 0)
|
||||||
self.assertEqual(result.stdout, "7\n")
|
self.assertEqual(
|
||||||
|
result.stdout,
|
||||||
|
"The 'notes' collection contains 7 records.\n",
|
||||||
|
)
|
||||||
|
|
||||||
def test_import_data(self) -> None:
|
def test_import_data(self) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
@@ -105,6 +108,19 @@ class CliTests(unittest.TestCase):
|
|||||||
self.assertEqual(result.exit_code, 0)
|
self.assertEqual(result.exit_code, 0)
|
||||||
self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n")
|
self.assertEqual(result.stdout, "Added 3 records to collection 'notes'.\n")
|
||||||
|
|
||||||
|
def test_import_data_rejects_non_text_files(self) -> None:
|
||||||
|
with patch(
|
||||||
|
"chromy.handlers.import_data.is_probably_text_file",
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
|
result = _invoke(["import", "notes", "romeo_and_juliet.txt"])
|
||||||
|
|
||||||
|
self.assertEqual(result.exit_code, 1)
|
||||||
|
self.assertEqual(
|
||||||
|
result.stdout,
|
||||||
|
"Error: The file 'romeo_and_juliet.txt' is not a text file.\n",
|
||||||
|
)
|
||||||
|
|
||||||
def test_query(self) -> None:
|
def test_query(self) -> None:
|
||||||
query_result = {"ids": [["1"]], "documents": [["hello"]]}
|
query_result = {"ids": [["1"]], "documents": [["hello"]]}
|
||||||
|
|
||||||
|
|||||||
+20
-2
@@ -1,11 +1,29 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from chromy.embed import embed
|
||||||
|
|
||||||
|
|
||||||
class EmbedTest(unittest.TestCase):
|
class EmbedTest(unittest.TestCase):
|
||||||
def test_embed_function(self) -> None:
|
def test_embed_returns_empty_list_for_empty_chunks(self) -> None:
|
||||||
self.assertEqual(0, 0)
|
self.assertEqual(embed([]), [])
|
||||||
|
|
||||||
|
def test_embed_pairs_text_with_list_embeddings(self) -> None:
|
||||||
|
with patch(
|
||||||
|
"chromy.embed.DefaultEmbeddingFunction",
|
||||||
|
return_value=lambda chunks: ((1.0, 2.0), (3.0, 4.0)),
|
||||||
|
):
|
||||||
|
result = embed(["first", "second"])
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
result,
|
||||||
|
[
|
||||||
|
{"text": "first", "embedding": [1.0, 2.0]},
|
||||||
|
{"text": "second", "embedding": [3.0, 4.0]},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
+13
-2
@@ -8,6 +8,7 @@ from pathlib import Path
|
|||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from chromy.errors import UnsupportedTextFileError
|
||||||
from chromy.handlers.count_collection import handle_count_collection
|
from chromy.handlers.count_collection import handle_count_collection
|
||||||
from chromy.handlers.create_collection import handle_create_collection
|
from chromy.handlers.create_collection import handle_create_collection
|
||||||
from chromy.handlers.delete_collection import (
|
from chromy.handlers.delete_collection import (
|
||||||
@@ -47,7 +48,7 @@ class HandlerTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(exit_code, 0)
|
self.assertEqual(exit_code, 0)
|
||||||
self.assertEqual(output, "notes\nplays\n")
|
self.assertEqual(output, "· notes\n· plays\n")
|
||||||
|
|
||||||
def test_create_collection_uses_typed_input(self) -> None:
|
def test_create_collection_uses_typed_input(self) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
@@ -86,7 +87,7 @@ class HandlerTests(unittest.TestCase):
|
|||||||
|
|
||||||
count.assert_called_once_with("notes")
|
count.assert_called_once_with("notes")
|
||||||
self.assertEqual(exit_code, 0)
|
self.assertEqual(exit_code, 0)
|
||||||
self.assertEqual(output, "7\n")
|
self.assertEqual(output, "The 'notes' collection contains 7 records.\n")
|
||||||
|
|
||||||
def test_import_data_uses_typed_input(self) -> None:
|
def test_import_data_uses_typed_input(self) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
@@ -106,6 +107,16 @@ class HandlerTests(unittest.TestCase):
|
|||||||
self.assertEqual(exit_code, 0)
|
self.assertEqual(exit_code, 0)
|
||||||
self.assertEqual(output, "Added 3 records to collection 'notes'.\n")
|
self.assertEqual(output, "Added 3 records to collection 'notes'.\n")
|
||||||
|
|
||||||
|
def test_import_data_rejects_non_text_files(self) -> None:
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"chromy.handlers.import_data.is_probably_text_file",
|
||||||
|
return_value=False,
|
||||||
|
),
|
||||||
|
self.assertRaises(UnsupportedTextFileError),
|
||||||
|
):
|
||||||
|
handle_import("notes", "romeo_and_juliet.txt")
|
||||||
|
|
||||||
def test_query_uses_typed_input(self) -> None:
|
def test_query_uses_typed_input(self) -> None:
|
||||||
query_result = {"ids": [["1"]], "documents": [["hello"]]}
|
query_result = {"ids": [["1"]], "documents": [["hello"]]}
|
||||||
with (
|
with (
|
||||||
|
|||||||
Reference in New Issue
Block a user