from __future__ import annotations import io import unittest from collections.abc import Callable from contextlib import redirect_stdout from typing import TypeVar from unittest.mock import patch from chromy.command_inputs import ( AddDataInput, CountCollectionInput, CreateCollectionInput, DeleteCollectionInput, DeleteRecordsInput, ListCollectionsInput, QueryInput, ) from chromy.handlers.add_data import handle_add_data from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.create_collection import handle_create_collection from chromy.handlers.delete_collection import ( handle_delete_collection, handle_delete_records, ) from chromy.handlers.list_collections import handle_list_collections from chromy.handlers.query import handle_query CommandT = TypeVar("CommandT") class HandlerTests(unittest.TestCase): def test_list_collections_prints_empty_message(self) -> None: with patch( "chromy.handlers.list_collections.list_collections", return_value=[] ): exit_code, output = _capture_output( handle_list_collections, ListCollectionsInput(), ) self.assertEqual(exit_code, 0) self.assertEqual(output, "No collections found.\n") def test_list_collections_prints_collection_names(self) -> None: with patch( "chromy.handlers.list_collections.list_collections", return_value=["notes", "plays"], ): exit_code, output = _capture_output( handle_list_collections, ListCollectionsInput(), ) self.assertEqual(exit_code, 0) self.assertEqual(output, "notes\nplays\n") def test_create_collection_uses_typed_input(self) -> None: with patch( "chromy.handlers.create_collection.create_collection", return_value="notes", ) as create_collection: exit_code, output = _capture_output( handle_create_collection, CreateCollectionInput(collection="notes"), ) create_collection.assert_called_once_with("notes") self.assertEqual(exit_code, 0) self.assertEqual(output, "Created collection 'notes'.\n") def test_delete_collection_uses_typed_input(self) -> None: with patch("chromy.handlers.delete_collection.delete_collection") as delete: exit_code, output = _capture_output( handle_delete_collection, DeleteCollectionInput(collection="notes"), ) delete.assert_called_once_with("notes") self.assertEqual(exit_code, 0) self.assertEqual(output, "Deleted collection 'notes'.\n") def test_count_collection_uses_typed_input(self) -> None: with patch( "chromy.handlers.count_collection.count_collection", return_value=7, ) as count: exit_code, output = _capture_output( handle_count_collection, CountCollectionInput(collection="notes"), ) count.assert_called_once_with("notes") self.assertEqual(exit_code, 0) self.assertEqual(output, "7\n") def test_add_data_uses_typed_input(self) -> None: with patch( "chromy.handlers.add_data.ingest_file", return_value=3, ) as ingest_file: exit_code, output = _capture_output( handle_add_data, AddDataInput(collection="notes", file="romeo_and_juliet.txt"), ) ingest_file.assert_called_once_with("notes", "romeo_and_juliet.txt") self.assertEqual(exit_code, 0) self.assertEqual(output, "Added 3 records to collection 'notes'.\n") def test_query_uses_typed_input(self) -> None: query_result = {"ids": [["1"]], "documents": [["hello"]]} with ( patch("chromy.handlers.query.run_query", return_value=query_result) as run, patch( "chromy.handlers.query.format_query_result", return_value=["Query results:", "1"], ) as format_result, ): exit_code, output = _capture_output( handle_query, QueryInput(collection="notes", query_text="hello"), ) run.assert_called_once_with("notes", "hello") format_result.assert_called_once_with(query_result) self.assertEqual(exit_code, 0) self.assertEqual(output, "Query results:\n1\n") def test_delete_records_parses_where_filter(self) -> None: with patch( "chromy.handlers.delete_collection.delete_data", return_value=2, ) as delete_data: exit_code, output = _capture_output( handle_delete_records, DeleteRecordsInput(collection="notes", where=" file_name = play.txt "), ) delete_data.assert_called_once_with("notes", {"file_name": "play.txt"}) self.assertEqual(exit_code, 0) self.assertEqual( output, "Deleted 2 record(s) from collection 'notes' where file_name=play.txt.\n", ) def test_delete_records_rejects_invalid_where_filter(self) -> None: with self.assertRaisesRegex( ValueError, "Invalid --where value. Expected =.", ): handle_delete_records( DeleteRecordsInput(collection="notes", where="file_name") ) def _capture_output( handler: Callable[[CommandT], int], command: CommandT, ) -> tuple[int, str]: output = io.StringIO() with redirect_stdout(output): exit_code = handler(command) return exit_code, output.getvalue() if __name__ == "__main__": unittest.main()