diff --git a/README.md b/README.md index f562055..45f21dd 100644 --- a/README.md +++ b/README.md @@ -195,3 +195,4 @@ Query results include the stored document chunk, its id, distance, and file name - quoted glob patterns such as `"*.md"` are treated as literal paths and are not expanded by `chromy` - unmatched unquoted globs may behave differently by shell: `zsh` commonly fails before `chromy` starts, while `bash` may pass the literal pattern through depending on shell settings - the CLI reports file-specific import failures and continues with the remaining files +- when importing multiple files in an interactive terminal, the CLI shows a Rich progress bar diff --git a/chromy/handlers/import_data.py b/chromy/handlers/import_data.py index 15f8fec..9a4b319 100644 --- a/chromy/handlers/import_data.py +++ b/chromy/handlers/import_data.py @@ -1,10 +1,18 @@ from __future__ import annotations import os +import sys from pathlib import Path from typing import Final from rich import print +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TextColumn, +) from chromy.errors import UnsupportedTextFileError from chromy.utilities import ingest_file @@ -42,34 +50,76 @@ def _import_one(collection: str, file: str) -> int: if not is_probably_text_file(absolute_path): raise UnsupportedTextFileError() - records_added = ingest_file(collection, absolute_path) - print( - "[bold green]Added[/] " - f"{records_added} records from '{file}' to collection '{collection}'." - ) - return SUCCESS_EXIT_CODE + return ingest_file(collection, absolute_path) + + +def _should_show_progress(file_count: int) -> bool: + return file_count > 1 and sys.stdout.isatty() + + +def _truncate_file_name(file_name: str, max_length: int = 20) -> str: + if len(file_name) <= max_length: + return file_name + + return f"{file_name[: max_length - 3]}" def handle_import(collection: str, files: list[str]) -> int: successful_imports = 0 failed_imports = 0 seen_paths: set[str] = set() + unique_files: list[str] = [] for file in files: try: absolute_path = _get_absolute_path(file) - if absolute_path in seen_paths: - continue - - seen_paths.add(absolute_path) - _import_one(collection, file) - successful_imports += 1 except FileNotFoundError: - failed_imports += 1 - print(f"[bold red]Error[/]: The file '{file}' was not found.") - except UnsupportedTextFileError: - failed_imports += 1 - print(f"[bold red]Error[/]: The file '{file}' is not a text file.") + unique_files.append(file) + continue + + if absolute_path in seen_paths: + continue + + seen_paths.add(absolute_path) + unique_files.append(file) + + show_progress = _should_show_progress(len(unique_files)) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + transient=True, + disable=not show_progress, + ) as progress: + task_id = progress.add_task("Importing files...", total=len(unique_files)) + + for file in unique_files: + file_name = _truncate_file_name(Path(file).name) + description = f"Importing [bold]{file_name}[/]..." + progress.update(task_id, description=description) + try: + records_added = _import_one(collection, file) + successful_imports += 1 + if not show_progress: + progress.console.print( + "[bold green]Added[/] " + f"{records_added} records from '{file}' to " + f"collection '{collection}'." + ) + except FileNotFoundError: + failed_imports += 1 + progress.console.print( + f"[bold red]Error[/]: The file '{file}' was not found." + ) + except UnsupportedTextFileError: + failed_imports += 1 + progress.console.print( + f"[bold red]Error[/]: The file '{file}' is not a text file." + ) + finally: + progress.advance(task_id) print( f"Imported {successful_imports} file(s) successfully; {failed_imports} failed." diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 53df14a..6ff39f4 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -6,7 +6,7 @@ from collections.abc import Callable from contextlib import redirect_stdout from pathlib import Path from typing import TypeVar -from unittest.mock import patch +from unittest.mock import MagicMock, patch from chromy.handlers.count_collection import handle_count_collection from chromy.handlers.create_collection import handle_create_collection @@ -173,6 +173,64 @@ class HandlerTests(unittest.TestCase): "Imported 1 file(s) successfully; 0 failed.\n", ) + def test_import_data_suppresses_per_file_output_with_progress(self) -> None: + progress = MagicMock() + progress.__enter__.return_value = progress + progress.__exit__.return_value = None + progress.console.print = print + progress.add_task.return_value = 1 + + with ( + patch("chromy.handlers.import_data.ingest_file", side_effect=[3, 2]), + patch( + "chromy.handlers.import_data._should_show_progress", + return_value=True, + ), + patch("chromy.handlers.import_data.Progress", return_value=progress), + ): + exit_code, output = _capture_output( + handle_import, + "notes", + ["romeo_and_juliet.txt", "README.md"], + ) + + self.assertEqual(exit_code, 0) + self.assertEqual(output, "Imported 2 file(s) successfully; 0 failed.\n") + + def test_import_data_truncates_long_file_names_in_progress(self) -> None: + progress = MagicMock() + progress.__enter__.return_value = progress + progress.__exit__.return_value = None + progress.console.print = print + progress.add_task.return_value = 1 + + with ( + patch( + "chromy.handlers.import_data._get_absolute_path", + side_effect=[ + "/tmp/this_is_a_very_long_file_name.txt", + self._fixture_path("README.md"), + "/tmp/this_is_a_very_long_file_name.txt", + self._fixture_path("README.md"), + ], + ), + patch("chromy.handlers.import_data._import_one", return_value=3), + patch( + "chromy.handlers.import_data._should_show_progress", + return_value=True, + ), + patch("chromy.handlers.import_data.Progress", return_value=progress), + ): + handle_import( + "notes", + ["this_is_a_very_long_file_name.txt", "README.md"], + ) + + progress.update.assert_any_call( + 1, + description="Importing [bold]this_is_a_very_lo...[/]...", + ) + def test_query_uses_typed_input(self) -> None: query_result = {"ids": [["1"]], "documents": [["hello"]]} with (