diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 20c6722..9ffc019 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -104,6 +104,37 @@ class Settings(BaseSettings): enable_cudnn_benchmark: bool = Field(default=True) # Optimize convolution algorithms num_threads: int = Field(default=4) # CPU threads for preprocessing + # ===== Enhanced Memory Management Configuration ===== + # Memory thresholds (as ratio of total GPU memory) + memory_warning_threshold: float = Field(default=0.80) # 80% - start warning + memory_critical_threshold: float = Field(default=0.95) # 95% - throttle operations + memory_emergency_threshold: float = Field(default=0.98) # 98% - emergency cleanup + + # Memory monitoring + memory_check_interval_seconds: int = Field(default=30) # Background check interval + enable_memory_alerts: bool = Field(default=True) # Enable memory alerts + + # Model lifecycle management + enable_model_lifecycle_management: bool = Field(default=True) # Use ModelManager + pp_structure_idle_timeout_seconds: int = Field(default=300) # Unload PP-Structure after idle + structure_model_memory_mb: int = Field(default=2000) # Estimated memory for PP-StructureV3 + ocr_model_memory_mb: int = Field(default=500) # Estimated memory per OCR language model + + # Service pool configuration + enable_service_pool: bool = Field(default=True) # Use OCRServicePool + max_services_per_device: int = Field(default=1) # Max OCRService per GPU + max_total_services: int = Field(default=2) # Max total OCRService instances + service_acquire_timeout_seconds: float = Field(default=300.0) # Timeout for acquiring service + max_queue_size: int = Field(default=50) # Max pending tasks per device + + # Concurrency control + max_concurrent_predictions: int = Field(default=2) # Max concurrent PP-StructureV3 predictions + enable_cpu_fallback: bool = Field(default=True) # Fall back to CPU when GPU memory low + + # Emergency recovery + enable_emergency_cleanup: bool = Field(default=True) # Auto-cleanup on memory pressure + enable_worker_restart: bool = Field(default=False) # Restart workers on OOM (requires supervisor) + # ===== File Upload Configuration ===== max_upload_size: int = Field(default=52428800) # 50MB allowed_extensions: str = Field(default="png,jpg,jpeg,pdf,bmp,tiff,doc,docx,ppt,pptx") diff --git a/backend/app/main.py b/backend/app/main.py index e888224..51e8eef 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -7,10 +7,103 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager import logging +import signal +import sys +import asyncio from pathlib import Path +from typing import Optional from app.core.config import settings + +# ============================================================================= +# Section 6.1: Signal Handlers +# ============================================================================= + +# Flag to indicate graceful shutdown is in progress +_shutdown_requested = False +_shutdown_complete = asyncio.Event() + +# Track active connections for draining +_active_connections = 0 +_connection_lock = asyncio.Lock() + + +async def increment_connections(): + """Track active connection count""" + global _active_connections + async with _connection_lock: + _active_connections += 1 + + +async def decrement_connections(): + """Track active connection count""" + global _active_connections + async with _connection_lock: + _active_connections -= 1 + + +def get_active_connections() -> int: + """Get current active connection count""" + return _active_connections + + +def is_shutdown_requested() -> bool: + """Check if graceful shutdown has been requested""" + return _shutdown_requested + + +def _signal_handler(signum: int, frame) -> None: + """ + Signal handler for SIGTERM and SIGINT. + + Initiates graceful shutdown by setting the shutdown flag. + The actual cleanup is handled by the lifespan context manager. + """ + global _shutdown_requested + signal_name = signal.Signals(signum).name + logger = logging.getLogger(__name__) + + if _shutdown_requested: + logger.warning(f"Received {signal_name} again, forcing immediate exit...") + sys.exit(1) + + logger.info(f"Received {signal_name}, initiating graceful shutdown...") + _shutdown_requested = True + + # Try to stop the event loop gracefully + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # Schedule shutdown event + loop.call_soon_threadsafe(_shutdown_complete.set) + except RuntimeError: + pass # No event loop running + + +def setup_signal_handlers(): + """ + Set up signal handlers for graceful shutdown. + + Handles: + - SIGTERM: Standard termination signal (from systemd, docker, etc.) + - SIGINT: Keyboard interrupt (Ctrl+C) + """ + logger = logging.getLogger(__name__) + + try: + # SIGTERM - Standard termination signal + signal.signal(signal.SIGTERM, _signal_handler) + logger.info("SIGTERM handler installed") + + # SIGINT - Keyboard interrupt + signal.signal(signal.SIGINT, _signal_handler) + logger.info("SIGINT handler installed") + + except (ValueError, OSError) as e: + # Signal handling may not be available in all contexts + logger.warning(f"Could not install signal handlers: {e}") + # Ensure log directory exists before configuring logging Path(settings.log_file).parent.mkdir(parents=True, exist_ok=True) @@ -38,16 +131,91 @@ for logger_name in ['app', 'app.services', 'app.services.pdf_generator_service', logger = logging.getLogger(__name__) +async def drain_connections(timeout: float = 30.0): + """ + Wait for active connections to complete (connection draining). + + Args: + timeout: Maximum time to wait for connections to drain + """ + logger.info(f"Draining connections (timeout={timeout}s)...") + start_time = asyncio.get_event_loop().time() + + while get_active_connections() > 0: + elapsed = asyncio.get_event_loop().time() - start_time + if elapsed >= timeout: + logger.warning( + f"Connection drain timeout after {timeout}s. " + f"{get_active_connections()} connections still active." + ) + break + + logger.info(f"Waiting for {get_active_connections()} active connections...") + await asyncio.sleep(1.0) + + if get_active_connections() == 0: + logger.info("All connections drained successfully") + + @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan events""" # Startup logger.info("Starting Tool_OCR V2 application...") + # Set up signal handlers for graceful shutdown + setup_signal_handlers() + # Ensure all directories exist settings.ensure_directories() logger.info("All directories created/verified") + # Initialize memory management if enabled + if settings.enable_model_lifecycle_management: + try: + from app.services.memory_manager import get_model_manager, MemoryConfig + + memory_config = MemoryConfig( + warning_threshold=settings.memory_warning_threshold, + critical_threshold=settings.memory_critical_threshold, + emergency_threshold=settings.memory_emergency_threshold, + model_idle_timeout_seconds=settings.pp_structure_idle_timeout_seconds, + memory_check_interval_seconds=settings.memory_check_interval_seconds, + enable_auto_cleanup=settings.enable_memory_optimization, + enable_emergency_cleanup=settings.enable_emergency_cleanup, + max_concurrent_predictions=settings.max_concurrent_predictions, + enable_cpu_fallback=settings.enable_cpu_fallback, + gpu_memory_limit_mb=settings.gpu_memory_limit_mb, + ) + get_model_manager(memory_config) + logger.info("Memory management initialized") + except Exception as e: + logger.warning(f"Failed to initialize memory management: {e}") + + # Initialize service pool if enabled + if settings.enable_service_pool: + try: + from app.services.service_pool import get_service_pool, PoolConfig + + pool_config = PoolConfig( + max_services_per_device=settings.max_services_per_device, + max_total_services=settings.max_total_services, + acquire_timeout_seconds=settings.service_acquire_timeout_seconds, + max_queue_size=settings.max_queue_size, + ) + get_service_pool(pool_config) + logger.info("OCR service pool initialized") + except Exception as e: + logger.warning(f"Failed to initialize service pool: {e}") + + # Initialize prediction semaphore for controlling concurrent PP-StructureV3 predictions + try: + from app.services.memory_manager import get_prediction_semaphore + get_prediction_semaphore(max_concurrent=settings.max_concurrent_predictions) + logger.info(f"Prediction semaphore initialized (max_concurrent={settings.max_concurrent_predictions})") + except Exception as e: + logger.warning(f"Failed to initialize prediction semaphore: {e}") + logger.info("Application startup complete") yield @@ -55,6 +223,45 @@ async def lifespan(app: FastAPI): # Shutdown logger.info("Shutting down Tool_OCR application...") + # Connection draining - wait for active requests to complete + await drain_connections(timeout=30.0) + + # Shutdown recovery manager if initialized + try: + from app.services.memory_manager import shutdown_recovery_manager + shutdown_recovery_manager() + logger.info("Recovery manager shutdown complete") + except Exception as e: + logger.debug(f"Recovery manager shutdown skipped: {e}") + + # Shutdown service pool + if settings.enable_service_pool: + try: + from app.services.service_pool import shutdown_service_pool + shutdown_service_pool() + logger.info("Service pool shutdown complete") + except Exception as e: + logger.warning(f"Error shutting down service pool: {e}") + + # Shutdown prediction semaphore + try: + from app.services.memory_manager import shutdown_prediction_semaphore + shutdown_prediction_semaphore() + logger.info("Prediction semaphore shutdown complete") + except Exception as e: + logger.warning(f"Error shutting down prediction semaphore: {e}") + + # Shutdown memory manager + if settings.enable_model_lifecycle_management: + try: + from app.services.memory_manager import shutdown_model_manager + shutdown_model_manager() + logger.info("Memory manager shutdown complete") + except Exception as e: + logger.warning(f"Error shutting down memory manager: {e}") + + logger.info("Tool_OCR shutdown complete") + # Create FastAPI application app = FastAPI( @@ -77,9 +284,7 @@ app.add_middleware( # Health check endpoint @app.get("/health") async def health_check(): - """Health check endpoint with GPU status""" - from app.services.ocr_service import OCRService - + """Health check endpoint with GPU status and memory management info""" response = { "status": "healthy", "service": "Tool_OCR V2", @@ -88,10 +293,31 @@ async def health_check(): # Add GPU status information try: - # Create temporary OCRService instance to get GPU status - # In production, this should be a singleton service - ocr_service = OCRService() - gpu_status = ocr_service.get_gpu_status() + # Use service pool if available to avoid creating new instances + gpu_status = None + if settings.enable_service_pool: + try: + from app.services.service_pool import get_service_pool + pool = get_service_pool() + pool_stats = pool.get_pool_stats() + response["service_pool"] = pool_stats + + # Get GPU status from first available service + for device, services in pool.services.items(): + for pooled in services: + if hasattr(pooled.service, 'get_gpu_status'): + gpu_status = pooled.service.get_gpu_status() + break + if gpu_status: + break + except Exception as e: + logger.debug(f"Could not get service pool stats: {e}") + + # Fallback: create temporary instance if no pool or no service available + if gpu_status is None: + from app.services.ocr_service import OCRService + ocr_service = OCRService() + gpu_status = ocr_service.get_gpu_status() response["gpu"] = { "available": gpu_status.get("gpu_available", False), @@ -120,6 +346,15 @@ async def health_check(): "error": str(e), } + # Add memory management status + if settings.enable_model_lifecycle_management: + try: + from app.services.memory_manager import get_model_manager + model_manager = get_model_manager() + response["memory_management"] = model_manager.get_model_stats() + except Exception as e: + logger.debug(f"Could not get memory management stats: {e}") + return response diff --git a/backend/app/models/unified_document.py b/backend/app/models/unified_document.py index ce236ea..bd7cd72 100644 --- a/backend/app/models/unified_document.py +++ b/backend/app/models/unified_document.py @@ -212,26 +212,44 @@ class TableData: if self.caption: html.append(f"{self.caption}") - # Group cells by row - rows_data = {} + # Group cells by row and column for quick lookup + cell_map = {} for cell in self.cells: - if cell.row not in rows_data: - rows_data[cell.row] = [] - rows_data[cell.row].append(cell) + cell_map[(cell.row, cell.col)] = cell - # Generate HTML + # Track which cells are covered by row/col spans + covered = set() + for cell in self.cells: + if cell.row_span > 1 or cell.col_span > 1: + for r in range(cell.row, cell.row + cell.row_span): + for c in range(cell.col, cell.col + cell.col_span): + if (r, c) != (cell.row, cell.col): + covered.add((r, c)) + + # Generate HTML with proper column filling for row_idx in range(self.rows): html.append("") - if row_idx in rows_data: - for cell in sorted(rows_data[row_idx], key=lambda c: c.col): + for col_idx in range(self.cols): + # Skip cells covered by row/col spans + if (row_idx, col_idx) in covered: + continue + + cell = cell_map.get((row_idx, col_idx)) + tag = "th" if row_idx == 0 and self.headers else "td" + + if cell: span_attrs = [] if cell.row_span > 1: span_attrs.append(f'rowspan="{cell.row_span}"') if cell.col_span > 1: span_attrs.append(f'colspan="{cell.col_span}"') span_str = " ".join(span_attrs) - tag = "th" if row_idx == 0 and self.headers else "td" - html.append(f'<{tag} {span_str}>{cell.content}') + content = cell.content if cell.content else "" + html.append(f'<{tag} {span_str}>{content}') + else: + # Fill in empty cell for missing positions + html.append(f'<{tag}>') + html.append("") html.append("") diff --git a/backend/app/routers/tasks.py b/backend/app/routers/tasks.py index 8e0505c..38a170b 100644 --- a/backend/app/routers/tasks.py +++ b/backend/app/routers/tasks.py @@ -39,6 +39,7 @@ from app.schemas.task import ( from app.services.task_service import task_service from app.services.file_access_service import file_access_service from app.services.ocr_service import OCRService +from app.services.service_pool import get_service_pool, PoolConfig # Import dual-track components try: @@ -47,6 +48,13 @@ try: except ImportError: DUAL_TRACK_AVAILABLE = False +# Service pool availability +SERVICE_POOL_AVAILABLE = True +try: + from app.services.memory_manager import get_model_manager +except ImportError: + SERVICE_POOL_AVAILABLE = False + logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v2/tasks", tags=["Tasks"]) @@ -63,7 +71,10 @@ def process_task_ocr( pp_structure_params: Optional[dict] = None ): """ - Background task to process OCR for a task with dual-track support + Background task to process OCR for a task with dual-track support. + + Uses OCRServicePool to acquire a shared service instance instead of + creating a new one, preventing GPU memory proliferation. Args: task_id: Task UUID string @@ -80,6 +91,7 @@ def process_task_ocr( db = SessionLocal() start_time = datetime.now() + pooled_service = None try: logger.info(f"Starting OCR processing for task {task_id}, file: {filename}") @@ -91,16 +103,39 @@ def process_task_ocr( logger.error(f"Task {task_id} not found in database") return - # Initialize OCR service - ocr_service = OCRService() + # Acquire OCR service from pool (or create new if pool disabled) + ocr_service = None + if settings.enable_service_pool and SERVICE_POOL_AVAILABLE: + try: + service_pool = get_service_pool() + pooled_service = service_pool.acquire( + device="GPU:0", + timeout=settings.service_acquire_timeout_seconds, + task_id=task_id + ) + if pooled_service: + ocr_service = pooled_service.service + logger.info(f"Acquired OCR service from pool for task {task_id}") + else: + logger.warning(f"Timeout acquiring service from pool, creating new instance") + except Exception as e: + logger.warning(f"Failed to acquire service from pool: {e}, creating new instance") + + # Fallback: create new instance if pool acquisition failed + if ocr_service is None: + logger.info("Creating new OCRService instance (pool disabled or unavailable)") + ocr_service = OCRService() # Create result directory before OCR processing (needed for saving extracted images) result_dir = Path(settings.result_dir) / task_id result_dir.mkdir(parents=True, exist_ok=True) - # Process the file with OCR (use dual-track if available) - if use_dual_track and hasattr(ocr_service, 'process'): - # Use new dual-track processing + # Process the file with OCR + # Use dual-track processing if: + # 1. use_dual_track is True (auto-detection) + # 2. OR force_track is specified (explicit track selection) + if (use_dual_track or force_track) and hasattr(ocr_service, 'process'): + # Use new dual-track processing (or forced track) ocr_result = ocr_service.process( file_path=Path(file_path), lang=language, @@ -111,7 +146,7 @@ def process_task_ocr( pp_structure_params=pp_structure_params ) else: - # Fall back to traditional processing + # Fall back to traditional processing (no force_track support) ocr_result = ocr_service.process_image( image_path=Path(file_path), lang=language, @@ -131,6 +166,16 @@ def process_task_ocr( source_file_path=Path(file_path) ) + # Release service back to pool (success case) + if pooled_service: + try: + service_pool = get_service_pool() + service_pool.release(pooled_service, error=None) + logger.info(f"Released OCR service back to pool for task {task_id}") + pooled_service = None # Prevent double release in finally + except Exception as e: + logger.warning(f"Failed to release service to pool: {e}") + # Close old session and create fresh one to avoid MySQL timeout # (long OCR processing may cause connection to become stale) db.close() @@ -158,6 +203,15 @@ def process_task_ocr( except Exception as e: logger.exception(f"OCR processing failed for task {task_id}") + # Release service back to pool with error + if pooled_service: + try: + service_pool = get_service_pool() + service_pool.release(pooled_service, error=e) + pooled_service = None + except Exception as release_error: + logger.warning(f"Failed to release service to pool: {release_error}") + # Update task status to failed (direct database update) try: task = db.query(Task).filter(Task.id == task_db_id).first() @@ -170,6 +224,13 @@ def process_task_ocr( logger.error(f"Failed to update task status: {update_error}") finally: + # Ensure service is released in case of any missed release + if pooled_service: + try: + service_pool = get_service_pool() + service_pool.release(pooled_service, error=None) + except Exception: + pass db.close() @@ -330,7 +391,13 @@ async def get_task( with open(result_path) as f: result_data = json.load(f) metadata = result_data.get("metadata", {}) - processing_track = metadata.get("processing_track") + track_str = metadata.get("processing_track") + # Convert string to enum to avoid Pydantic serialization warning + if track_str: + try: + processing_track = ProcessingTrackEnum(track_str) + except ValueError: + processing_track = None except Exception: pass # Silently ignore errors reading the result file diff --git a/backend/app/services/direct_extraction_engine.py b/backend/app/services/direct_extraction_engine.py index 0ed9405..d0b23b9 100644 --- a/backend/app/services/direct_extraction_engine.py +++ b/backend/app/services/direct_extraction_engine.py @@ -247,9 +247,11 @@ class DirectExtractionEngine: element_counter += len(image_elements) # Extract vector graphics (charts, diagrams) from drawing commands + # Pass table_bboxes to filter out table border drawings before clustering if self.enable_image_extraction: vector_elements = self._extract_vector_graphics( - page, page_num, document_id, element_counter, output_dir + page, page_num, document_id, element_counter, output_dir, + table_bboxes=table_bboxes ) elements.extend(vector_elements) element_counter += len(vector_elements) @@ -705,40 +707,52 @@ class DirectExtractionEngine: y1=bbox_data[3] ) - # Extract column widths from table cells + # Extract column widths from table cells by analyzing X boundaries column_widths = [] if hasattr(table, 'cells') and table.cells: - # Group cells by column - cols_x = {} + # Collect all unique X boundaries (both left and right edges) + x_boundaries = set() for cell in table.cells: - col_idx = None - # Determine column index by x0 position - for idx, x0 in enumerate(sorted(set(c[0] for c in table.cells))): - if abs(cell[0] - x0) < 1.0: # Within 1pt tolerance - col_idx = idx - break + x_boundaries.add(round(cell[0], 1)) # x0 (left edge) + x_boundaries.add(round(cell[2], 1)) # x1 (right edge) - if col_idx is not None: - if col_idx not in cols_x: - cols_x[col_idx] = {'x0': cell[0], 'x1': cell[2]} - else: - cols_x[col_idx]['x1'] = max(cols_x[col_idx]['x1'], cell[2]) + # Sort boundaries to get column edges + sorted_x = sorted(x_boundaries) - # Calculate width for each column - for col_idx in sorted(cols_x.keys()): - width = cols_x[col_idx]['x1'] - cols_x[col_idx]['x0'] - column_widths.append(width) + # Calculate column widths from adjacent boundaries + if len(sorted_x) >= 2: + column_widths = [sorted_x[i+1] - sorted_x[i] for i in range(len(sorted_x)-1)] + logger.debug(f"Calculated column widths from {len(sorted_x)} boundaries: {column_widths}") + + # Extract row heights from table cells by analyzing Y boundaries + row_heights = [] + if hasattr(table, 'cells') and table.cells: + # Collect all unique Y boundaries (both top and bottom edges) + y_boundaries = set() + for cell in table.cells: + y_boundaries.add(round(cell[1], 1)) # y0 (top edge) + y_boundaries.add(round(cell[3], 1)) # y1 (bottom edge) + + # Sort boundaries to get row edges + sorted_y = sorted(y_boundaries) + + # Calculate row heights from adjacent boundaries + if len(sorted_y) >= 2: + row_heights = [sorted_y[i+1] - sorted_y[i] for i in range(len(sorted_y)-1)] + logger.debug(f"Calculated row heights from {len(sorted_y)} boundaries: {row_heights}") # Create table cells + # Note: Include ALL cells (even empty ones) to preserve table structure + # This is critical for correct HTML generation and PDF rendering cells = [] for row_idx, row in enumerate(data): for col_idx, cell_text in enumerate(row): - if cell_text: - cells.append(TableCell( - row=row_idx, - col=col_idx, - content=str(cell_text) if cell_text else "" - )) + # Always add cell, even if empty, to maintain table structure + cells.append(TableCell( + row=row_idx, + col=col_idx, + content=str(cell_text) if cell_text else "" + )) # Create table data table_data = TableData( @@ -748,8 +762,13 @@ class DirectExtractionEngine: headers=data[0] if data else None # Assume first row is header ) - # Store column widths in metadata - metadata = {"column_widths": column_widths} if column_widths else None + # Store column widths and row heights in metadata + metadata = {} + if column_widths: + metadata["column_widths"] = column_widths + if row_heights: + metadata["row_heights"] = row_heights + metadata = metadata if metadata else None return DocumentElement( element_id=f"table_{page_num}_{counter}", @@ -978,7 +997,9 @@ class DirectExtractionEngine: image_filename = f"{document_id}_p{page_num}_img{img_idx}.png" image_path = output_dir / image_filename pix.save(str(image_path)) - image_data["saved_path"] = str(image_path) + # Store relative filename only (consistent with OCR track) + # PDF generator will join with result_dir to get full path + image_data["saved_path"] = image_filename logger.debug(f"Saved image to {image_path}") element = DocumentElement( @@ -1001,12 +1022,272 @@ class DirectExtractionEngine: return elements + def has_missing_images(self, page: fitz.Page) -> bool: + """ + Detect if a page likely has images that weren't extracted. + + This checks for inline image blocks (type=1 in text dict) which indicate + graphics composed of many small image blocks (like logos) that + page.get_images() cannot detect. + + Args: + page: PyMuPDF page object + + Returns: + True if there are likely missing images that need OCR extraction + """ + try: + # Check if get_images found anything + standard_images = page.get_images() + if standard_images: + return False # Standard images were found, no need for fallback + + # Check for inline image blocks (type=1) + text_dict = page.get_text("dict", sort=True) + blocks = text_dict.get("blocks", []) + + image_block_count = sum(1 for b in blocks if b.get("type") == 1) + + # If there are many inline image blocks, likely there's a logo or graphic + if image_block_count >= 10: + logger.info(f"Detected {image_block_count} inline image blocks - may need OCR for image extraction") + return True + + return False + + except Exception as e: + logger.warning(f"Error checking for missing images: {e}") + return False + + def check_document_for_missing_images(self, pdf_path: Path) -> List[int]: + """ + Check a PDF document for pages that likely have missing images. + + This opens the PDF and checks each page for inline image blocks + that weren't extracted by get_images(). + + Args: + pdf_path: Path to the PDF file + + Returns: + List of page numbers (1-indexed) that have missing images + """ + pages_with_missing_images = [] + + try: + doc = fitz.open(str(pdf_path)) + for page_num in range(len(doc)): + page = doc[page_num] + if self.has_missing_images(page): + pages_with_missing_images.append(page_num + 1) # 1-indexed + doc.close() + + if pages_with_missing_images: + logger.info(f"Document has missing images on pages: {pages_with_missing_images}") + + except Exception as e: + logger.error(f"Error checking document for missing images: {e}") + + return pages_with_missing_images + + def render_inline_image_regions( + self, + pdf_path: Path, + unified_doc: 'UnifiedDocument', + pages: List[int], + output_dir: Optional[Path] = None + ) -> int: + """ + Render inline image regions and add them to the unified document. + + This is a fallback when OCR doesn't detect images. It clusters inline + image blocks (type=1) and renders them as images. + + Args: + pdf_path: Path to the PDF file + unified_doc: UnifiedDocument to add images to + pages: List of page numbers (1-indexed) to process + output_dir: Directory to save rendered images + + Returns: + Number of images added + """ + images_added = 0 + + try: + doc = fitz.open(str(pdf_path)) + + for page_num in pages: + if page_num < 1 or page_num > len(doc): + continue + + page = doc[page_num - 1] # 0-indexed + page_rect = page.rect + + # Get inline image blocks + text_dict = page.get_text("dict", sort=True) + blocks = text_dict.get("blocks", []) + + image_blocks = [] + for block in blocks: + if block.get("type") == 1: # Image block + bbox = block.get("bbox") + if bbox: + image_blocks.append(fitz.Rect(bbox)) + + if len(image_blocks) < 5: # Reduced from 10 + logger.debug(f"Page {page_num}: Only {len(image_blocks)} inline image blocks, skipping") + continue + + logger.info(f"Page {page_num}: Found {len(image_blocks)} inline image blocks") + + # Cluster nearby image blocks + regions = self._cluster_nearby_rects(image_blocks, tolerance=5.0) + logger.info(f"Page {page_num}: Clustered into {len(regions)} regions") + + # Find the corresponding page in unified_doc + target_page = None + for p in unified_doc.pages: + if p.page_number == page_num: + target_page = p + break + + if not target_page: + continue + + for region_idx, region_rect in enumerate(regions): + logger.info(f"Page {page_num} region {region_idx}: {region_rect} (w={region_rect.width:.1f}, h={region_rect.height:.1f})") + + # Skip very small regions + if region_rect.width < 30 or region_rect.height < 30: + logger.info(f" -> Skipped: too small (min 30x30)") + continue + + # Skip regions that are primarily in the table area (below top 40%) + # But allow regions that START in the top portion + page_30_pct = page_rect.height * 0.3 + page_40_pct = page_rect.height * 0.4 + if region_rect.y0 > page_40_pct: + logger.info(f" -> Skipped: y0={region_rect.y0:.1f} > 40% of page ({page_40_pct:.1f})") + continue + + logger.info(f"Rendering inline image region {region_idx} on page {page_num}: {region_rect}") + + try: + # Add small padding + clip_rect = region_rect + (-2, -2, 2, 2) + clip_rect.intersect(page_rect) + + # Render at 2x resolution + mat = fitz.Matrix(2, 2) + pix = page.get_pixmap(clip=clip_rect, matrix=mat, alpha=False) + + # Create bounding box + bbox = BoundingBox( + x0=clip_rect.x0, + y0=clip_rect.y0, + x1=clip_rect.x1, + y1=clip_rect.y1 + ) + + image_data = { + "width": pix.width, + "height": pix.height, + "colorspace": "rgb", + "type": "inline_region" + } + + # Save image if output directory provided + if output_dir: + output_dir.mkdir(parents=True, exist_ok=True) + doc_id = unified_doc.document_id or "unknown" + image_filename = f"{doc_id}_p{page_num}_logo{region_idx}.png" + image_path = output_dir / image_filename + pix.save(str(image_path)) + image_data["saved_path"] = image_filename + logger.info(f"Saved inline image region to {image_path}") + + element = DocumentElement( + element_id=f"logo_{page_num}_{region_idx}", + type=ElementType.LOGO, + content=image_data, + bbox=bbox, + confidence=0.9, + metadata={ + "region_type": "inline_image_blocks", + "block_count": len(image_blocks) + } + ) + target_page.elements.append(element) + images_added += 1 + + pix = None # Free memory + + except Exception as e: + logger.error(f"Error rendering inline image region {region_idx}: {e}") + + doc.close() + + if images_added > 0: + current_images = unified_doc.metadata.total_images or 0 + unified_doc.metadata.total_images = current_images + images_added + logger.info(f"Added {images_added} inline image regions to document") + + except Exception as e: + logger.error(f"Error rendering inline image regions: {e}") + + return images_added + + def _cluster_nearby_rects(self, rects: List[fitz.Rect], tolerance: float = 5.0) -> List[fitz.Rect]: + """Cluster nearby rectangles into regions.""" + if not rects: + return [] + + sorted_rects = sorted(rects, key=lambda r: (r.y0, r.x0)) + + merged = [] + for rect in sorted_rects: + merged_with_existing = False + for i, region in enumerate(merged): + expanded = region + (-tolerance, -tolerance, tolerance, tolerance) + if expanded.intersects(rect): + merged[i] = region | rect + merged_with_existing = True + break + if not merged_with_existing: + merged.append(rect) + + # Second pass: merge any regions that now overlap + changed = True + while changed: + changed = False + new_merged = [] + skip = set() + + for i, r1 in enumerate(merged): + if i in skip: + continue + current = r1 + for j, r2 in enumerate(merged[i+1:], start=i+1): + if j in skip: + continue + expanded = current + (-tolerance, -tolerance, tolerance, tolerance) + if expanded.intersects(r2): + current = current | r2 + skip.add(j) + changed = True + new_merged.append(current) + merged = new_merged + + return merged + def _extract_vector_graphics(self, page: fitz.Page, page_num: int, document_id: str, counter: int, - output_dir: Optional[Path]) -> List[DocumentElement]: + output_dir: Optional[Path], + table_bboxes: Optional[List[BoundingBox]] = None) -> List[DocumentElement]: """ Extract vector graphics (charts, diagrams) from page. @@ -1020,6 +1301,7 @@ class DirectExtractionEngine: document_id: Unique document identifier counter: Starting counter for element IDs output_dir: Directory to save rendered graphics + table_bboxes: List of table bounding boxes to exclude table border drawings Returns: List of DocumentElement objects representing vector graphics @@ -1034,16 +1316,25 @@ class DirectExtractionEngine: logger.debug(f"Page {page_num} contains {len(drawings)} vector drawing commands") + # Filter out drawings that are likely table borders + # Table borders are typically thin rectangular lines within table regions + non_table_drawings = self._filter_table_border_drawings(drawings, table_bboxes) + logger.debug(f"After filtering table borders: {len(non_table_drawings)} drawings remain") + + if not non_table_drawings: + logger.debug("All drawings appear to be table borders, no vector graphics to extract") + return elements + # Cluster drawings into groups (charts, diagrams, etc.) try: - # PyMuPDF's cluster_drawings() groups nearby drawings automatically - drawing_clusters = page.cluster_drawings() + # Use custom clustering that only considers non-table drawings + drawing_clusters = self._cluster_non_table_drawings(page, non_table_drawings) logger.debug(f"Clustered into {len(drawing_clusters)} groups") except (AttributeError, TypeError) as e: # cluster_drawings not available or has different signature # Fallback: try to identify charts by analyzing drawing density - logger.warning(f"cluster_drawings() failed ({e}), using fallback method") - drawing_clusters = self._cluster_drawings_fallback(page, drawings) + logger.warning(f"Custom clustering failed ({e}), using fallback method") + drawing_clusters = self._cluster_drawings_fallback(page, non_table_drawings) for cluster_idx, bbox in enumerate(drawing_clusters): # Ignore small regions (likely noise or separator lines) @@ -1148,6 +1439,124 @@ class DirectExtractionEngine: return filtered_clusters + def _filter_table_border_drawings(self, drawings: list, table_bboxes: Optional[List[BoundingBox]]) -> list: + """ + Filter out drawings that are likely table borders. + + Table borders are typically: + - Thin rectangular lines (height or width < 5pt) + - Located within or on the edge of table bounding boxes + + Args: + drawings: List of PyMuPDF drawing objects + table_bboxes: List of table bounding boxes + + Returns: + List of drawings that are NOT table borders (likely logos, charts, etc.) + """ + if not table_bboxes: + return drawings + + non_table_drawings = [] + table_border_count = 0 + + for drawing in drawings: + rect = drawing.get('rect') + if not rect: + continue + + draw_rect = fitz.Rect(rect) + + # Check if this drawing is a thin line (potential table border) + is_thin_line = draw_rect.width < 5 or draw_rect.height < 5 + + # Check if drawing overlaps significantly with any table + overlaps_table = False + for table_bbox in table_bboxes: + table_rect = fitz.Rect(table_bbox.x0, table_bbox.y0, table_bbox.x1, table_bbox.y1) + + # Expand table rect slightly to include border lines on edges + expanded_table = table_rect + (-5, -5, 5, 5) + + if expanded_table.contains(draw_rect) or expanded_table.intersects(draw_rect): + # Calculate overlap ratio + intersection = draw_rect & expanded_table + if not intersection.is_empty: + overlap_ratio = intersection.get_area() / draw_rect.get_area() if draw_rect.get_area() > 0 else 0 + + # If drawing is mostly inside table region, it's likely a border + if overlap_ratio > 0.8: + overlaps_table = True + break + + # Keep drawing if it's NOT (thin line AND overlapping table) + # This keeps: logos (complex shapes), charts outside tables, etc. + if is_thin_line and overlaps_table: + table_border_count += 1 + else: + non_table_drawings.append(drawing) + + if table_border_count > 0: + logger.debug(f"Filtered out {table_border_count} table border drawings") + + return non_table_drawings + + def _cluster_non_table_drawings(self, page: fitz.Page, drawings: list) -> list: + """ + Cluster non-table drawings into groups. + + This method clusters drawings that have been pre-filtered to exclude table borders. + It uses a more conservative clustering approach suitable for logos and charts. + + Args: + page: PyMuPDF page object + drawings: Pre-filtered list of drawings (excluding table borders) + + Returns: + List of fitz.Rect representing clustered drawing regions + """ + if not drawings: + return [] + + # Collect all drawing bounding boxes + bboxes = [] + for drawing in drawings: + rect = drawing.get('rect') + if rect: + bboxes.append(fitz.Rect(rect)) + + if not bboxes: + return [] + + # More conservative clustering with smaller tolerance + # This prevents grouping distant graphics together + clusters = [] + tolerance = 10 # Smaller tolerance than fallback (was 20) + + for bbox in bboxes: + # Try to merge with existing cluster + merged = False + for i, cluster in enumerate(clusters): + # Check if bbox is close to this cluster + expanded_cluster = cluster + (-tolerance, -tolerance, tolerance, tolerance) + if expanded_cluster.intersects(bbox): + # Merge bbox into cluster + clusters[i] = cluster | bbox # Union of rectangles + merged = True + break + + if not merged: + # Create new cluster + clusters.append(bbox) + + # Filter out very small clusters (noise) + # Keep minimum 30x30 for logos (smaller than default 50x50) + filtered_clusters = [c for c in clusters if c.width >= 30 and c.height >= 30] + + logger.debug(f"Non-table clustering: {len(bboxes)} drawings -> {len(clusters)} clusters -> {len(filtered_clusters)} filtered") + + return filtered_clusters + def _deduplicate_table_chart_overlap(self, elements: List[DocumentElement]) -> List[DocumentElement]: """ Intelligently resolve TABLE-CHART overlaps based on table structure completeness. diff --git a/backend/app/services/memory_manager.py b/backend/app/services/memory_manager.py new file mode 100644 index 0000000..57ce976 --- /dev/null +++ b/backend/app/services/memory_manager.py @@ -0,0 +1,2269 @@ +""" +Tool_OCR - Memory Management System +Provides centralized model lifecycle management with reference counting, +idle timeout, and GPU memory monitoring. +""" + +import asyncio +import gc +import logging +import threading +import time +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple +from weakref import WeakValueDictionary + +import paddle + +# Optional torch import for additional GPU memory management +try: + import torch + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + +# Optional pynvml import for NVIDIA GPU monitoring +try: + import pynvml + PYNVML_AVAILABLE = True +except ImportError: + PYNVML_AVAILABLE = False + +logger = logging.getLogger(__name__) + + +class MemoryBackend(Enum): + """Available memory query backends""" + PADDLE = "paddle" + TORCH = "torch" + PYNVML = "pynvml" + NONE = "none" + + +@dataclass +class MemoryStats: + """GPU/CPU memory statistics""" + gpu_used_mb: float = 0.0 + gpu_free_mb: float = 0.0 + gpu_total_mb: float = 0.0 + gpu_used_ratio: float = 0.0 + cpu_used_mb: float = 0.0 + cpu_available_mb: float = 0.0 + timestamp: float = field(default_factory=time.time) + backend: MemoryBackend = MemoryBackend.NONE + + +@dataclass +class ModelEntry: + """Entry for a managed model""" + model: Any + model_id: str + ref_count: int = 0 + last_used: float = field(default_factory=time.time) + created_at: float = field(default_factory=time.time) + estimated_memory_mb: float = 0.0 + is_loading: bool = False + cleanup_callback: Optional[Callable] = None + + +class MemoryConfig: + """Configuration for memory management""" + + def __init__( + self, + warning_threshold: float = 0.80, + critical_threshold: float = 0.95, + emergency_threshold: float = 0.98, + model_idle_timeout_seconds: int = 300, + memory_check_interval_seconds: int = 30, + enable_auto_cleanup: bool = True, + enable_emergency_cleanup: bool = True, + max_concurrent_predictions: int = 2, + enable_cpu_fallback: bool = True, + gpu_memory_limit_mb: int = 6144, + ): + self.warning_threshold = warning_threshold + self.critical_threshold = critical_threshold + self.emergency_threshold = emergency_threshold + self.model_idle_timeout_seconds = model_idle_timeout_seconds + self.memory_check_interval_seconds = memory_check_interval_seconds + self.enable_auto_cleanup = enable_auto_cleanup + self.enable_emergency_cleanup = enable_emergency_cleanup + self.max_concurrent_predictions = max_concurrent_predictions + self.enable_cpu_fallback = enable_cpu_fallback + self.gpu_memory_limit_mb = gpu_memory_limit_mb + + +class MemoryGuard: + """ + Monitor GPU/CPU memory usage and trigger preventive actions. + + Supports multiple backends: paddle.device.cuda, pynvml, torch + """ + + def __init__(self, config: Optional[MemoryConfig] = None): + self.config = config or MemoryConfig() + self.backend = self._detect_backend() + self._history: List[MemoryStats] = [] + self._max_history = 100 + self._alerts: List[Dict] = [] + self._lock = threading.Lock() + + # Initialize pynvml if available + self._nvml_handle = None + if self.backend == MemoryBackend.PYNVML: + self._init_pynvml() + + logger.info(f"MemoryGuard initialized with backend: {self.backend.value}") + + def _detect_backend(self) -> MemoryBackend: + """Detect the best available memory query backend""" + # Prefer pynvml for accurate GPU memory info + if PYNVML_AVAILABLE: + try: + pynvml.nvmlInit() + pynvml.nvmlShutdown() + return MemoryBackend.PYNVML + except Exception: + pass + + # Fall back to torch if available + if TORCH_AVAILABLE and torch.cuda.is_available(): + return MemoryBackend.TORCH + + # Fall back to paddle + if paddle.is_compiled_with_cuda(): + try: + if paddle.device.cuda.device_count() > 0: + return MemoryBackend.PADDLE + except Exception: + pass + + return MemoryBackend.NONE + + def _init_pynvml(self): + """Initialize pynvml for GPU monitoring""" + try: + pynvml.nvmlInit() + self._nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(0) + logger.info("pynvml initialized for GPU monitoring") + except Exception as e: + logger.warning(f"Failed to initialize pynvml: {e}") + self.backend = MemoryBackend.PADDLE if paddle.is_compiled_with_cuda() else MemoryBackend.NONE + + def get_memory_stats(self, device_id: int = 0) -> MemoryStats: + """ + Get current memory statistics. + + Args: + device_id: GPU device ID (default 0) + + Returns: + MemoryStats with current memory usage + """ + stats = MemoryStats(backend=self.backend) + + try: + if self.backend == MemoryBackend.PYNVML and self._nvml_handle: + mem_info = pynvml.nvmlDeviceGetMemoryInfo(self._nvml_handle) + stats.gpu_total_mb = mem_info.total / (1024**2) + stats.gpu_used_mb = mem_info.used / (1024**2) + stats.gpu_free_mb = mem_info.free / (1024**2) + stats.gpu_used_ratio = mem_info.used / mem_info.total if mem_info.total > 0 else 0 + + elif self.backend == MemoryBackend.TORCH: + stats.gpu_total_mb = torch.cuda.get_device_properties(device_id).total_memory / (1024**2) + stats.gpu_used_mb = torch.cuda.memory_allocated(device_id) / (1024**2) + stats.gpu_free_mb = (torch.cuda.get_device_properties(device_id).total_memory - + torch.cuda.memory_allocated(device_id)) / (1024**2) + stats.gpu_used_ratio = stats.gpu_used_mb / stats.gpu_total_mb if stats.gpu_total_mb > 0 else 0 + + elif self.backend == MemoryBackend.PADDLE: + # Paddle doesn't provide total memory directly, use allocated/reserved + stats.gpu_used_mb = paddle.device.cuda.memory_allocated(device_id) / (1024**2) + stats.gpu_free_mb = 0 # Not directly available + stats.gpu_total_mb = self.config.gpu_memory_limit_mb + stats.gpu_used_ratio = stats.gpu_used_mb / stats.gpu_total_mb if stats.gpu_total_mb > 0 else 0 + + # Get CPU memory info + try: + import psutil + mem = psutil.virtual_memory() + stats.cpu_used_mb = mem.used / (1024**2) + stats.cpu_available_mb = mem.available / (1024**2) + except ImportError: + pass + + except Exception as e: + logger.warning(f"Failed to get memory stats: {e}") + + # Store in history + with self._lock: + self._history.append(stats) + if len(self._history) > self._max_history: + self._history.pop(0) + + return stats + + def check_memory(self, required_mb: int = 0, device_id: int = 0) -> Tuple[bool, MemoryStats]: + """ + Check if sufficient GPU memory is available. + + Args: + required_mb: Required memory in MB (0 for just checking thresholds) + device_id: GPU device ID + + Returns: + Tuple of (is_available, current_stats) + """ + stats = self.get_memory_stats(device_id) + + # Check if we have enough free memory + if required_mb > 0 and stats.gpu_free_mb > 0: + if stats.gpu_free_mb < required_mb: + self._add_alert("insufficient_memory", + f"Required {required_mb}MB but only {stats.gpu_free_mb:.0f}MB available") + return False, stats + + # Check threshold levels + if stats.gpu_used_ratio > self.config.emergency_threshold: + self._add_alert("emergency", + f"GPU memory at {stats.gpu_used_ratio*100:.1f}% (emergency threshold)") + return False, stats + + if stats.gpu_used_ratio > self.config.critical_threshold: + self._add_alert("critical", + f"GPU memory at {stats.gpu_used_ratio*100:.1f}% (critical threshold)") + return False, stats + + if stats.gpu_used_ratio > self.config.warning_threshold: + self._add_alert("warning", + f"GPU memory at {stats.gpu_used_ratio*100:.1f}% (warning threshold)") + + return True, stats + + def _add_alert(self, level: str, message: str): + """Add an alert to the alert history""" + alert = { + "level": level, + "message": message, + "timestamp": time.time() + } + with self._lock: + self._alerts.append(alert) + # Keep last 50 alerts + if len(self._alerts) > 50: + self._alerts.pop(0) + + if level == "emergency": + logger.error(f"MEMORY ALERT [{level}]: {message}") + elif level == "critical": + logger.warning(f"MEMORY ALERT [{level}]: {message}") + else: + logger.info(f"MEMORY ALERT [{level}]: {message}") + + def get_alerts(self, since_timestamp: float = 0) -> List[Dict]: + """Get alerts since a given timestamp""" + with self._lock: + return [a for a in self._alerts if a["timestamp"] > since_timestamp] + + def clear_gpu_cache(self): + """Clear GPU memory cache""" + try: + if TORCH_AVAILABLE and torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + logger.debug("Cleared PyTorch GPU cache") + + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.empty_cache() + logger.debug("Cleared PaddlePaddle GPU cache") + + gc.collect() + + except Exception as e: + logger.warning(f"GPU cache clear failed: {e}") + + def shutdown(self): + """Clean up resources""" + if PYNVML_AVAILABLE and self._nvml_handle: + try: + pynvml.nvmlShutdown() + except Exception: + pass + + +class PredictionSemaphore: + """ + Semaphore for controlling concurrent PP-StructureV3 predictions. + + PP-StructureV3.predict() is memory-intensive. Running multiple predictions + simultaneously can cause OOM errors. This class limits concurrent predictions + and provides timeout handling. + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """Singleton pattern - ensure only one PredictionSemaphore exists""" + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, max_concurrent: int = 2, default_timeout: float = 300.0): + if self._initialized: + return + + self._max_concurrent = max_concurrent + self._default_timeout = default_timeout + self._semaphore = threading.Semaphore(max_concurrent) + self._condition = threading.Condition() + self._queue_depth = 0 + self._active_predictions = 0 + + # Metrics + self._total_predictions = 0 + self._total_timeouts = 0 + self._total_wait_time = 0.0 + self._metrics_lock = threading.Lock() + + self._initialized = True + logger.info(f"PredictionSemaphore initialized (max_concurrent={max_concurrent})") + + def acquire(self, timeout: Optional[float] = None, task_id: Optional[str] = None) -> bool: + """ + Acquire a prediction slot. + + Args: + timeout: Timeout in seconds (None for default, 0 for non-blocking) + task_id: Optional task identifier for logging + + Returns: + True if acquired, False if timed out + """ + timeout = timeout if timeout is not None else self._default_timeout + start_time = time.time() + + with self._condition: + self._queue_depth += 1 + + task_str = f" for task {task_id}" if task_id else "" + logger.debug(f"Waiting for prediction slot{task_str} (queue_depth={self._queue_depth})") + + try: + acquired = self._semaphore.acquire(timeout=timeout if timeout > 0 else None) + + wait_time = time.time() - start_time + with self._metrics_lock: + self._total_wait_time += wait_time + if acquired: + self._total_predictions += 1 + self._active_predictions += 1 + else: + self._total_timeouts += 1 + + with self._condition: + self._queue_depth -= 1 + + if acquired: + logger.debug(f"Prediction slot acquired{task_str} (waited {wait_time:.2f}s, active={self._active_predictions})") + else: + logger.warning(f"Prediction slot timeout{task_str} after {timeout}s") + + return acquired + + except Exception as e: + with self._condition: + self._queue_depth -= 1 + logger.error(f"Error acquiring prediction slot: {e}") + return False + + def release(self, task_id: Optional[str] = None): + """ + Release a prediction slot. + + Args: + task_id: Optional task identifier for logging + """ + self._semaphore.release() + + with self._metrics_lock: + self._active_predictions = max(0, self._active_predictions - 1) + + task_str = f" for task {task_id}" if task_id else "" + logger.debug(f"Prediction slot released{task_str} (active={self._active_predictions})") + + def get_stats(self) -> Dict: + """Get prediction semaphore statistics""" + with self._metrics_lock: + avg_wait = self._total_wait_time / max(1, self._total_predictions) + return { + "max_concurrent": self._max_concurrent, + "active_predictions": self._active_predictions, + "queue_depth": self._queue_depth, + "total_predictions": self._total_predictions, + "total_timeouts": self._total_timeouts, + "average_wait_seconds": round(avg_wait, 3), + } + + def reset_metrics(self): + """Reset metrics counters""" + with self._metrics_lock: + self._total_predictions = 0 + self._total_timeouts = 0 + self._total_wait_time = 0.0 + + +class PredictionContext: + """ + Context manager for PP-StructureV3 predictions with semaphore control. + + Usage: + with prediction_context(task_id="task_123") as acquired: + if acquired: + result = structure_engine.predict(image_path) + else: + # Handle timeout + """ + + def __init__( + self, + semaphore: PredictionSemaphore, + timeout: Optional[float] = None, + task_id: Optional[str] = None + ): + self._semaphore = semaphore + self._timeout = timeout + self._task_id = task_id + self._acquired = False + + def __enter__(self) -> bool: + self._acquired = self._semaphore.acquire( + timeout=self._timeout, + task_id=self._task_id + ) + return self._acquired + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._acquired: + self._semaphore.release(task_id=self._task_id) + return False # Don't suppress exceptions + + +# Global prediction semaphore instance +_prediction_semaphore: Optional[PredictionSemaphore] = None + + +def get_prediction_semaphore(max_concurrent: Optional[int] = None) -> PredictionSemaphore: + """ + Get the global PredictionSemaphore instance. + + Args: + max_concurrent: Max concurrent predictions (only used on first call) + + Returns: + PredictionSemaphore singleton instance + """ + global _prediction_semaphore + if _prediction_semaphore is None: + from app.core.config import settings + max_conc = max_concurrent or settings.max_concurrent_predictions + _prediction_semaphore = PredictionSemaphore(max_concurrent=max_conc) + return _prediction_semaphore + + +def shutdown_prediction_semaphore(): + """Reset the global PredictionSemaphore instance""" + global _prediction_semaphore + if _prediction_semaphore is not None: + # Reset singleton for clean state + PredictionSemaphore._instance = None + PredictionSemaphore._lock = threading.Lock() + _prediction_semaphore = None + + +def prediction_context( + timeout: Optional[float] = None, + task_id: Optional[str] = None +) -> PredictionContext: + """ + Create a prediction context manager. + + Args: + timeout: Timeout in seconds for acquiring slot + task_id: Optional task identifier for logging + + Returns: + PredictionContext context manager + """ + semaphore = get_prediction_semaphore() + return PredictionContext(semaphore, timeout, task_id) + + +class ModelManager: + """ + Centralized model lifecycle management with reference counting and idle timeout. + + Features: + - Reference counting for shared model instances + - Idle timeout for automatic unloading + - LRU eviction when memory pressure + - Thread-safe operations + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """Singleton pattern - ensure only one ModelManager exists""" + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, config: Optional[MemoryConfig] = None): + if self._initialized: + return + + self.config = config or MemoryConfig() + self.models: Dict[str, ModelEntry] = {} + self.memory_guard = MemoryGuard(self.config) + self._model_lock = threading.RLock() + self._loading_locks: Dict[str, threading.Lock] = {} + + # Start background timeout monitor + self._monitor_running = True + self._monitor_thread = threading.Thread( + target=self._timeout_monitor_loop, + daemon=True, + name="ModelManager-TimeoutMonitor" + ) + self._monitor_thread.start() + + self._initialized = True + logger.info("ModelManager initialized") + + def _timeout_monitor_loop(self): + """Background thread to monitor and unload idle models""" + while self._monitor_running: + try: + time.sleep(self.config.memory_check_interval_seconds) + if self.config.enable_auto_cleanup: + self._cleanup_idle_models() + except Exception as e: + logger.error(f"Error in timeout monitor: {e}") + + def _cleanup_idle_models(self): + """Unload models that have been idle longer than the timeout""" + current_time = time.time() + models_to_unload = [] + + with self._model_lock: + for model_id, entry in self.models.items(): + # Only unload if no active references and idle timeout exceeded + if entry.ref_count <= 0: + idle_time = current_time - entry.last_used + if idle_time > self.config.model_idle_timeout_seconds: + models_to_unload.append(model_id) + + for model_id in models_to_unload: + self.unload_model(model_id, force=False) + + def get_or_load_model( + self, + model_id: str, + loader_func: Callable[[], Any], + estimated_memory_mb: float = 0, + cleanup_callback: Optional[Callable] = None + ) -> Any: + """ + Get a model by ID, loading it if not already loaded. + + Args: + model_id: Unique identifier for the model + loader_func: Function to call to load the model if not cached + estimated_memory_mb: Estimated memory usage for this model + cleanup_callback: Optional callback to run before unloading + + Returns: + The model instance + """ + with self._model_lock: + # Check if model is already loaded + if model_id in self.models: + entry = self.models[model_id] + if not entry.is_loading: + entry.ref_count += 1 + entry.last_used = time.time() + logger.debug(f"Model {model_id} acquired (ref_count={entry.ref_count})") + return entry.model + + # Create loading lock for this model if not exists + if model_id not in self._loading_locks: + self._loading_locks[model_id] = threading.Lock() + + # Load model outside the main lock to allow concurrent operations + loading_lock = self._loading_locks[model_id] + + with loading_lock: + # Double-check after acquiring loading lock + with self._model_lock: + if model_id in self.models and not self.models[model_id].is_loading: + entry = self.models[model_id] + entry.ref_count += 1 + entry.last_used = time.time() + return entry.model + + # Mark as loading + self.models[model_id] = ModelEntry( + model=None, + model_id=model_id, + is_loading=True, + estimated_memory_mb=estimated_memory_mb, + cleanup_callback=cleanup_callback + ) + + try: + # Check memory before loading + if estimated_memory_mb > 0: + is_available, stats = self.memory_guard.check_memory(int(estimated_memory_mb)) + if not is_available and self.config.enable_emergency_cleanup: + logger.warning(f"Memory low, attempting cleanup before loading {model_id}") + self._evict_lru_models(required_mb=estimated_memory_mb) + + # Load the model + logger.info(f"Loading model {model_id} (estimated {estimated_memory_mb}MB)") + start_time = time.time() + model = loader_func() + load_time = time.time() - start_time + logger.info(f"Model {model_id} loaded in {load_time:.2f}s") + + # Update entry + with self._model_lock: + self.models[model_id] = ModelEntry( + model=model, + model_id=model_id, + ref_count=1, + estimated_memory_mb=estimated_memory_mb, + cleanup_callback=cleanup_callback, + is_loading=False + ) + + return model + + except Exception as e: + # Clean up failed entry + with self._model_lock: + if model_id in self.models: + del self.models[model_id] + logger.error(f"Failed to load model {model_id}: {e}") + raise + + def release_model(self, model_id: str): + """ + Release a reference to a model. + + Args: + model_id: Model identifier + """ + with self._model_lock: + if model_id in self.models: + entry = self.models[model_id] + entry.ref_count = max(0, entry.ref_count - 1) + entry.last_used = time.time() + logger.debug(f"Model {model_id} released (ref_count={entry.ref_count})") + + def unload_model(self, model_id: str, force: bool = False) -> bool: + """ + Unload a model from memory. + + Args: + model_id: Model identifier + force: Force unload even if references exist + + Returns: + True if model was unloaded + """ + with self._model_lock: + if model_id not in self.models: + return False + + entry = self.models[model_id] + + # Don't unload if there are active references (unless forced) + if entry.ref_count > 0 and not force: + logger.warning(f"Cannot unload {model_id}: {entry.ref_count} active references") + return False + + # Run cleanup callback if provided + if entry.cleanup_callback: + try: + entry.cleanup_callback() + except Exception as e: + logger.warning(f"Cleanup callback failed for {model_id}: {e}") + + # Delete model + del self.models[model_id] + logger.info(f"Model {model_id} unloaded") + + # Clear GPU cache after unloading + self.memory_guard.clear_gpu_cache() + return True + + def _evict_lru_models(self, required_mb: float = 0): + """ + Evict least recently used models to free memory. + + Args: + required_mb: Target amount of memory to free + """ + with self._model_lock: + # Sort models by last_used (oldest first), excluding those with references + eviction_candidates = [ + (model_id, entry) + for model_id, entry in self.models.items() + if entry.ref_count <= 0 and not entry.is_loading + ] + eviction_candidates.sort(key=lambda x: x[1].last_used) + + freed_mb = 0 + for model_id, entry in eviction_candidates: + if self.unload_model(model_id, force=False): + freed_mb += entry.estimated_memory_mb + logger.info(f"Evicted LRU model {model_id}, freed ~{entry.estimated_memory_mb}MB") + + if required_mb > 0 and freed_mb >= required_mb: + break + + def get_model_stats(self) -> Dict: + """Get statistics about loaded models""" + with self._model_lock: + return { + "total_models": len(self.models), + "models": { + model_id: { + "ref_count": entry.ref_count, + "last_used": entry.last_used, + "estimated_memory_mb": entry.estimated_memory_mb, + "is_loading": entry.is_loading, + "idle_seconds": time.time() - entry.last_used + } + for model_id, entry in self.models.items() + }, + "total_estimated_memory_mb": sum( + e.estimated_memory_mb for e in self.models.values() + ), + "memory_stats": self.memory_guard.get_memory_stats().__dict__ + } + + def teardown(self): + """ + Clean up all models and resources. + Called during application shutdown. + """ + logger.info("ModelManager teardown started") + + # Stop monitor thread + self._monitor_running = False + + # Unload all models + with self._model_lock: + model_ids = list(self.models.keys()) + + for model_id in model_ids: + self.unload_model(model_id, force=True) + + # Clean up memory guard + self.memory_guard.shutdown() + + logger.info("ModelManager teardown completed") + + +# Global singleton instance +_model_manager: Optional[ModelManager] = None + + +def get_model_manager(config: Optional[MemoryConfig] = None) -> ModelManager: + """ + Get the global ModelManager instance. + + Args: + config: Optional configuration (only used on first call) + + Returns: + ModelManager singleton instance + """ + global _model_manager + if _model_manager is None: + _model_manager = ModelManager(config) + return _model_manager + + +def shutdown_model_manager(): + """Shutdown the global ModelManager instance""" + global _model_manager + if _model_manager is not None: + _model_manager.teardown() + _model_manager = None + + +# ============================================================================= +# Section 4.2: Batch Processing and Progressive Loading +# ============================================================================= + +class BatchPriority(Enum): + """Priority levels for batch operations""" + LOW = 0 + NORMAL = 1 + HIGH = 2 + CRITICAL = 3 + + +@dataclass +class BatchItem: + """Item in a processing batch""" + item_id: str + data: Any + priority: BatchPriority = BatchPriority.NORMAL + created_at: float = field(default_factory=time.time) + estimated_memory_mb: float = 0.0 + metadata: Dict = field(default_factory=dict) + + +@dataclass +class BatchResult: + """Result of batch processing""" + item_id: str + success: bool + result: Any = None + error: Optional[str] = None + processing_time_ms: float = 0.0 + + +class BatchProcessor: + """ + Process items in batches to optimize memory usage for large documents. + + Features: + - Memory-aware batch sizing + - Priority-based processing + - Progress tracking + - Automatic memory cleanup between batches + """ + + def __init__( + self, + max_batch_size: int = 5, + max_memory_per_batch_mb: float = 2000.0, + memory_guard: Optional[MemoryGuard] = None, + cleanup_between_batches: bool = True + ): + """ + Initialize BatchProcessor. + + Args: + max_batch_size: Maximum items per batch + max_memory_per_batch_mb: Maximum memory allowed per batch + memory_guard: MemoryGuard instance for memory monitoring + cleanup_between_batches: Whether to clear GPU cache between batches + """ + self.max_batch_size = max_batch_size + self.max_memory_per_batch_mb = max_memory_per_batch_mb + self.memory_guard = memory_guard or MemoryGuard() + self.cleanup_between_batches = cleanup_between_batches + + self._queue: List[BatchItem] = [] + self._lock = threading.Lock() + self._processing = False + + # Statistics + self._total_processed = 0 + self._total_batches = 0 + self._total_failures = 0 + + logger.info(f"BatchProcessor initialized (max_batch_size={max_batch_size}, max_memory={max_memory_per_batch_mb}MB)") + + def add_item(self, item: BatchItem): + """Add an item to the processing queue""" + with self._lock: + self._queue.append(item) + # Sort by priority (highest first), then by creation time (oldest first) + self._queue.sort(key=lambda x: (-x.priority.value, x.created_at)) + logger.debug(f"Added item {item.item_id} to batch queue (queue_size={len(self._queue)})") + + def add_items(self, items: List[BatchItem]): + """Add multiple items to the processing queue""" + with self._lock: + self._queue.extend(items) + self._queue.sort(key=lambda x: (-x.priority.value, x.created_at)) + logger.debug(f"Added {len(items)} items to batch queue (queue_size={len(self._queue)})") + + def _create_batch(self) -> List[BatchItem]: + """Create a batch from the queue based on size and memory constraints""" + batch = [] + batch_memory = 0.0 + + with self._lock: + remaining = [] + for item in self._queue: + # Check if adding this item would exceed limits + if len(batch) >= self.max_batch_size: + remaining.append(item) + continue + + if batch_memory + item.estimated_memory_mb > self.max_memory_per_batch_mb and batch: + remaining.append(item) + continue + + batch.append(item) + batch_memory += item.estimated_memory_mb + + self._queue = remaining + + return batch + + def process_batch( + self, + processor_func: Callable[[Any], Any], + progress_callback: Optional[Callable[[int, int, BatchResult], None]] = None + ) -> List[BatchResult]: + """ + Process a single batch of items. + + Args: + processor_func: Function to process each item (receives item.data) + progress_callback: Optional callback(current, total, result) + + Returns: + List of BatchResult for each item in the batch + """ + batch = self._create_batch() + if not batch: + return [] + + self._processing = True + results = [] + total = len(batch) + + try: + for i, item in enumerate(batch): + start_time = time.time() + result = BatchResult(item_id=item.item_id, success=False) + + try: + # Check memory before processing + is_available, stats = self.memory_guard.check_memory( + int(item.estimated_memory_mb) + ) + if not is_available: + logger.warning(f"Insufficient memory for item {item.item_id}, cleaning up...") + self.memory_guard.clear_gpu_cache() + gc.collect() + + # Process item + result.result = processor_func(item.data) + result.success = True + + except Exception as e: + result.error = str(e) + self._total_failures += 1 + logger.error(f"Failed to process item {item.item_id}: {e}") + + result.processing_time_ms = (time.time() - start_time) * 1000 + results.append(result) + self._total_processed += 1 + + # Call progress callback + if progress_callback: + progress_callback(i + 1, total, result) + + self._total_batches += 1 + + finally: + self._processing = False + + # Clean up after batch + if self.cleanup_between_batches: + self.memory_guard.clear_gpu_cache() + gc.collect() + + return results + + def process_all( + self, + processor_func: Callable[[Any], Any], + progress_callback: Optional[Callable[[int, int, BatchResult], None]] = None + ) -> List[BatchResult]: + """ + Process all items in the queue. + + Args: + processor_func: Function to process each item + progress_callback: Optional progress callback + + Returns: + List of all BatchResults + """ + all_results = [] + + while True: + with self._lock: + if not self._queue: + break + + batch_results = self.process_batch(processor_func, progress_callback) + all_results.extend(batch_results) + + return all_results + + def get_queue_size(self) -> int: + """Get current queue size""" + with self._lock: + return len(self._queue) + + def get_stats(self) -> Dict: + """Get processing statistics""" + with self._lock: + return { + "queue_size": len(self._queue), + "total_processed": self._total_processed, + "total_batches": self._total_batches, + "total_failures": self._total_failures, + "is_processing": self._processing, + "max_batch_size": self.max_batch_size, + "max_memory_per_batch_mb": self.max_memory_per_batch_mb, + } + + def clear_queue(self): + """Clear the processing queue""" + with self._lock: + self._queue.clear() + logger.info("Batch queue cleared") + + +class ProgressiveLoader: + """ + Progressive page loader for multi-page documents. + + Loads and processes pages incrementally to minimize memory usage. + Supports lookahead loading for better performance. + """ + + def __init__( + self, + lookahead_pages: int = 2, + memory_guard: Optional[MemoryGuard] = None, + cleanup_after_pages: int = 5 + ): + """ + Initialize ProgressiveLoader. + + Args: + lookahead_pages: Number of pages to load ahead + memory_guard: MemoryGuard instance + cleanup_after_pages: Trigger cleanup after this many pages + """ + self.lookahead_pages = lookahead_pages + self.memory_guard = memory_guard or MemoryGuard() + self.cleanup_after_pages = cleanup_after_pages + + self._loaded_pages: Dict[int, Any] = {} + self._lock = threading.Lock() + self._current_page = 0 + self._total_pages = 0 + self._pages_since_cleanup = 0 + + logger.info(f"ProgressiveLoader initialized (lookahead={lookahead_pages})") + + def initialize(self, total_pages: int): + """Initialize loader with total page count""" + with self._lock: + self._total_pages = total_pages + self._current_page = 0 + self._loaded_pages.clear() + self._pages_since_cleanup = 0 + logger.info(f"ProgressiveLoader initialized for {total_pages} pages") + + def load_page( + self, + page_num: int, + loader_func: Callable[[int], Any], + unload_distant: bool = True + ) -> Any: + """ + Load a specific page. + + Args: + page_num: Page number to load (0-indexed) + loader_func: Function to load page (receives page_num) + unload_distant: Unload pages far from current position + + Returns: + Loaded page data + """ + with self._lock: + # Check if already loaded + if page_num in self._loaded_pages: + self._current_page = page_num + return self._loaded_pages[page_num] + + # Load the page + logger.debug(f"Loading page {page_num}") + page_data = loader_func(page_num) + + with self._lock: + self._loaded_pages[page_num] = page_data + self._current_page = page_num + self._pages_since_cleanup += 1 + + # Unload distant pages to save memory + if unload_distant: + self._unload_distant_pages() + + # Trigger cleanup if needed + if self._pages_since_cleanup >= self.cleanup_after_pages: + self.memory_guard.clear_gpu_cache() + gc.collect() + self._pages_since_cleanup = 0 + + return page_data + + def _unload_distant_pages(self): + """Unload pages far from current position""" + keep_range = range( + max(0, self._current_page - 1), + min(self._total_pages, self._current_page + self.lookahead_pages + 1) + ) + + pages_to_unload = [ + p for p in self._loaded_pages.keys() + if p not in keep_range + ] + + for page_num in pages_to_unload: + del self._loaded_pages[page_num] + logger.debug(f"Unloaded distant page {page_num}") + + def prefetch_pages( + self, + start_page: int, + loader_func: Callable[[int], Any] + ): + """ + Prefetch upcoming pages in background. + + Args: + start_page: Starting page number + loader_func: Function to load page + """ + for i in range(self.lookahead_pages): + page_num = start_page + i + 1 + if page_num >= self._total_pages: + break + + with self._lock: + if page_num in self._loaded_pages: + continue + + try: + self.load_page(page_num, loader_func, unload_distant=False) + except Exception as e: + logger.warning(f"Prefetch failed for page {page_num}: {e}") + + def iterate_pages( + self, + loader_func: Callable[[int], Any], + processor_func: Callable[[int, Any], Any], + progress_callback: Optional[Callable[[int, int], None]] = None + ) -> List[Any]: + """ + Iterate through all pages with progressive loading. + + Args: + loader_func: Function to load a page + processor_func: Function to process page (receives page_num, data) + progress_callback: Optional callback(current_page, total_pages) + + Returns: + List of results from processor_func + """ + results = [] + + for page_num in range(self._total_pages): + # Load page + page_data = self.load_page(page_num, loader_func) + + # Process page + result = processor_func(page_num, page_data) + results.append(result) + + # Report progress + if progress_callback: + progress_callback(page_num + 1, self._total_pages) + + # Start prefetching next pages in background + if self.lookahead_pages > 0: + # Use thread for prefetching to not block + prefetch_thread = threading.Thread( + target=self.prefetch_pages, + args=(page_num, loader_func), + daemon=True + ) + prefetch_thread.start() + + return results + + def get_loaded_pages(self) -> List[int]: + """Get list of currently loaded page numbers""" + with self._lock: + return list(self._loaded_pages.keys()) + + def get_stats(self) -> Dict: + """Get loader statistics""" + with self._lock: + return { + "total_pages": self._total_pages, + "current_page": self._current_page, + "loaded_pages_count": len(self._loaded_pages), + "loaded_pages": list(self._loaded_pages.keys()), + "lookahead_pages": self.lookahead_pages, + "pages_since_cleanup": self._pages_since_cleanup, + } + + def clear(self): + """Clear all loaded pages""" + with self._lock: + self._loaded_pages.clear() + self._current_page = 0 + self._pages_since_cleanup = 0 + self.memory_guard.clear_gpu_cache() + gc.collect() + + +class PriorityOperationQueue: + """ + Priority queue for OCR operations. + + Higher priority operations are processed first. + Supports timeout and cancellation. + """ + + def __init__(self, max_size: int = 100): + """ + Initialize priority queue. + + Args: + max_size: Maximum queue size (0 for unlimited) + """ + self.max_size = max_size + self._queue: List[Tuple[BatchPriority, float, str, Any]] = [] # (priority, timestamp, id, data) + self._lock = threading.Lock() + self._condition = threading.Condition(self._lock) + self._cancelled: set = set() + + # Statistics + self._total_enqueued = 0 + self._total_dequeued = 0 + self._total_cancelled = 0 + + logger.info(f"PriorityOperationQueue initialized (max_size={max_size})") + + def enqueue( + self, + item_id: str, + data: Any, + priority: BatchPriority = BatchPriority.NORMAL, + timeout: Optional[float] = None + ) -> bool: + """ + Add an operation to the queue. + + Args: + item_id: Unique identifier for the operation + data: Operation data + priority: Operation priority + timeout: Optional timeout to wait for space in queue + + Returns: + True if enqueued successfully + """ + with self._condition: + # Wait for space if queue is full + if self.max_size > 0 and len(self._queue) >= self.max_size: + if timeout is not None: + result = self._condition.wait_for( + lambda: len(self._queue) < self.max_size, + timeout=timeout + ) + if not result: + logger.warning(f"Queue full, timeout waiting to enqueue {item_id}") + return False + else: + logger.warning(f"Queue full, cannot enqueue {item_id}") + return False + + # Add to queue (negative priority for max-heap behavior) + import heapq + heapq.heappush( + self._queue, + (-priority.value, time.time(), item_id, data) + ) + self._total_enqueued += 1 + self._condition.notify() + + logger.debug(f"Enqueued operation {item_id} with priority {priority.name}") + return True + + def dequeue(self, timeout: Optional[float] = None) -> Optional[Tuple[str, Any, BatchPriority]]: + """ + Get the highest priority operation from the queue. + + Args: + timeout: Optional timeout to wait for an item + + Returns: + Tuple of (item_id, data, priority) or None if timeout + """ + import heapq + + with self._condition: + # Wait for an item + if not self._queue: + if timeout is not None: + result = self._condition.wait_for( + lambda: len(self._queue) > 0, + timeout=timeout + ) + if not result: + return None + else: + return None + + # Get highest priority item + neg_priority, _, item_id, data = heapq.heappop(self._queue) + priority = BatchPriority(-neg_priority) + + # Skip if cancelled + if item_id in self._cancelled: + self._cancelled.discard(item_id) + self._total_cancelled += 1 + self._condition.notify() + return self.dequeue(timeout=0) # Try next item + + self._total_dequeued += 1 + self._condition.notify() + + logger.debug(f"Dequeued operation {item_id} with priority {priority.name}") + return item_id, data, priority + + def cancel(self, item_id: str) -> bool: + """ + Cancel a pending operation. + + Args: + item_id: Operation identifier to cancel + + Returns: + True if the operation was found and marked for cancellation + """ + with self._lock: + # Check if item is in queue + for _, _, qid, _ in self._queue: + if qid == item_id: + self._cancelled.add(item_id) + logger.info(f"Operation {item_id} marked for cancellation") + return True + return False + + def get_size(self) -> int: + """Get current queue size""" + with self._lock: + return len(self._queue) + + def get_stats(self) -> Dict: + """Get queue statistics""" + with self._lock: + # Count by priority + priority_counts = {p.name: 0 for p in BatchPriority} + for neg_priority, _, _, _ in self._queue: + priority = BatchPriority(-neg_priority) + priority_counts[priority.name] += 1 + + return { + "queue_size": len(self._queue), + "max_size": self.max_size, + "total_enqueued": self._total_enqueued, + "total_dequeued": self._total_dequeued, + "total_cancelled": self._total_cancelled, + "pending_cancellations": len(self._cancelled), + "by_priority": priority_counts, + } + + def clear(self): + """Clear the queue""" + with self._lock: + self._queue.clear() + self._cancelled.clear() + logger.info("Priority queue cleared") + + +# ============================================================================= +# Section 5.2: Recovery Mechanisms +# ============================================================================= + +@dataclass +class RecoveryState: + """State of recovery mechanism""" + last_recovery_time: float = 0.0 + recovery_count: int = 0 + in_cooldown: bool = False + cooldown_until: float = 0.0 + last_error: Optional[str] = None + + +class RecoveryManager: + """ + Manages recovery mechanisms for memory issues and failures. + + Features: + - Emergency memory release + - Cooldown period after recovery + - Recovery attempt limits + """ + + def __init__( + self, + cooldown_seconds: float = 30.0, + max_recovery_attempts: int = 3, + recovery_window_seconds: float = 300.0, + memory_guard: Optional[MemoryGuard] = None + ): + """ + Initialize RecoveryManager. + + Args: + cooldown_seconds: Cooldown period after recovery + max_recovery_attempts: Max recovery attempts within window + recovery_window_seconds: Window for counting recovery attempts + memory_guard: MemoryGuard instance + """ + self.cooldown_seconds = cooldown_seconds + self.max_recovery_attempts = max_recovery_attempts + self.recovery_window_seconds = recovery_window_seconds + self.memory_guard = memory_guard or MemoryGuard() + + self._state = RecoveryState() + self._lock = threading.Lock() + self._recovery_times: List[float] = [] + + # Callbacks + self._on_recovery_start: List[Callable] = [] + self._on_recovery_complete: List[Callable[[bool], None]] = [] + + logger.info(f"RecoveryManager initialized (cooldown={cooldown_seconds}s)") + + def register_callbacks( + self, + on_start: Optional[Callable] = None, + on_complete: Optional[Callable[[bool], None]] = None + ): + """Register recovery event callbacks""" + if on_start: + self._on_recovery_start.append(on_start) + if on_complete: + self._on_recovery_complete.append(on_complete) + + def is_in_cooldown(self) -> bool: + """Check if currently in cooldown period""" + with self._lock: + if not self._state.in_cooldown: + return False + + if time.time() >= self._state.cooldown_until: + self._state.in_cooldown = False + return False + + return True + + def get_cooldown_remaining(self) -> float: + """Get remaining cooldown time in seconds""" + with self._lock: + if not self._state.in_cooldown: + return 0.0 + return max(0, self._state.cooldown_until - time.time()) + + def _count_recent_recoveries(self) -> int: + """Count recovery attempts within the window""" + cutoff = time.time() - self.recovery_window_seconds + with self._lock: + # Clean old entries + self._recovery_times = [t for t in self._recovery_times if t > cutoff] + return len(self._recovery_times) + + def can_attempt_recovery(self) -> Tuple[bool, str]: + """ + Check if recovery can be attempted. + + Returns: + Tuple of (can_recover, reason) + """ + if self.is_in_cooldown(): + remaining = self.get_cooldown_remaining() + return False, f"In cooldown period ({remaining:.1f}s remaining)" + + recent_count = self._count_recent_recoveries() + if recent_count >= self.max_recovery_attempts: + return False, f"Max recovery attempts ({self.max_recovery_attempts}) reached" + + return True, "Recovery allowed" + + def attempt_recovery(self, error: Optional[str] = None) -> bool: + """ + Attempt memory recovery. + + Args: + error: Optional error message that triggered recovery + + Returns: + True if recovery was successful + """ + can_recover, reason = self.can_attempt_recovery() + if not can_recover: + logger.warning(f"Cannot attempt recovery: {reason}") + return False + + logger.info("Starting memory recovery...") + + # Notify callbacks + for callback in self._on_recovery_start: + try: + callback() + except Exception as e: + logger.warning(f"Recovery start callback failed: {e}") + + success = False + try: + # Step 1: Clear GPU cache + self.memory_guard.clear_gpu_cache() + + # Step 2: Force garbage collection + gc.collect() + + # Step 3: Check memory status + is_available, stats = self.memory_guard.check_memory() + success = is_available or stats.gpu_used_ratio < 0.9 + + if success: + logger.info( + f"Memory recovery successful. GPU: {stats.gpu_used_ratio*100:.1f}% used" + ) + else: + logger.warning( + f"Memory recovery incomplete. GPU still at {stats.gpu_used_ratio*100:.1f}%" + ) + + except Exception as e: + logger.error(f"Recovery failed with error: {e}") + success = False + + # Update state + with self._lock: + self._state.last_recovery_time = time.time() + self._state.recovery_count += 1 + self._state.last_error = error + self._recovery_times.append(time.time()) + + # Enter cooldown + self._state.in_cooldown = True + self._state.cooldown_until = time.time() + self.cooldown_seconds + + logger.info(f"Entering cooldown period ({self.cooldown_seconds}s)") + + # Notify callbacks + for callback in self._on_recovery_complete: + try: + callback(success) + except Exception as e: + logger.warning(f"Recovery complete callback failed: {e}") + + return success + + def emergency_release(self, model_manager: Optional['ModelManager'] = None) -> bool: + """ + Emergency memory release - more aggressive than normal recovery. + + Args: + model_manager: Optional ModelManager to unload models from + + Returns: + True if significant memory was freed + """ + logger.warning("Initiating EMERGENCY memory release") + + initial_stats = self.memory_guard.get_memory_stats() + + # Step 1: Unload all models if model_manager provided + if model_manager: + logger.info("Unloading all models...") + try: + model_ids = list(model_manager.models.keys()) + for model_id in model_ids: + model_manager.unload_model(model_id, force=True) + except Exception as e: + logger.error(f"Failed to unload models: {e}") + + # Step 2: Clear all caches + self.memory_guard.clear_gpu_cache() + + # Step 3: Multiple rounds of garbage collection + for i in range(3): + gc.collect() + time.sleep(0.1) + + # Step 4: Check improvement + final_stats = self.memory_guard.get_memory_stats() + freed_mb = initial_stats.gpu_used_mb - final_stats.gpu_used_mb + + logger.info( + f"Emergency release complete. Freed ~{freed_mb:.0f}MB. " + f"GPU: {final_stats.gpu_used_mb:.0f}MB / {final_stats.gpu_total_mb:.0f}MB " + f"({final_stats.gpu_used_ratio*100:.1f}%)" + ) + + return freed_mb > 100 # Consider success if freed >100MB + + def get_state(self) -> Dict: + """Get current recovery state""" + with self._lock: + return { + "last_recovery_time": self._state.last_recovery_time, + "recovery_count": self._state.recovery_count, + "in_cooldown": self._state.in_cooldown, + "cooldown_remaining_seconds": self.get_cooldown_remaining(), + "recent_recoveries": self._count_recent_recoveries(), + "max_recovery_attempts": self.max_recovery_attempts, + "last_error": self._state.last_error, + } + + +# Global recovery manager instance +_recovery_manager: Optional[RecoveryManager] = None + + +def get_recovery_manager( + cooldown_seconds: float = 30.0, + max_recovery_attempts: int = 3 +) -> RecoveryManager: + """ + Get the global RecoveryManager instance. + + Args: + cooldown_seconds: Cooldown period after recovery + max_recovery_attempts: Max recovery attempts within window + + Returns: + RecoveryManager singleton instance + """ + global _recovery_manager + if _recovery_manager is None: + _recovery_manager = RecoveryManager( + cooldown_seconds=cooldown_seconds, + max_recovery_attempts=max_recovery_attempts + ) + return _recovery_manager + + +def shutdown_recovery_manager(): + """Shutdown the global RecoveryManager instance""" + global _recovery_manager + _recovery_manager = None + + +# ============================================================================= +# Section 5.2: Memory Dump for Debugging +# ============================================================================= + +@dataclass +class MemoryDumpEntry: + """Entry in a memory dump""" + object_type: str + object_id: str + size_bytes: int + ref_count: int + details: Dict = field(default_factory=dict) + + +@dataclass +class MemoryDump: + """Complete memory dump for debugging""" + timestamp: float + total_gpu_memory_mb: float + used_gpu_memory_mb: float + free_gpu_memory_mb: float + total_cpu_memory_mb: float + used_cpu_memory_mb: float + loaded_models: List[Dict] + active_predictions: int + queue_depth: int + service_pool_stats: Dict + recovery_state: Dict + python_objects: List[MemoryDumpEntry] = field(default_factory=list) + gc_stats: Dict = field(default_factory=dict) + + +class MemoryDumper: + """ + Creates memory dumps for debugging memory issues. + + Captures comprehensive memory state including: + - GPU/CPU memory usage + - Loaded models and their references + - Active predictions and queue state + - Python garbage collector statistics + - Large object tracking + """ + + def __init__(self, memory_guard: Optional[MemoryGuard] = None): + """ + Initialize MemoryDumper. + + Args: + memory_guard: MemoryGuard instance for memory queries + """ + self.memory_guard = memory_guard or MemoryGuard() + self._dump_history: List[MemoryDump] = [] + self._max_history = 10 + self._lock = threading.Lock() + + logger.info("MemoryDumper initialized") + + def create_dump( + self, + include_python_objects: bool = False, + min_object_size: int = 1048576 # 1MB + ) -> MemoryDump: + """ + Create a memory dump capturing current state. + + Args: + include_python_objects: Include large Python objects in dump + min_object_size: Minimum object size to include (bytes) + + Returns: + MemoryDump with current memory state + """ + logger.info("Creating memory dump...") + + # Get memory stats + stats = self.memory_guard.get_memory_stats() + + # Get model manager stats + loaded_models = [] + try: + model_manager = get_model_manager() + model_stats = model_manager.get_model_stats() + loaded_models = [ + { + "model_id": model_id, + "ref_count": info["ref_count"], + "estimated_memory_mb": info["estimated_memory_mb"], + "idle_seconds": info["idle_seconds"], + "is_loading": info["is_loading"], + } + for model_id, info in model_stats.get("models", {}).items() + ] + except Exception as e: + logger.debug(f"Could not get model stats: {e}") + + # Get prediction semaphore stats + active_predictions = 0 + queue_depth = 0 + try: + semaphore = get_prediction_semaphore() + sem_stats = semaphore.get_stats() + active_predictions = sem_stats.get("active_predictions", 0) + queue_depth = sem_stats.get("queue_depth", 0) + except Exception as e: + logger.debug(f"Could not get semaphore stats: {e}") + + # Get service pool stats + service_pool_stats = {} + try: + from app.services.service_pool import get_service_pool + pool = get_service_pool() + service_pool_stats = pool.get_pool_stats() + except Exception as e: + logger.debug(f"Could not get service pool stats: {e}") + + # Get recovery state + recovery_state = {} + try: + recovery_manager = get_recovery_manager() + recovery_state = recovery_manager.get_state() + except Exception as e: + logger.debug(f"Could not get recovery state: {e}") + + # Get GC stats + gc_stats = { + "counts": gc.get_count(), + "threshold": gc.get_threshold(), + "is_tracking": gc.isenabled(), + } + + # Create dump + dump = MemoryDump( + timestamp=time.time(), + total_gpu_memory_mb=stats.gpu_total_mb, + used_gpu_memory_mb=stats.gpu_used_mb, + free_gpu_memory_mb=stats.gpu_free_mb, + total_cpu_memory_mb=stats.cpu_used_mb + stats.cpu_available_mb, + used_cpu_memory_mb=stats.cpu_used_mb, + loaded_models=loaded_models, + active_predictions=active_predictions, + queue_depth=queue_depth, + service_pool_stats=service_pool_stats, + recovery_state=recovery_state, + gc_stats=gc_stats, + ) + + # Optionally include large Python objects + if include_python_objects: + dump.python_objects = self._get_large_objects(min_object_size) + + # Store in history + with self._lock: + self._dump_history.append(dump) + if len(self._dump_history) > self._max_history: + self._dump_history.pop(0) + + logger.info( + f"Memory dump created: GPU {stats.gpu_used_mb:.0f}/{stats.gpu_total_mb:.0f}MB, " + f"{len(loaded_models)} models, {active_predictions} active predictions" + ) + + return dump + + def _get_large_objects(self, min_size: int) -> List[MemoryDumpEntry]: + """Get list of large Python objects for debugging""" + large_objects = [] + + try: + import sys + + # Get all objects tracked by GC + for obj in gc.get_objects(): + try: + size = sys.getsizeof(obj) + if size >= min_size: + entry = MemoryDumpEntry( + object_type=type(obj).__name__, + object_id=str(id(obj)), + size_bytes=size, + ref_count=sys.getrefcount(obj), + details={ + "module": getattr(type(obj), "__module__", "unknown"), + } + ) + large_objects.append(entry) + except Exception: + pass # Skip objects that can't be measured + + # Sort by size descending + large_objects.sort(key=lambda x: x.size_bytes, reverse=True) + + # Limit to top 100 + return large_objects[:100] + + except Exception as e: + logger.warning(f"Failed to get large objects: {e}") + return [] + + def get_dump_history(self) -> List[MemoryDump]: + """Get recent dump history""" + with self._lock: + return list(self._dump_history) + + def get_latest_dump(self) -> Optional[MemoryDump]: + """Get the most recent dump""" + with self._lock: + return self._dump_history[-1] if self._dump_history else None + + def compare_dumps( + self, + dump1: MemoryDump, + dump2: MemoryDump + ) -> Dict: + """ + Compare two memory dumps to identify changes. + + Args: + dump1: First (earlier) dump + dump2: Second (later) dump + + Returns: + Dictionary with comparison results + """ + return { + "time_delta_seconds": dump2.timestamp - dump1.timestamp, + "gpu_memory_change_mb": dump2.used_gpu_memory_mb - dump1.used_gpu_memory_mb, + "cpu_memory_change_mb": dump2.used_cpu_memory_mb - dump1.used_cpu_memory_mb, + "model_count_change": len(dump2.loaded_models) - len(dump1.loaded_models), + "prediction_count_change": dump2.active_predictions - dump1.active_predictions, + "dump1_timestamp": dump1.timestamp, + "dump2_timestamp": dump2.timestamp, + } + + def to_dict(self, dump: MemoryDump) -> Dict: + """Convert a MemoryDump to a dictionary for JSON serialization""" + return { + "timestamp": dump.timestamp, + "gpu": { + "total_mb": dump.total_gpu_memory_mb, + "used_mb": dump.used_gpu_memory_mb, + "free_mb": dump.free_gpu_memory_mb, + "utilization_percent": ( + dump.used_gpu_memory_mb / dump.total_gpu_memory_mb * 100 + if dump.total_gpu_memory_mb > 0 else 0 + ), + }, + "cpu": { + "total_mb": dump.total_cpu_memory_mb, + "used_mb": dump.used_cpu_memory_mb, + }, + "models": dump.loaded_models, + "predictions": { + "active": dump.active_predictions, + "queue_depth": dump.queue_depth, + }, + "service_pool": dump.service_pool_stats, + "recovery": dump.recovery_state, + "gc": dump.gc_stats, + "large_objects_count": len(dump.python_objects), + } + + +# Global memory dumper instance +_memory_dumper: Optional[MemoryDumper] = None + + +def get_memory_dumper() -> MemoryDumper: + """Get the global MemoryDumper instance""" + global _memory_dumper + if _memory_dumper is None: + _memory_dumper = MemoryDumper() + return _memory_dumper + + +def shutdown_memory_dumper(): + """Shutdown the global MemoryDumper instance""" + global _memory_dumper + _memory_dumper = None + + +# ============================================================================= +# Section 7.2: Prometheus Metrics Export +# ============================================================================= + +class PrometheusMetrics: + """ + Prometheus metrics exporter for memory management. + + Exposes metrics in Prometheus text format for monitoring: + - GPU/CPU memory usage + - Model lifecycle metrics + - Prediction semaphore metrics + - Service pool metrics + - Recovery metrics + """ + + # Metric names + METRIC_PREFIX = "tool_ocr_memory_" + + def __init__(self): + """Initialize PrometheusMetrics""" + self._custom_metrics: Dict[str, float] = {} + self._lock = threading.Lock() + logger.info("PrometheusMetrics initialized") + + def set_custom_metric(self, name: str, value: float, labels: Optional[Dict[str, str]] = None): + """ + Set a custom metric value. + + Args: + name: Metric name + value: Metric value + labels: Optional labels for the metric + """ + with self._lock: + key = name if not labels else f"{name}{{{self._format_labels(labels)}}}" + self._custom_metrics[key] = value + + def _format_labels(self, labels: Dict[str, str]) -> str: + """Format labels for Prometheus""" + return ",".join(f'{k}="{v}"' for k, v in sorted(labels.items())) + + def _format_metric(self, name: str, value: float, help_text: str, metric_type: str = "gauge") -> str: + """Format a single metric in Prometheus format""" + lines = [ + f"# HELP {self.METRIC_PREFIX}{name} {help_text}", + f"# TYPE {self.METRIC_PREFIX}{name} {metric_type}", + f"{self.METRIC_PREFIX}{name} {value}", + ] + return "\n".join(lines) + + def _format_metric_with_labels( + self, + name: str, + values: List[Tuple[Dict[str, str], float]], + help_text: str, + metric_type: str = "gauge" + ) -> str: + """Format a metric with labels in Prometheus format""" + lines = [ + f"# HELP {self.METRIC_PREFIX}{name} {help_text}", + f"# TYPE {self.METRIC_PREFIX}{name} {metric_type}", + ] + for labels, value in values: + label_str = self._format_labels(labels) + lines.append(f"{self.METRIC_PREFIX}{name}{{{label_str}}} {value}") + return "\n".join(lines) + + def export_metrics(self) -> str: + """ + Export all metrics in Prometheus text format. + + Returns: + String containing metrics in Prometheus exposition format + """ + metrics = [] + + # GPU Memory metrics + try: + guard = MemoryGuard() + stats = guard.get_memory_stats() + + metrics.append(self._format_metric( + "gpu_total_bytes", + stats.gpu_total_mb * 1024 * 1024, + "Total GPU memory in bytes" + )) + metrics.append(self._format_metric( + "gpu_used_bytes", + stats.gpu_used_mb * 1024 * 1024, + "Used GPU memory in bytes" + )) + metrics.append(self._format_metric( + "gpu_free_bytes", + stats.gpu_free_mb * 1024 * 1024, + "Free GPU memory in bytes" + )) + metrics.append(self._format_metric( + "gpu_utilization_ratio", + stats.gpu_used_ratio, + "GPU memory utilization ratio (0-1)" + )) + metrics.append(self._format_metric( + "cpu_used_bytes", + stats.cpu_used_mb * 1024 * 1024, + "Used CPU memory in bytes" + )) + metrics.append(self._format_metric( + "cpu_available_bytes", + stats.cpu_available_mb * 1024 * 1024, + "Available CPU memory in bytes" + )) + + guard.shutdown() + except Exception as e: + logger.debug(f"Could not get memory stats for metrics: {e}") + + # Model Manager metrics + try: + model_manager = get_model_manager() + model_stats = model_manager.get_model_stats() + + metrics.append(self._format_metric( + "models_loaded_total", + model_stats.get("total_models", 0), + "Total number of loaded models" + )) + metrics.append(self._format_metric( + "models_memory_bytes", + model_stats.get("total_estimated_memory_mb", 0) * 1024 * 1024, + "Estimated total memory used by loaded models in bytes" + )) + + # Per-model metrics + model_values = [] + for model_id, info in model_stats.get("models", {}).items(): + model_values.append(( + {"model_id": model_id}, + info.get("ref_count", 0) + )) + if model_values: + metrics.append(self._format_metric_with_labels( + "model_ref_count", + model_values, + "Reference count per model" + )) + + except Exception as e: + logger.debug(f"Could not get model stats for metrics: {e}") + + # Prediction Semaphore metrics + try: + semaphore = get_prediction_semaphore() + sem_stats = semaphore.get_stats() + + metrics.append(self._format_metric( + "predictions_active", + sem_stats.get("active_predictions", 0), + "Number of currently active predictions" + )) + metrics.append(self._format_metric( + "predictions_queue_depth", + sem_stats.get("queue_depth", 0), + "Number of predictions waiting in queue" + )) + metrics.append(self._format_metric( + "predictions_total", + sem_stats.get("total_predictions", 0), + "Total number of predictions processed", + metric_type="counter" + )) + metrics.append(self._format_metric( + "predictions_timeouts_total", + sem_stats.get("total_timeouts", 0), + "Total number of prediction timeouts", + metric_type="counter" + )) + metrics.append(self._format_metric( + "predictions_avg_wait_seconds", + sem_stats.get("average_wait_seconds", 0), + "Average wait time for predictions in seconds" + )) + metrics.append(self._format_metric( + "predictions_max_concurrent", + sem_stats.get("max_concurrent", 2), + "Maximum concurrent predictions allowed" + )) + + except Exception as e: + logger.debug(f"Could not get semaphore stats for metrics: {e}") + + # Service Pool metrics + try: + from app.services.service_pool import get_service_pool + pool = get_service_pool() + pool_stats = pool.get_pool_stats() + + metrics.append(self._format_metric( + "pool_services_total", + pool_stats.get("total_services", 0), + "Total number of services in pool" + )) + metrics.append(self._format_metric( + "pool_services_available", + pool_stats.get("available_services", 0), + "Number of available services in pool" + )) + metrics.append(self._format_metric( + "pool_services_in_use", + pool_stats.get("in_use_services", 0), + "Number of services currently in use" + )) + + pool_metrics = pool_stats.get("metrics", {}) + metrics.append(self._format_metric( + "pool_acquisitions_total", + pool_metrics.get("total_acquisitions", 0), + "Total number of service acquisitions", + metric_type="counter" + )) + metrics.append(self._format_metric( + "pool_releases_total", + pool_metrics.get("total_releases", 0), + "Total number of service releases", + metric_type="counter" + )) + metrics.append(self._format_metric( + "pool_timeouts_total", + pool_metrics.get("total_timeouts", 0), + "Total number of acquisition timeouts", + metric_type="counter" + )) + metrics.append(self._format_metric( + "pool_errors_total", + pool_metrics.get("total_errors", 0), + "Total number of pool errors", + metric_type="counter" + )) + + except Exception as e: + logger.debug(f"Could not get pool stats for metrics: {e}") + + # Recovery Manager metrics + try: + recovery_manager = get_recovery_manager() + recovery_state = recovery_manager.get_state() + + metrics.append(self._format_metric( + "recovery_count_total", + recovery_state.get("recovery_count", 0), + "Total number of recovery attempts", + metric_type="counter" + )) + metrics.append(self._format_metric( + "recovery_in_cooldown", + 1 if recovery_state.get("in_cooldown", False) else 0, + "Whether recovery is currently in cooldown (1=yes, 0=no)" + )) + metrics.append(self._format_metric( + "recovery_cooldown_remaining_seconds", + recovery_state.get("cooldown_remaining_seconds", 0), + "Remaining cooldown time in seconds" + )) + metrics.append(self._format_metric( + "recovery_recent_count", + recovery_state.get("recent_recoveries", 0), + "Number of recent recovery attempts within window" + )) + + except Exception as e: + logger.debug(f"Could not get recovery stats for metrics: {e}") + + # Custom metrics + with self._lock: + for name, value in self._custom_metrics.items(): + if "{" in name: + # Metric with labels + base_name = name.split("{")[0] + metrics.append(f"{self.METRIC_PREFIX}{name} {value}") + else: + metrics.append(f"{self.METRIC_PREFIX}{name} {value}") + + return "\n\n".join(metrics) + "\n" + + +# Global Prometheus metrics instance +_prometheus_metrics: Optional[PrometheusMetrics] = None + + +def get_prometheus_metrics() -> PrometheusMetrics: + """Get the global PrometheusMetrics instance""" + global _prometheus_metrics + if _prometheus_metrics is None: + _prometheus_metrics = PrometheusMetrics() + return _prometheus_metrics + + +def shutdown_prometheus_metrics(): + """Shutdown the global PrometheusMetrics instance""" + global _prometheus_metrics + _prometheus_metrics = None diff --git a/backend/app/services/ocr_service.py b/backend/app/services/ocr_service.py index 1b57dd7..a419a6b 100644 --- a/backend/app/services/ocr_service.py +++ b/backend/app/services/ocr_service.py @@ -25,6 +25,7 @@ except ImportError: from app.core.config import settings from app.services.office_converter import OfficeConverter, OfficeConverterError +from app.services.memory_manager import get_model_manager, MemoryConfig, MemoryGuard, prediction_context # Import dual-track components try: @@ -96,6 +97,26 @@ class OCRService: self._model_last_used = {} # Track last usage time for each model self._memory_warning_logged = False + # Initialize MemoryGuard for enhanced memory monitoring + self._memory_guard = None + if settings.enable_model_lifecycle_management: + try: + memory_config = MemoryConfig( + warning_threshold=settings.memory_warning_threshold, + critical_threshold=settings.memory_critical_threshold, + emergency_threshold=settings.memory_emergency_threshold, + model_idle_timeout_seconds=settings.pp_structure_idle_timeout_seconds, + gpu_memory_limit_mb=settings.gpu_memory_limit_mb, + enable_cpu_fallback=settings.enable_cpu_fallback, + ) + self._memory_guard = MemoryGuard(memory_config) + logger.debug("MemoryGuard initialized for OCRService") + except Exception as e: + logger.warning(f"Failed to initialize MemoryGuard: {e}") + + # Track if CPU fallback was activated + self._cpu_fallback_active = False + self._detect_and_configure_gpu() # Log GPU optimization settings @@ -217,53 +238,91 @@ class OCRService: def _check_gpu_memory_usage(self): """ Check GPU memory usage and log warnings if approaching limits. - Implements memory optimization for RTX 4060 8GB. + Uses MemoryGuard for enhanced monitoring with multiple backends. """ if not self.use_gpu or not settings.enable_memory_optimization: return try: - device_id = self.gpu_info.get('device_id', 0) - memory_allocated = paddle.device.cuda.memory_allocated(device_id) - memory_allocated_mb = memory_allocated / (1024**2) - memory_limit_mb = settings.gpu_memory_limit_mb + # Use MemoryGuard if available for better monitoring + if self._memory_guard: + stats = self._memory_guard.get_memory_stats() - utilization = (memory_allocated_mb / memory_limit_mb * 100) if memory_limit_mb > 0 else 0 + # Log based on usage ratio + if stats.gpu_used_ratio > 0.90 and not self._memory_warning_logged: + logger.warning( + f"GPU memory usage critical: {stats.gpu_used_mb:.0f}MB / {stats.gpu_total_mb:.0f}MB " + f"({stats.gpu_used_ratio*100:.1f}%)" + ) + logger.warning("Consider enabling auto_unload_unused_models or reducing batch size") + self._memory_warning_logged = True - if utilization > 90 and not self._memory_warning_logged: - logger.warning(f"GPU memory usage high: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)") - logger.warning("Consider enabling auto_unload_unused_models or reducing batch size") - self._memory_warning_logged = True - elif utilization > 75: - logger.info(f"GPU memory: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)") + # Trigger emergency cleanup if enabled + if settings.enable_emergency_cleanup: + self._cleanup_unused_models() + self._memory_guard.clear_gpu_cache() + + elif stats.gpu_used_ratio > 0.75: + logger.info( + f"GPU memory: {stats.gpu_used_mb:.0f}MB / {stats.gpu_total_mb:.0f}MB " + f"({stats.gpu_used_ratio*100:.1f}%)" + ) + else: + # Fallback to original implementation + device_id = self.gpu_info.get('device_id', 0) + memory_allocated = paddle.device.cuda.memory_allocated(device_id) + memory_allocated_mb = memory_allocated / (1024**2) + memory_limit_mb = settings.gpu_memory_limit_mb + + utilization = (memory_allocated_mb / memory_limit_mb * 100) if memory_limit_mb > 0 else 0 + + if utilization > 90 and not self._memory_warning_logged: + logger.warning(f"GPU memory usage high: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)") + logger.warning("Consider enabling auto_unload_unused_models or reducing batch size") + self._memory_warning_logged = True + elif utilization > 75: + logger.info(f"GPU memory: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)") except Exception as e: logger.debug(f"Memory check failed: {e}") def _cleanup_unused_models(self): """ - Clean up unused language models to free GPU memory. + Clean up unused models (including PP-StructureV3) to free GPU memory. Models idle longer than model_idle_timeout_seconds will be unloaded. + + Note: PP-StructureV3 is NO LONGER exempted from cleanup - it will be + unloaded based on pp_structure_idle_timeout_seconds configuration. """ if not settings.auto_unload_unused_models: return current_time = datetime.now() - timeout = settings.model_idle_timeout_seconds models_to_remove = [] for lang, last_used in self._model_last_used.items(): - if lang == 'structure': # Don't unload structure engine - continue + # Use different timeout for structure engine vs language models + if lang == 'structure': + timeout = settings.pp_structure_idle_timeout_seconds + else: + timeout = settings.model_idle_timeout_seconds + idle_seconds = (current_time - last_used).total_seconds() if idle_seconds > timeout: models_to_remove.append(lang) - for lang in models_to_remove: - if lang in self.ocr_engines: - logger.info(f"Unloading idle OCR engine for {lang} (idle {timeout}s)") - del self.ocr_engines[lang] - del self._model_last_used[lang] + for model_key in models_to_remove: + if model_key == 'structure': + if self.structure_engine is not None: + logger.info(f"Unloading idle PP-StructureV3 engine (idle {settings.pp_structure_idle_timeout_seconds}s)") + self._unload_structure_engine() + if model_key in self._model_last_used: + del self._model_last_used[model_key] + elif model_key in self.ocr_engines: + logger.info(f"Unloading idle OCR engine for {model_key} (idle {settings.model_idle_timeout_seconds}s)") + del self.ocr_engines[model_key] + if model_key in self._model_last_used: + del self._model_last_used[model_key] if models_to_remove and self.use_gpu: # Clear CUDA cache @@ -273,6 +332,41 @@ class OCRService: except Exception as e: logger.debug(f"Cache clear failed: {e}") + def _unload_structure_engine(self): + """ + Properly unload PP-StructureV3 engine and free GPU memory. + """ + if self.structure_engine is None: + return + + try: + # Clear internal engine components + if hasattr(self.structure_engine, 'table_engine'): + self.structure_engine.table_engine = None + if hasattr(self.structure_engine, 'text_detector'): + self.structure_engine.text_detector = None + if hasattr(self.structure_engine, 'text_recognizer'): + self.structure_engine.text_recognizer = None + if hasattr(self.structure_engine, 'layout_predictor'): + self.structure_engine.layout_predictor = None + + # Delete the engine + del self.structure_engine + self.structure_engine = None + + # Force garbage collection + gc.collect() + + # Clear GPU cache + if self.use_gpu: + paddle.device.cuda.empty_cache() + + logger.info("PP-StructureV3 engine unloaded successfully") + + except Exception as e: + logger.warning(f"Error unloading PP-StructureV3: {e}") + self.structure_engine = None + def clear_gpu_cache(self): """ Manually clear GPU memory cache. @@ -519,46 +613,160 @@ class OCRService: logger.warning(f"GPU memory cleanup failed (non-critical): {e}") # Don't fail the processing if cleanup fails - def check_gpu_memory(self, required_mb: int = 2000) -> bool: + def check_gpu_memory(self, required_mb: int = 2000, enable_fallback: bool = True) -> bool: """ - Check if sufficient GPU memory is available. + Check if sufficient GPU memory is available using MemoryGuard. + + This method now uses MemoryGuard for accurate memory queries across + multiple backends (pynvml, torch, paddle) instead of returning True + blindly for PaddlePaddle-only environments. Args: required_mb: Required memory in MB (default 2000MB for OCR models) + enable_fallback: If True and CPU fallback is enabled, switch to CPU mode + when memory is insufficient instead of returning False Returns: - True if sufficient memory is available or GPU is not used + True if sufficient memory is available, GPU is not used, or CPU fallback activated """ - try: - # Check GPU memory using torch if available, otherwise use PaddlePaddle - free_memory = None + # If not using GPU, always return True + if not self.use_gpu: + return True - if TORCH_AVAILABLE and torch.cuda.is_available(): - free_memory = torch.cuda.mem_get_info()[0] / 1024**2 - elif paddle.device.is_compiled_with_cuda(): - # PaddlePaddle doesn't have direct API to get free memory, - # so we rely on cleanup and continue - logger.debug("Using PaddlePaddle GPU, memory info not directly available") + try: + # Use MemoryGuard if available for accurate multi-backend memory queries + if self._memory_guard: + is_available, stats = self._memory_guard.check_memory( + required_mb=required_mb, + device_id=self.gpu_info.get('device_id', 0) + ) + + if not is_available: + logger.warning( + f"GPU memory check failed: {stats.gpu_free_mb:.0f}MB free, " + f"{required_mb}MB required ({stats.gpu_used_ratio*100:.1f}% used)" + ) + + # Try to free memory + logger.info("Attempting memory cleanup before retry...") + self._cleanup_unused_models() + self._memory_guard.clear_gpu_cache() + + # Check again + is_available, stats = self._memory_guard.check_memory(required_mb=required_mb) + + if not is_available: + # Memory still insufficient after cleanup + if enable_fallback and settings.enable_cpu_fallback: + logger.warning( + f"Insufficient GPU memory ({stats.gpu_free_mb:.0f}MB) after cleanup. " + f"Activating CPU fallback mode." + ) + self._activate_cpu_fallback() + return True # Continue with CPU + else: + logger.error( + f"Insufficient GPU memory: {stats.gpu_free_mb:.0f}MB available, " + f"{required_mb}MB required" + ) + return False + + logger.debug( + f"GPU memory check passed: {stats.gpu_free_mb:.0f}MB free " + f"({stats.gpu_used_ratio*100:.1f}% used)" + ) return True - if free_memory is not None: - if free_memory < required_mb: - logger.warning(f"Low GPU memory: {free_memory:.0f}MB available, {required_mb}MB required") - # Try to free memory - self.cleanup_gpu_memory() - # Check again - if TORCH_AVAILABLE and torch.cuda.is_available(): - free_memory = torch.cuda.mem_get_info()[0] / 1024**2 - if free_memory < required_mb: - logger.error(f"Insufficient GPU memory after cleanup: {free_memory:.0f}MB") - return False - logger.debug(f"GPU memory check passed: {free_memory:.0f}MB available") + else: + # Fallback to original implementation + free_memory = None + + if TORCH_AVAILABLE and torch.cuda.is_available(): + free_memory = torch.cuda.mem_get_info()[0] / 1024**2 + elif paddle.device.is_compiled_with_cuda(): + # PaddlePaddle doesn't have direct API to get free memory, + # use allocated memory to estimate + device_id = self.gpu_info.get('device_id', 0) + allocated = paddle.device.cuda.memory_allocated(device_id) / (1024**2) + total = settings.gpu_memory_limit_mb + free_memory = max(0, total - allocated) + logger.debug(f"Estimated free GPU memory: {free_memory:.0f}MB (total: {total}MB, allocated: {allocated:.0f}MB)") + + if free_memory is not None: + if free_memory < required_mb: + logger.warning(f"Low GPU memory: {free_memory:.0f}MB available, {required_mb}MB required") + self.cleanup_gpu_memory() + + # Recheck + if TORCH_AVAILABLE and torch.cuda.is_available(): + free_memory = torch.cuda.mem_get_info()[0] / 1024**2 + else: + allocated = paddle.device.cuda.memory_allocated(device_id) / (1024**2) + free_memory = max(0, total - allocated) + + if free_memory < required_mb: + if enable_fallback and settings.enable_cpu_fallback: + logger.warning(f"Insufficient GPU memory after cleanup. Activating CPU fallback.") + self._activate_cpu_fallback() + return True + else: + logger.error(f"Insufficient GPU memory after cleanup: {free_memory:.0f}MB") + return False + + logger.debug(f"GPU memory check passed: {free_memory:.0f}MB available") + + return True - return True except Exception as e: logger.warning(f"GPU memory check failed: {e}") return True # Continue processing even if check fails + def _activate_cpu_fallback(self): + """ + Activate CPU fallback mode when GPU memory is insufficient. + This disables GPU usage for the current service instance. + """ + if self._cpu_fallback_active: + return # Already in CPU mode + + logger.warning("=== CPU FALLBACK MODE ACTIVATED ===") + logger.warning("GPU memory insufficient, switching to CPU processing") + logger.warning("Performance will be significantly reduced") + + self._cpu_fallback_active = True + self.use_gpu = False + + # Update GPU info to reflect fallback + self.gpu_info['cpu_fallback'] = True + self.gpu_info['fallback_reason'] = 'GPU memory insufficient' + + # Clear GPU cache to free memory + if self._memory_guard: + self._memory_guard.clear_gpu_cache() + + def _restore_gpu_mode(self): + """ + Attempt to restore GPU mode after CPU fallback. + Called when memory pressure has been relieved. + """ + if not self._cpu_fallback_active: + return + + if not self.gpu_available: + return + + # Check if GPU memory is now available + if self._memory_guard: + is_available, stats = self._memory_guard.check_memory( + required_mb=settings.structure_model_memory_mb + ) + if is_available: + logger.info("GPU memory available, restoring GPU mode") + self._cpu_fallback_active = False + self.use_gpu = True + self.gpu_info.pop('cpu_fallback', None) + self.gpu_info.pop('fallback_reason', None) + def convert_pdf_to_images(self, pdf_path: Path, output_dir: Path) -> List[Path]: """ Convert PDF to images (one per page) @@ -626,6 +834,24 @@ class OCRService: threshold = confidence_threshold if confidence_threshold is not None else self.confidence_threshold try: + # Pre-operation memory check: Try to restore GPU if in fallback and memory available + if self._cpu_fallback_active: + self._restore_gpu_mode() + if not self._cpu_fallback_active: + logger.info("GPU mode restored for processing") + + # Initial memory check before starting any heavy processing + # Estimate memory requirement based on image type + estimated_memory_mb = 2500 # Conservative estimate for full OCR + layout + if detect_layout: + estimated_memory_mb += 500 # Additional for PP-StructureV3 + + if not self.check_gpu_memory(required_mb=estimated_memory_mb, enable_fallback=True): + logger.warning( + f"Pre-operation memory check failed ({estimated_memory_mb}MB required). " + f"Processing will attempt to proceed but may encounter issues." + ) + # Check if file is Office document if self.office_converter.is_office_document(image_path): logger.info(f"Detected Office document: {image_path.name}, converting to PDF") @@ -748,9 +974,12 @@ class OCRService: # Get OCR engine (for non-PDF images) ocr_engine = self.get_ocr_engine(lang) - # Check GPU memory before OCR processing - if not self.check_gpu_memory(required_mb=1500): - logger.warning("Insufficient GPU memory for OCR, attempting to proceed anyway") + # Secondary memory check before OCR processing + if not self.check_gpu_memory(required_mb=1500, enable_fallback=True): + logger.warning( + f"OCR memory check: insufficient GPU memory (1500MB required). " + f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU (low memory)'}" + ) # Get the actual image dimensions that OCR will use from PIL import Image @@ -950,6 +1179,18 @@ class OCRService: Tuple of (layout_data, images_metadata) """ try: + # Pre-operation memory check for layout analysis + if self._cpu_fallback_active: + self._restore_gpu_mode() + if not self._cpu_fallback_active: + logger.info("GPU mode restored for layout analysis") + + if not self.check_gpu_memory(required_mb=2000, enable_fallback=True): + logger.warning( + f"Layout analysis pre-check: insufficient GPU memory (2000MB required). " + f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU'}" + ) + structure_engine = self._ensure_structure_engine(pp_structure_params) # Try enhanced processing first @@ -998,11 +1239,21 @@ class OCRService: # Standard processing (original implementation) logger.info(f"Running standard layout analysis on {image_path.name}") - # Check GPU memory before processing - if not self.check_gpu_memory(required_mb=2000): - logger.warning("Insufficient GPU memory for PP-StructureV3, attempting to proceed anyway") + # Memory check before PP-StructureV3 processing + if not self.check_gpu_memory(required_mb=2000, enable_fallback=True): + logger.warning( + f"PP-StructureV3 memory check: insufficient GPU memory (2000MB required). " + f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU (low memory)'}" + ) - results = structure_engine.predict(str(image_path)) + # Use prediction semaphore to control concurrent predictions + # This prevents OOM errors from multiple simultaneous PP-StructureV3.predict() calls + with prediction_context(timeout=settings.service_acquire_timeout_seconds) as acquired: + if not acquired: + logger.error("Failed to acquire prediction slot (timeout), returning empty layout") + return None, [] + + results = structure_engine.predict(str(image_path)) layout_elements = [] images_metadata = [] @@ -1254,6 +1505,46 @@ class OCRService: if temp_pdf_path: unified_doc.metadata.original_filename = file_path.name + # HYBRID MODE: Check if Direct track missed images (e.g., inline image blocks) + # If so, use OCR to extract images and merge them into the Direct result + pages_with_missing_images = self.direct_extraction_engine.check_document_for_missing_images( + actual_file_path + ) + if pages_with_missing_images: + logger.info(f"Hybrid mode: Direct track missing images on pages {pages_with_missing_images}, using OCR to extract images") + try: + # Run OCR on the file to extract images + ocr_result = self.process_file_traditional( + actual_file_path, lang, detect_layout=True, + confidence_threshold=confidence_threshold, + output_dir=output_dir, pp_structure_params=pp_structure_params + ) + + # Convert OCR result to extract images + ocr_unified = self.ocr_to_unified_converter.convert( + ocr_result, actual_file_path, 0.0, lang + ) + + # Merge OCR-extracted images into Direct track result + images_added = self._merge_ocr_images_into_direct( + unified_doc, ocr_unified, pages_with_missing_images + ) + if images_added > 0: + logger.info(f"Hybrid mode: Added {images_added} images from OCR to Direct track result") + unified_doc.metadata.processing_track = ProcessingTrack.HYBRID + else: + # Fallback: OCR didn't find images either, render inline image blocks directly + logger.info("Hybrid mode: OCR didn't find images, falling back to inline image rendering") + images_added = self.direct_extraction_engine.render_inline_image_regions( + actual_file_path, unified_doc, pages_with_missing_images, output_dir + ) + if images_added > 0: + logger.info(f"Hybrid mode: Rendered {images_added} inline image regions") + unified_doc.metadata.processing_track = ProcessingTrack.HYBRID + except Exception as e: + logger.warning(f"Hybrid mode image extraction failed: {e}") + # Continue with Direct track result without images + # Use OCR track (either by recommendation or fallback) if recommendation.track == "ocr": # Use OCR for scanned documents, images, etc. @@ -1269,17 +1560,19 @@ class OCRService: ) unified_doc.document_id = document_id - # Update processing track metadata - unified_doc.metadata.processing_track = ( - ProcessingTrack.DIRECT if recommendation.track == "direct" - else ProcessingTrack.OCR - ) + # Update processing track metadata (only if not already set to HYBRID) + if unified_doc.metadata.processing_track != ProcessingTrack.HYBRID: + unified_doc.metadata.processing_track = ( + ProcessingTrack.DIRECT if recommendation.track == "direct" + else ProcessingTrack.OCR + ) # Calculate total processing time processing_time = (datetime.now() - start_time).total_seconds() unified_doc.metadata.processing_time = processing_time - logger.info(f"Document processing completed in {processing_time:.2f}s using {recommendation.track} track") + actual_track = unified_doc.metadata.processing_track.value + logger.info(f"Document processing completed in {processing_time:.2f}s using {actual_track} track") return unified_doc @@ -1290,6 +1583,75 @@ class OCRService: file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params ) + def _merge_ocr_images_into_direct( + self, + direct_doc: 'UnifiedDocument', + ocr_doc: 'UnifiedDocument', + pages_with_missing_images: List[int] + ) -> int: + """ + Merge OCR-extracted images into Direct track result. + + This is used in hybrid mode when Direct track couldn't extract certain + images (like logos composed of inline image blocks). + + Args: + direct_doc: UnifiedDocument from Direct track + ocr_doc: UnifiedDocument from OCR track + pages_with_missing_images: List of page numbers (1-indexed) that need images + + Returns: + Number of images added + """ + images_added = 0 + + try: + # Get image element types to look for + image_types = {ElementType.FIGURE, ElementType.IMAGE, ElementType.LOGO} + + for page_num in pages_with_missing_images: + # Find the target page in direct_doc + direct_page = None + for page in direct_doc.pages: + if page.page_number == page_num: + direct_page = page + break + + if not direct_page: + continue + + # Find the source page in ocr_doc + ocr_page = None + for page in ocr_doc.pages: + if page.page_number == page_num: + ocr_page = page + break + + if not ocr_page: + continue + + # Extract image elements from OCR page + for element in ocr_page.elements: + if element.type in image_types: + # Assign new element ID to avoid conflicts + new_element_id = f"hybrid_{element.element_id}" + element.element_id = new_element_id + + # Add to direct page + direct_page.elements.append(element) + images_added += 1 + logger.debug(f"Added image element {new_element_id} to page {page_num}") + + # Update image count in direct_doc metadata + if images_added > 0: + current_images = direct_doc.metadata.total_images or 0 + direct_doc.metadata.total_images = current_images + images_added + + except Exception as e: + logger.error(f"Error merging OCR images into Direct track: {e}") + + return images_added + def process_file_traditional( self, file_path: Path, @@ -1441,13 +1803,16 @@ class OCRService: UnifiedDocument if dual-track is enabled and use_dual_track=True, Dict with legacy format otherwise """ - if use_dual_track and self.dual_track_enabled: - # Use dual-track processing + # Use dual-track processing if: + # 1. use_dual_track is True (auto-detection), OR + # 2. force_track is specified (explicit track selection) + if (use_dual_track or force_track) and self.dual_track_enabled: + # Use dual-track processing (or forced track) return self.process_with_dual_track( file_path, lang, detect_layout, confidence_threshold, output_dir, force_track, pp_structure_params ) else: - # Use traditional OCR processing + # Use traditional OCR processing (no force_track support) return self.process_file_traditional( file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params ) diff --git a/backend/app/services/pdf_generator_service.py b/backend/app/services/pdf_generator_service.py index 4a4bf17..5e2e462 100644 --- a/backend/app/services/pdf_generator_service.py +++ b/backend/app/services/pdf_generator_service.py @@ -572,8 +572,10 @@ class PDFGeneratorService: processing_track = unified_doc.metadata.get('processing_track') # Route to track-specific rendering method - is_direct_track = (processing_track == 'direct' or - processing_track == ProcessingTrack.DIRECT) + # ProcessingTrack is (str, Enum), so comparing with enum value works for both string and enum + # HYBRID track uses Direct track rendering (Direct text/tables + OCR images) + is_direct_track = (processing_track == ProcessingTrack.DIRECT or + processing_track == ProcessingTrack.HYBRID) logger.info(f"Processing track: {processing_track}, using {'Direct' if is_direct_track else 'OCR'} track rendering") @@ -675,8 +677,11 @@ class PDFGeneratorService: logger.info("=== Direct Track PDF Generation ===") logger.info(f"Total pages: {len(unified_doc.pages)}") - # Set current track for helper methods - self.current_processing_track = 'direct' + # Set current track for helper methods (may be DIRECT or HYBRID) + if hasattr(unified_doc, 'metadata') and unified_doc.metadata: + self.current_processing_track = unified_doc.metadata.processing_track + else: + self.current_processing_track = ProcessingTrack.DIRECT # Get page dimensions from first page (for canvas initialization) if not unified_doc.pages: @@ -1074,11 +1079,16 @@ class PDFGeneratorService: # *** 優先級 1: 檢查 ocr_dimensions (UnifiedDocument 轉換來的) *** if 'ocr_dimensions' in ocr_data: dims = ocr_data['ocr_dimensions'] - w = float(dims.get('width', 0)) - h = float(dims.get('height', 0)) - if w > 0 and h > 0: - logger.info(f"使用 ocr_dimensions 欄位的頁面尺寸: {w:.1f} x {h:.1f}") - return (w, h) + # Handle both dict format {'width': w, 'height': h} and + # list format [{'page': 1, 'width': w, 'height': h}, ...] + if isinstance(dims, list) and len(dims) > 0: + dims = dims[0] # Use first page dimensions + if isinstance(dims, dict): + w = float(dims.get('width', 0)) + h = float(dims.get('height', 0)) + if w > 0 and h > 0: + logger.info(f"使用 ocr_dimensions 欄位的頁面尺寸: {w:.1f} x {h:.1f}") + return (w, h) # *** 優先級 2: 檢查原始 JSON 的 dimensions *** if 'dimensions' in ocr_data: @@ -1418,8 +1428,8 @@ class PDFGeneratorService: # Set font with track-specific styling # Note: OCR track has no StyleInfo (extracted from images), so no advanced formatting style_info = region.get('style') - is_direct_track = (self.current_processing_track == 'direct' or - self.current_processing_track == ProcessingTrack.DIRECT) + is_direct_track = (self.current_processing_track == ProcessingTrack.DIRECT or + self.current_processing_track == ProcessingTrack.HYBRID) if style_info and is_direct_track: # Direct track: Apply rich styling from StyleInfo @@ -1661,10 +1671,15 @@ class PDFGeneratorService: return # Construct full path to image + # saved_path is relative to result_dir (e.g., "imgs/element_id.png") image_path = result_dir / image_path_str + # Fallback for legacy data if not image_path.exists(): - logger.warning(f"Image not found: {image_path}") + image_path = result_dir / Path(image_path_str).name + + if not image_path.exists(): + logger.warning(f"Image not found: {image_path_str} (in {result_dir})") return # Get bbox for positioning @@ -2289,12 +2304,30 @@ class PDFGeneratorService: col_widths = element.metadata['column_widths'] logger.debug(f"Using extracted column widths: {col_widths}") - # Create table without rowHeights (will use canvas scaling instead) - t = Table(table_content, colWidths=col_widths) + # Use original row heights from extraction if available + # Row heights must match the number of data rows exactly + row_heights_list = None + if element.metadata and 'row_heights' in element.metadata: + extracted_row_heights = element.metadata['row_heights'] + num_data_rows = len(table_content) + num_height_rows = len(extracted_row_heights) + + if num_height_rows == num_data_rows: + row_heights_list = extracted_row_heights + logger.debug(f"Using extracted row heights ({num_height_rows} rows): {row_heights_list}") + else: + # Row counts don't match - this can happen with merged cells or empty rows + logger.warning(f"Row height mismatch: {num_height_rows} heights for {num_data_rows} data rows, falling back to auto-sizing") + + # Create table with both column widths and row heights for accurate sizing + t = Table(table_content, colWidths=col_widths, rowHeights=row_heights_list) # Apply style with minimal padding to reduce table extension + # Use Chinese font to support special characters (℃, μm, ≦, ×, Ω, etc.) + font_for_table = self.font_name if self.font_registered else 'Helvetica' style = TableStyle([ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey), + ('FONTNAME', (0, 0), (-1, -1), font_for_table), ('FONTSIZE', (0, 0), (-1, -1), 8), ('ALIGN', (0, 0), (-1, -1), 'LEFT'), ('VALIGN', (0, 0), (-1, -1), 'TOP'), @@ -2307,8 +2340,8 @@ class PDFGeneratorService: ]) t.setStyle(style) - # CRITICAL: Use canvas scaling to fit table within bbox - # This is more reliable than rowHeights which doesn't always work + # Use canvas scaling as fallback to fit table within bbox + # With proper row heights, scaling should be minimal (close to 1.0) # Step 1: Wrap to get actual rendered size actual_width, actual_height = t.wrapOn(pdf_canvas, table_width * 10, table_height * 10) @@ -2358,11 +2391,16 @@ class PDFGeneratorService: logger.warning(f"No image path for element {element.element_id}") return - # Construct full path + # Construct full path to image + # saved_path is relative to result_dir (e.g., "document_id_p1_img0.png") image_path = result_dir / image_path_str + # Fallback for legacy data if not image_path.exists(): - logger.warning(f"Image not found: {image_path}") + image_path = result_dir / Path(image_path_str).name + + if not image_path.exists(): + logger.warning(f"Image not found: {image_path_str} (in {result_dir})") return # Get bbox @@ -2388,7 +2426,7 @@ class PDFGeneratorService: preserveAspectRatio=True ) - logger.debug(f"Drew image: {image_path_str}") + logger.debug(f"Drew image: {image_path} (from: {original_path_str})") except Exception as e: logger.error(f"Failed to draw image element {element.element_id}: {e}") diff --git a/backend/app/services/pp_structure_enhanced.py b/backend/app/services/pp_structure_enhanced.py index d1f00ea..4331c38 100644 --- a/backend/app/services/pp_structure_enhanced.py +++ b/backend/app/services/pp_structure_enhanced.py @@ -21,6 +21,8 @@ except ImportError: import paddle from paddleocr import PPStructureV3 from app.models.unified_document import ElementType +from app.core.config import settings +from app.services.memory_manager import prediction_context logger = logging.getLogger(__name__) @@ -96,8 +98,22 @@ class PPStructureEnhanced: try: logger.info(f"Enhanced PP-StructureV3 analysis on {image_path.name}") - # Perform structure analysis - results = self.structure_engine.predict(str(image_path)) + # Perform structure analysis with semaphore control + # This prevents OOM errors from multiple simultaneous predictions + with prediction_context(timeout=settings.service_acquire_timeout_seconds) as acquired: + if not acquired: + logger.error("Failed to acquire prediction slot (timeout), returning empty result") + return { + 'has_parsing_res_list': False, + 'elements': [], + 'total_elements': 0, + 'images': [], + 'tables': [], + 'element_types': {}, + 'error': 'Prediction slot timeout' + } + + results = self.structure_engine.predict(str(image_path)) all_elements = [] all_images = [] diff --git a/backend/app/services/service_pool.py b/backend/app/services/service_pool.py new file mode 100644 index 0000000..07db0b1 --- /dev/null +++ b/backend/app/services/service_pool.py @@ -0,0 +1,468 @@ +""" +Tool_OCR - OCR Service Pool +Manages a pool of OCRService instances to prevent duplicate model loading +and control concurrent GPU operations. +""" + +import asyncio +import logging +import threading +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +from app.services.memory_manager import get_model_manager, MemoryConfig + +if TYPE_CHECKING: + from app.services.ocr_service import OCRService + +logger = logging.getLogger(__name__) + + +class ServiceState(Enum): + """State of a pooled service""" + AVAILABLE = "available" + IN_USE = "in_use" + UNHEALTHY = "unhealthy" + INITIALIZING = "initializing" + + +@dataclass +class PooledService: + """Wrapper for a pooled OCRService instance""" + service: Any # OCRService + device: str + state: ServiceState = ServiceState.AVAILABLE + created_at: float = field(default_factory=time.time) + last_used: float = field(default_factory=time.time) + use_count: int = 0 + error_count: int = 0 + current_task_id: Optional[str] = None + + +class PoolConfig: + """Configuration for the service pool""" + + def __init__( + self, + max_services_per_device: int = 1, + max_total_services: int = 2, + acquire_timeout_seconds: float = 300.0, + max_queue_size: int = 50, + health_check_interval_seconds: int = 60, + max_consecutive_errors: int = 3, + service_idle_timeout_seconds: int = 600, + enable_auto_scaling: bool = False, + ): + self.max_services_per_device = max_services_per_device + self.max_total_services = max_total_services + self.acquire_timeout_seconds = acquire_timeout_seconds + self.max_queue_size = max_queue_size + self.health_check_interval_seconds = health_check_interval_seconds + self.max_consecutive_errors = max_consecutive_errors + self.service_idle_timeout_seconds = service_idle_timeout_seconds + self.enable_auto_scaling = enable_auto_scaling + + +class OCRServicePool: + """ + Pool of OCRService instances with concurrency control. + + Features: + - Per-device instance management (one service per GPU) + - Queue-based task distribution + - Semaphore-based concurrency limits + - Health monitoring + - Automatic service recovery + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """Singleton pattern""" + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, config: Optional[PoolConfig] = None): + if self._initialized: + return + + self.config = config or PoolConfig() + self.services: Dict[str, List[PooledService]] = {} + self.semaphores: Dict[str, threading.Semaphore] = {} + self.queues: Dict[str, List] = {} + self._pool_lock = threading.RLock() + self._condition = threading.Condition(self._pool_lock) + + # Metrics + self._metrics = { + "total_acquisitions": 0, + "total_releases": 0, + "total_timeouts": 0, + "total_errors": 0, + "queue_waits": 0, + } + + # Initialize default device pool + self._initialize_device("GPU:0") + + self._initialized = True + logger.info("OCRServicePool initialized") + + def _initialize_device(self, device: str): + """Initialize pool resources for a device""" + with self._pool_lock: + if device not in self.services: + self.services[device] = [] + self.semaphores[device] = threading.Semaphore( + self.config.max_services_per_device + ) + self.queues[device] = [] + logger.info(f"Initialized pool for device {device}") + + def _create_service(self, device: str) -> PooledService: + """ + Create a new OCRService instance for the pool. + + Args: + device: Device identifier (e.g., "GPU:0", "CPU") + + Returns: + PooledService wrapper + """ + # Import here to avoid circular imports + from app.services.ocr_service import OCRService + + logger.info(f"Creating new OCRService for device {device}") + start_time = time.time() + + # Create service instance + service = OCRService() + + creation_time = time.time() - start_time + logger.info(f"OCRService created in {creation_time:.2f}s for device {device}") + + return PooledService( + service=service, + device=device, + state=ServiceState.AVAILABLE + ) + + def acquire( + self, + device: str = "GPU:0", + timeout: Optional[float] = None, + task_id: Optional[str] = None + ) -> Optional[PooledService]: + """ + Acquire an OCRService from the pool. + + Args: + device: Preferred device (e.g., "GPU:0") + timeout: Maximum time to wait for a service + task_id: Optional task ID for tracking + + Returns: + PooledService if available, None if timeout + """ + timeout = timeout or self.config.acquire_timeout_seconds + self._initialize_device(device) + + start_time = time.time() + deadline = start_time + timeout + + with self._condition: + while True: + # Try to get an available service + service = self._try_acquire_service(device, task_id) + if service: + self._metrics["total_acquisitions"] += 1 + return service + + # Check if we can create a new service + if self._can_create_service(device): + try: + pooled = self._create_service(device) + pooled.state = ServiceState.IN_USE + pooled.current_task_id = task_id + pooled.use_count += 1 + self.services[device].append(pooled) + self._metrics["total_acquisitions"] += 1 + logger.info(f"Created and acquired new service for {device}") + return pooled + except Exception as e: + logger.error(f"Failed to create service for {device}: {e}") + self._metrics["total_errors"] += 1 + + # Wait for a service to become available + remaining = deadline - time.time() + if remaining <= 0: + self._metrics["total_timeouts"] += 1 + logger.warning(f"Timeout waiting for service on {device}") + return None + + self._metrics["queue_waits"] += 1 + logger.debug(f"Waiting for service on {device} (timeout: {remaining:.1f}s)") + self._condition.wait(timeout=min(remaining, 1.0)) + + def _try_acquire_service(self, device: str, task_id: Optional[str]) -> Optional[PooledService]: + """Try to acquire an available service without waiting""" + for pooled in self.services.get(device, []): + if pooled.state == ServiceState.AVAILABLE: + pooled.state = ServiceState.IN_USE + pooled.last_used = time.time() + pooled.use_count += 1 + pooled.current_task_id = task_id + logger.debug(f"Acquired existing service for {device} (use #{pooled.use_count})") + return pooled + return None + + def _can_create_service(self, device: str) -> bool: + """Check if a new service can be created""" + device_count = len(self.services.get(device, [])) + total_count = sum(len(services) for services in self.services.values()) + + return ( + device_count < self.config.max_services_per_device and + total_count < self.config.max_total_services + ) + + def release(self, pooled: PooledService, error: Optional[Exception] = None): + """ + Release a service back to the pool. + + Args: + pooled: The pooled service to release + error: Optional error that occurred during use + """ + with self._condition: + if error: + pooled.error_count += 1 + self._metrics["total_errors"] += 1 + logger.warning(f"Service released with error: {error}") + + # Mark unhealthy if too many errors + if pooled.error_count >= self.config.max_consecutive_errors: + pooled.state = ServiceState.UNHEALTHY + logger.error(f"Service marked unhealthy after {pooled.error_count} errors") + else: + pooled.state = ServiceState.AVAILABLE + else: + pooled.error_count = 0 # Reset error count on success + pooled.state = ServiceState.AVAILABLE + + pooled.last_used = time.time() + pooled.current_task_id = None + self._metrics["total_releases"] += 1 + + # Clean up GPU memory after release + try: + model_manager = get_model_manager() + model_manager.memory_guard.clear_gpu_cache() + except Exception as e: + logger.debug(f"Cache clear after release failed: {e}") + + # Notify waiting threads + self._condition.notify_all() + + logger.debug(f"Service released for device {pooled.device}") + + @contextmanager + def acquire_context( + self, + device: str = "GPU:0", + timeout: Optional[float] = None, + task_id: Optional[str] = None + ): + """ + Context manager for acquiring and releasing a service. + + Usage: + with pool.acquire_context("GPU:0") as pooled: + result = pooled.service.process(...) + """ + pooled = None + error = None + try: + pooled = self.acquire(device, timeout, task_id) + if pooled is None: + raise TimeoutError(f"Failed to acquire service for {device}") + yield pooled + except Exception as e: + error = e + raise + finally: + if pooled: + self.release(pooled, error) + + def get_service(self, device: str = "GPU:0") -> Optional["OCRService"]: + """ + Get a service directly (for backward compatibility). + + This acquires a service and returns the underlying OCRService. + The caller is responsible for calling release_service() when done. + + Args: + device: Device identifier + + Returns: + OCRService instance or None + """ + pooled = self.acquire(device) + if pooled: + return pooled.service + return None + + def get_pool_stats(self) -> Dict: + """Get current pool statistics""" + with self._pool_lock: + stats = { + "devices": {}, + "metrics": self._metrics.copy(), + "total_services": 0, + "available_services": 0, + "in_use_services": 0, + } + + for device, services in self.services.items(): + available = sum(1 for s in services if s.state == ServiceState.AVAILABLE) + in_use = sum(1 for s in services if s.state == ServiceState.IN_USE) + unhealthy = sum(1 for s in services if s.state == ServiceState.UNHEALTHY) + + stats["devices"][device] = { + "total": len(services), + "available": available, + "in_use": in_use, + "unhealthy": unhealthy, + "max_allowed": self.config.max_services_per_device, + } + + stats["total_services"] += len(services) + stats["available_services"] += available + stats["in_use_services"] += in_use + + return stats + + def health_check(self) -> Dict: + """ + Perform health check on all pooled services. + + Returns: + Health check results + """ + results = { + "healthy": True, + "services": [], + "timestamp": time.time() + } + + with self._pool_lock: + for device, services in self.services.items(): + for idx, pooled in enumerate(services): + service_health = { + "device": device, + "index": idx, + "state": pooled.state.value, + "error_count": pooled.error_count, + "use_count": pooled.use_count, + "idle_seconds": time.time() - pooled.last_used, + } + + # Check if service is responsive + if pooled.state == ServiceState.AVAILABLE: + try: + # Simple check - verify service has required attributes + has_process = hasattr(pooled.service, 'process') + has_gpu_status = hasattr(pooled.service, 'get_gpu_status') + service_health["responsive"] = has_process and has_gpu_status + except Exception as e: + service_health["responsive"] = False + service_health["error"] = str(e) + results["healthy"] = False + else: + service_health["responsive"] = pooled.state != ServiceState.UNHEALTHY + + if pooled.state == ServiceState.UNHEALTHY: + results["healthy"] = False + + results["services"].append(service_health) + + return results + + def recover_unhealthy(self): + """ + Attempt to recover unhealthy services. + """ + with self._pool_lock: + for device, services in self.services.items(): + for idx, pooled in enumerate(services): + if pooled.state == ServiceState.UNHEALTHY: + logger.info(f"Attempting to recover unhealthy service {device}:{idx}") + try: + # Remove old service + services.remove(pooled) + + # Create new service + new_pooled = self._create_service(device) + services.append(new_pooled) + logger.info(f"Successfully recovered service {device}:{idx}") + + except Exception as e: + logger.error(f"Failed to recover service {device}:{idx}: {e}") + + def shutdown(self): + """ + Shutdown the pool and cleanup all services. + """ + logger.info("OCRServicePool shutdown started") + + with self._pool_lock: + for device, services in self.services.items(): + for pooled in services: + try: + # Clean up service resources + if hasattr(pooled.service, 'cleanup_gpu_memory'): + pooled.service.cleanup_gpu_memory() + except Exception as e: + logger.warning(f"Error cleaning up service: {e}") + + # Clear all pools + self.services.clear() + self.semaphores.clear() + self.queues.clear() + + logger.info("OCRServicePool shutdown completed") + + +# Global singleton instance +_service_pool: Optional[OCRServicePool] = None + + +def get_service_pool(config: Optional[PoolConfig] = None) -> OCRServicePool: + """ + Get the global OCRServicePool instance. + + Args: + config: Optional configuration (only used on first call) + + Returns: + OCRServicePool singleton instance + """ + global _service_pool + if _service_pool is None: + _service_pool = OCRServicePool(config) + return _service_pool + + +def shutdown_service_pool(): + """Shutdown the global service pool""" + global _service_pool + if _service_pool is not None: + _service_pool.shutdown() + _service_pool = None diff --git a/backend/tests/services/test_memory_manager.py b/backend/tests/services/test_memory_manager.py new file mode 100644 index 0000000..8f02096 --- /dev/null +++ b/backend/tests/services/test_memory_manager.py @@ -0,0 +1,1986 @@ +""" +Tests for Memory Management Components + +Tests ModelManager, MemoryGuard, and related functionality. +""" + +import gc +import pytest +import threading +import time +from unittest.mock import Mock, patch, MagicMock +import sys + +# Mock paddle before importing memory_manager to avoid import errors +# when paddle is not installed in the test environment +paddle_mock = MagicMock() +paddle_mock.is_compiled_with_cuda.return_value = False +paddle_mock.device.cuda.device_count.return_value = 0 +paddle_mock.device.cuda.memory_allocated.return_value = 0 +paddle_mock.device.cuda.memory_reserved.return_value = 0 +paddle_mock.device.cuda.empty_cache = MagicMock() +sys.modules['paddle'] = paddle_mock + +from app.services.memory_manager import ( + ModelManager, + ModelEntry, + MemoryGuard, + MemoryConfig, + MemoryStats, + MemoryBackend, + get_model_manager, + shutdown_model_manager, +) + + +class TestMemoryConfig: + """Tests for MemoryConfig class""" + + def test_default_values(self): + """Test default configuration values""" + config = MemoryConfig() + assert config.warning_threshold == 0.80 + assert config.critical_threshold == 0.95 + assert config.emergency_threshold == 0.98 + assert config.model_idle_timeout_seconds == 300 + assert config.enable_auto_cleanup is True + assert config.max_concurrent_predictions == 2 + + def test_custom_values(self): + """Test custom configuration values""" + config = MemoryConfig( + warning_threshold=0.70, + critical_threshold=0.85, + model_idle_timeout_seconds=600, + ) + assert config.warning_threshold == 0.70 + assert config.critical_threshold == 0.85 + assert config.model_idle_timeout_seconds == 600 + + +class TestMemoryGuard: + """Tests for MemoryGuard class""" + + def setup_method(self): + """Setup for each test""" + self.config = MemoryConfig( + warning_threshold=0.80, + critical_threshold=0.95, + emergency_threshold=0.98, + ) + + def test_initialization(self): + """Test MemoryGuard initialization""" + guard = MemoryGuard(self.config) + assert guard.config == self.config + assert guard.backend is not None + guard.shutdown() + + def test_get_memory_stats(self): + """Test getting memory statistics""" + guard = MemoryGuard(self.config) + stats = guard.get_memory_stats() + assert isinstance(stats, MemoryStats) + assert stats.timestamp > 0 + guard.shutdown() + + def test_check_memory_below_warning(self): + """Test memory check when below warning threshold""" + guard = MemoryGuard(self.config) + + # Mock stats to be below warning + with patch.object(guard, 'get_memory_stats') as mock_stats: + mock_stats.return_value = MemoryStats( + gpu_used_ratio=0.50, + gpu_free_mb=4000, + gpu_total_mb=8000, + ) + is_available, stats = guard.check_memory(required_mb=1000) + assert is_available is True + + guard.shutdown() + + def test_check_memory_above_warning(self): + """Test memory check when above warning threshold""" + guard = MemoryGuard(self.config) + + # Mock stats to be above warning + with patch.object(guard, 'get_memory_stats') as mock_stats: + mock_stats.return_value = MemoryStats( + gpu_used_ratio=0.85, + gpu_free_mb=1200, + gpu_total_mb=8000, + ) + is_available, stats = guard.check_memory(required_mb=500) + # Should still return True (warning, not critical) + assert is_available is True + + guard.shutdown() + + def test_check_memory_above_critical(self): + """Test memory check when above critical threshold""" + guard = MemoryGuard(self.config) + + # Mock stats to be above critical + with patch.object(guard, 'get_memory_stats') as mock_stats: + mock_stats.return_value = MemoryStats( + gpu_used_ratio=0.96, + gpu_free_mb=320, + gpu_total_mb=8000, + ) + is_available, stats = guard.check_memory(required_mb=100) + # Should return False (critical) + assert is_available is False + + guard.shutdown() + + def test_check_memory_insufficient_free(self): + """Test memory check when insufficient free memory""" + guard = MemoryGuard(self.config) + + # Mock stats with insufficient free memory + with patch.object(guard, 'get_memory_stats') as mock_stats: + mock_stats.return_value = MemoryStats( + gpu_used_ratio=0.70, + gpu_free_mb=500, + gpu_total_mb=8000, + ) + is_available, stats = guard.check_memory(required_mb=1000) + # Should return False (not enough free) + assert is_available is False + + guard.shutdown() + + def test_alert_history(self): + """Test alert history functionality""" + guard = MemoryGuard(self.config) + + # Trigger some alerts + guard._add_alert("warning", "Test warning") + guard._add_alert("critical", "Test critical") + + alerts = guard.get_alerts() + assert len(alerts) == 2 + assert alerts[0]["level"] == "warning" + assert alerts[1]["level"] == "critical" + + guard.shutdown() + + def test_clear_gpu_cache(self): + """Test GPU cache clearing""" + guard = MemoryGuard(self.config) + # Should not raise even if no GPU + guard.clear_gpu_cache() + guard.shutdown() + + +class TestModelManager: + """Tests for ModelManager class""" + + def setup_method(self): + """Reset singleton before each test""" + shutdown_model_manager() + ModelManager._instance = None + ModelManager._lock = threading.Lock() + + def teardown_method(self): + """Cleanup after each test""" + shutdown_model_manager() + ModelManager._instance = None + + def test_singleton_pattern(self): + """Test that ModelManager is a singleton""" + config = MemoryConfig() + manager1 = ModelManager(config) + manager2 = ModelManager() + assert manager1 is manager2 + manager1.teardown() + + def test_get_or_load_model_new(self): + """Test loading a new model""" + config = MemoryConfig() + manager = ModelManager(config) + + mock_model = Mock() + loader_called = False + + def loader(): + nonlocal loader_called + loader_called = True + return mock_model + + model = manager.get_or_load_model( + model_id="test_model", + loader_func=loader, + estimated_memory_mb=100 + ) + + assert model is mock_model + assert loader_called is True + assert "test_model" in manager.models + assert manager.models["test_model"].ref_count == 1 + + manager.teardown() + + def test_get_or_load_model_cached(self): + """Test getting a cached model""" + config = MemoryConfig() + manager = ModelManager(config) + + mock_model = Mock() + load_count = 0 + + def loader(): + nonlocal load_count + load_count += 1 + return mock_model + + # First load + model1 = manager.get_or_load_model("test_model", loader) + # Second load (should return cached) + model2 = manager.get_or_load_model("test_model", loader) + + assert model1 is model2 + assert load_count == 1 # Loader should only be called once + assert manager.models["test_model"].ref_count == 2 + + manager.teardown() + + def test_release_model(self): + """Test releasing model references""" + config = MemoryConfig() + manager = ModelManager(config) + + model = manager.get_or_load_model("test_model", lambda: Mock()) + assert manager.models["test_model"].ref_count == 1 + + manager.release_model("test_model") + assert manager.models["test_model"].ref_count == 0 + + manager.teardown() + + def test_unload_model(self): + """Test unloading a model""" + config = MemoryConfig() + manager = ModelManager(config) + + manager.get_or_load_model("test_model", lambda: Mock()) + manager.release_model("test_model") + + success = manager.unload_model("test_model") + assert success is True + assert "test_model" not in manager.models + + manager.teardown() + + def test_unload_model_with_references(self): + """Test that model with active references is not unloaded""" + config = MemoryConfig() + manager = ModelManager(config) + + manager.get_or_load_model("test_model", lambda: Mock()) + # Don't release - ref_count is still 1 + + success = manager.unload_model("test_model", force=False) + assert success is False + assert "test_model" in manager.models + + manager.teardown() + + def test_unload_model_force(self): + """Test force unloading a model with references""" + config = MemoryConfig() + manager = ModelManager(config) + + manager.get_or_load_model("test_model", lambda: Mock()) + + success = manager.unload_model("test_model", force=True) + assert success is True + assert "test_model" not in manager.models + + manager.teardown() + + def test_cleanup_callback(self): + """Test cleanup callback is called on unload""" + config = MemoryConfig() + manager = ModelManager(config) + + cleanup_called = False + def cleanup(): + nonlocal cleanup_called + cleanup_called = True + + manager.get_or_load_model( + "test_model", + lambda: Mock(), + cleanup_callback=cleanup + ) + manager.release_model("test_model") + manager.unload_model("test_model") + + assert cleanup_called is True + manager.teardown() + + def test_get_model_stats(self): + """Test getting model statistics""" + config = MemoryConfig() + manager = ModelManager(config) + + manager.get_or_load_model("model1", lambda: Mock(), estimated_memory_mb=100) + manager.get_or_load_model("model2", lambda: Mock(), estimated_memory_mb=200) + + stats = manager.get_model_stats() + assert stats["total_models"] == 2 + assert "model1" in stats["models"] + assert "model2" in stats["models"] + assert stats["total_estimated_memory_mb"] == 300 + + manager.teardown() + + def test_idle_cleanup(self): + """Test idle model cleanup""" + config = MemoryConfig( + model_idle_timeout_seconds=1, # Short timeout for testing + memory_check_interval_seconds=60, # Don't auto-cleanup + ) + manager = ModelManager(config) + + manager.get_or_load_model("test_model", lambda: Mock()) + manager.release_model("test_model") + + # Manually set last_used to simulate idle + manager.models["test_model"].last_used = time.time() - 10 + + # Manually trigger cleanup + manager._cleanup_idle_models() + + assert "test_model" not in manager.models + manager.teardown() + + +class TestGetModelManager: + """Tests for get_model_manager helper function""" + + def setup_method(self): + """Reset singleton before each test""" + shutdown_model_manager() + ModelManager._instance = None + + def teardown_method(self): + """Cleanup after each test""" + shutdown_model_manager() + ModelManager._instance = None + + def test_get_model_manager_creates_singleton(self): + """Test that get_model_manager creates a singleton""" + manager1 = get_model_manager() + manager2 = get_model_manager() + assert manager1 is manager2 + shutdown_model_manager() + + def test_shutdown_model_manager(self): + """Test shutdown_model_manager cleans up""" + manager = get_model_manager() + manager.get_or_load_model("test", lambda: Mock()) + + shutdown_model_manager() + + # Should be able to create new manager + new_manager = get_model_manager() + assert "test" not in new_manager.models + shutdown_model_manager() + + +class TestConcurrency: + """Tests for concurrent access""" + + def setup_method(self): + shutdown_model_manager() + ModelManager._instance = None + + def teardown_method(self): + shutdown_model_manager() + ModelManager._instance = None + + def test_concurrent_model_access(self): + """Test concurrent model loading""" + config = MemoryConfig() + manager = ModelManager(config) + + load_count = 0 + lock = threading.Lock() + + def loader(): + nonlocal load_count + with lock: + load_count += 1 + time.sleep(0.1) # Simulate slow load + return Mock() + + results = [] + + def worker(): + model = manager.get_or_load_model("shared_model", loader) + results.append(model) + + threads = [threading.Thread(target=worker) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All threads should get the same model + assert len(set(id(r) for r in results)) == 1 + # Loader should only be called once + assert load_count == 1 + # Ref count should match thread count + assert manager.models["shared_model"].ref_count == 5 + + manager.teardown() + + +class TestPredictionSemaphore: + """Tests for PredictionSemaphore class""" + + def setup_method(self): + """Reset singleton before each test""" + from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore + shutdown_prediction_semaphore() + PredictionSemaphore._instance = None + PredictionSemaphore._lock = threading.Lock() + + def teardown_method(self): + """Cleanup after each test""" + from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore + shutdown_prediction_semaphore() + PredictionSemaphore._instance = None + + def test_singleton_pattern(self): + """Test that PredictionSemaphore is a singleton""" + from app.services.memory_manager import PredictionSemaphore + sem1 = PredictionSemaphore(max_concurrent=2) + sem2 = PredictionSemaphore(max_concurrent=4) # Different config should be ignored + assert sem1 is sem2 + assert sem1._max_concurrent == 2 + + def test_acquire_release(self): + """Test basic acquire and release""" + from app.services.memory_manager import PredictionSemaphore + sem = PredictionSemaphore(max_concurrent=2) + + # Acquire first slot + assert sem.acquire(task_id="task1") is True + assert sem._active_predictions == 1 + + # Acquire second slot + assert sem.acquire(task_id="task2") is True + assert sem._active_predictions == 2 + + # Release one + sem.release(task_id="task1") + assert sem._active_predictions == 1 + + # Release another + sem.release(task_id="task2") + assert sem._active_predictions == 0 + + def test_acquire_blocks_when_full(self): + """Test that acquire blocks when all slots are taken""" + from app.services.memory_manager import PredictionSemaphore + sem = PredictionSemaphore(max_concurrent=1) + + # Acquire the only slot + assert sem.acquire(task_id="task1") is True + + # Try to acquire another with short timeout - should fail + result = sem.acquire(timeout=0.1, task_id="task2") + assert result is False + assert sem._total_timeouts == 1 + + # Release first slot + sem.release(task_id="task1") + + # Now should succeed + assert sem.acquire(task_id="task3") is True + + def test_get_stats(self): + """Test statistics tracking""" + from app.services.memory_manager import PredictionSemaphore + sem = PredictionSemaphore(max_concurrent=2) + + sem.acquire(task_id="task1") + sem.acquire(task_id="task2") + sem.release(task_id="task1") + + stats = sem.get_stats() + assert stats["max_concurrent"] == 2 + assert stats["active_predictions"] == 1 + assert stats["total_predictions"] == 2 + assert stats["total_timeouts"] == 0 + + def test_concurrent_acquire(self): + """Test concurrent access to semaphore""" + from app.services.memory_manager import PredictionSemaphore + sem = PredictionSemaphore(max_concurrent=2) + + results = [] + acquired_count = 0 + lock = threading.Lock() + + def worker(worker_id): + nonlocal acquired_count + if sem.acquire(timeout=1.0, task_id=f"task_{worker_id}"): + with lock: + acquired_count += 1 + time.sleep(0.1) # Simulate work + sem.release(task_id=f"task_{worker_id}") + results.append(worker_id) + + # Start 4 workers but only 2 slots + threads = [threading.Thread(target=worker, args=(i,)) for i in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All should complete eventually + assert len(results) == 4 + + +class TestPredictionContext: + """Tests for prediction_context helper function""" + + def setup_method(self): + """Reset singleton before each test""" + from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore + shutdown_prediction_semaphore() + PredictionSemaphore._instance = None + PredictionSemaphore._lock = threading.Lock() + + def teardown_method(self): + """Cleanup after each test""" + from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore + shutdown_prediction_semaphore() + PredictionSemaphore._instance = None + + def test_context_manager_success(self): + """Test context manager for successful prediction""" + from app.services.memory_manager import prediction_context, get_prediction_semaphore + + # Initialize semaphore + sem = get_prediction_semaphore(max_concurrent=2) + + with prediction_context(timeout=5.0, task_id="test") as acquired: + assert acquired is True + assert sem._active_predictions == 1 + + # After exiting context, slot should be released + assert sem._active_predictions == 0 + + def test_context_manager_with_exception(self): + """Test context manager releases on exception""" + from app.services.memory_manager import prediction_context, get_prediction_semaphore + + sem = get_prediction_semaphore(max_concurrent=2) + + with pytest.raises(ValueError): + with prediction_context(task_id="test") as acquired: + assert acquired is True + raise ValueError("Test error") + + # Should still release the slot + assert sem._active_predictions == 0 + + def test_context_manager_timeout(self): + """Test context manager when timeout occurs""" + from app.services.memory_manager import prediction_context, get_prediction_semaphore + + sem = get_prediction_semaphore(max_concurrent=1) + + # Acquire the only slot + sem.acquire(task_id="blocker") + + # Context manager should timeout + with prediction_context(timeout=0.1, task_id="waiter") as acquired: + assert acquired is False + + # Release blocker + sem.release(task_id="blocker") + + +class TestBatchProcessor: + """Tests for BatchProcessor class""" + + def test_add_item(self): + """Test adding items to batch queue""" + from app.services.memory_manager import BatchProcessor, BatchItem, BatchPriority + + processor = BatchProcessor(max_batch_size=5) + item = BatchItem(item_id="test1", data="data1", priority=BatchPriority.NORMAL) + processor.add_item(item) + + assert processor.get_queue_size() == 1 + + def test_add_items_sorted_by_priority(self): + """Test that items are sorted by priority""" + from app.services.memory_manager import BatchProcessor, BatchItem, BatchPriority + + processor = BatchProcessor(max_batch_size=5) + + processor.add_item(BatchItem(item_id="low", data="low", priority=BatchPriority.LOW)) + processor.add_item(BatchItem(item_id="high", data="high", priority=BatchPriority.HIGH)) + processor.add_item(BatchItem(item_id="normal", data="normal", priority=BatchPriority.NORMAL)) + + # High priority should be first + assert processor._queue[0].item_id == "high" + assert processor._queue[1].item_id == "normal" + assert processor._queue[2].item_id == "low" + + def test_process_batch(self): + """Test processing a batch of items""" + from app.services.memory_manager import BatchProcessor, BatchItem, BatchPriority + + processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False) + + for i in range(5): + processor.add_item(BatchItem(item_id=f"item{i}", data=i)) + + results = processor.process_batch(lambda x: x * 2) + + # Should process max_batch_size items + assert len(results) == 3 + assert all(r.success for r in results) + assert processor.get_queue_size() == 2 # 2 remaining + + def test_process_all(self): + """Test processing all items""" + from app.services.memory_manager import BatchProcessor, BatchItem + + processor = BatchProcessor(max_batch_size=2, cleanup_between_batches=False) + + for i in range(5): + processor.add_item(BatchItem(item_id=f"item{i}", data=i)) + + results = processor.process_all(lambda x: x * 2) + + assert len(results) == 5 + assert processor.get_queue_size() == 0 + + def test_process_with_failure(self): + """Test handling of processing failures""" + from app.services.memory_manager import BatchProcessor, BatchItem + + processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False) + + processor.add_item(BatchItem(item_id="good", data=1)) + processor.add_item(BatchItem(item_id="bad", data="error")) + + def processor_func(data): + if data == "error": + raise ValueError("Test error") + return data * 2 + + results = processor.process_all(processor_func) + + assert len(results) == 2 + assert results[0].success is True + assert results[1].success is False + assert "Test error" in results[1].error + + def test_memory_constraint_batching(self): + """Test that batches respect memory constraints""" + from app.services.memory_manager import BatchProcessor, BatchItem + + processor = BatchProcessor( + max_batch_size=10, + max_memory_per_batch_mb=100.0, + cleanup_between_batches=False + ) + + # Add items that exceed memory limit + processor.add_item(BatchItem(item_id="item1", data=1, estimated_memory_mb=60.0)) + processor.add_item(BatchItem(item_id="item2", data=2, estimated_memory_mb=60.0)) + processor.add_item(BatchItem(item_id="item3", data=3, estimated_memory_mb=60.0)) + + results = processor.process_batch(lambda x: x) + + # Should only process items that fit in memory limit + # First item (60MB) fits, second (60MB) doesn't exceed 100MB together, third does + assert len(results) == 1 or len(results) == 2 + assert processor.get_queue_size() >= 1 + + def test_get_stats(self): + """Test statistics tracking""" + from app.services.memory_manager import BatchProcessor, BatchItem + + processor = BatchProcessor(max_batch_size=2, cleanup_between_batches=False) + processor.add_item(BatchItem(item_id="item1", data=1)) + processor.add_item(BatchItem(item_id="item2", data=2)) + + processor.process_all(lambda x: x) + + stats = processor.get_stats() + assert stats["total_processed"] == 2 + assert stats["total_batches"] == 1 + assert stats["total_failures"] == 0 + + +class TestProgressiveLoader: + """Tests for ProgressiveLoader class""" + + def test_initialize(self): + """Test initializing loader with page count""" + from app.services.memory_manager import ProgressiveLoader + + loader = ProgressiveLoader(lookahead_pages=2) + loader.initialize(total_pages=10) + + stats = loader.get_stats() + assert stats["total_pages"] == 10 + assert stats["loaded_pages_count"] == 0 + + def test_load_page(self): + """Test loading a single page""" + from app.services.memory_manager import ProgressiveLoader + + loader = ProgressiveLoader(lookahead_pages=2, cleanup_after_pages=10) + loader.initialize(total_pages=5) + + page_data = loader.load_page(0, lambda p: f"page_{p}_data") + + assert page_data == "page_0_data" + assert 0 in loader.get_loaded_pages() + + def test_load_page_caching(self): + """Test that loaded pages are cached""" + from app.services.memory_manager import ProgressiveLoader + + load_count = 0 + + def loader_func(page_num): + nonlocal load_count + load_count += 1 + return f"page_{page_num}" + + loader = ProgressiveLoader(lookahead_pages=2, cleanup_after_pages=10) + loader.initialize(total_pages=5) + + # Load same page twice + loader.load_page(0, loader_func) + loader.load_page(0, loader_func) + + assert load_count == 1 # Should only load once + + def test_unload_distant_pages(self): + """Test that distant pages are unloaded""" + from app.services.memory_manager import ProgressiveLoader + + loader = ProgressiveLoader(lookahead_pages=1, cleanup_after_pages=100) + loader.initialize(total_pages=10) + + # Load several pages + for i in range(5): + loader.load_page(i, lambda p: f"page_{p}") + + # After loading page 4, distant pages should be unloaded + loaded = loader.get_loaded_pages() + # Should keep only pages near current (4): pages 3, 4, and potentially 5 + assert 0 not in loaded # Page 0 should be unloaded + + def test_iterate_pages(self): + """Test iterating through all pages""" + from app.services.memory_manager import ProgressiveLoader + + loader = ProgressiveLoader(lookahead_pages=0, cleanup_after_pages=100) + loader.initialize(total_pages=5) + + results = loader.iterate_pages( + loader_func=lambda p: f"page_{p}", + processor_func=lambda p, data: f"processed_{data}" + ) + + assert len(results) == 5 + assert results[0] == "processed_page_0" + assert results[4] == "processed_page_4" + + def test_progress_callback(self): + """Test progress callback during iteration""" + from app.services.memory_manager import ProgressiveLoader + + loader = ProgressiveLoader(lookahead_pages=0, cleanup_after_pages=100) + loader.initialize(total_pages=3) + + progress_reports = [] + + def callback(current, total): + progress_reports.append((current, total)) + + loader.iterate_pages( + loader_func=lambda p: p, + processor_func=lambda p, d: d, + progress_callback=callback + ) + + assert progress_reports == [(1, 3), (2, 3), (3, 3)] + + def test_clear(self): + """Test clearing loaded pages""" + from app.services.memory_manager import ProgressiveLoader + + loader = ProgressiveLoader(cleanup_after_pages=100) + loader.initialize(total_pages=5) + + for i in range(3): + loader.load_page(i, lambda p: p) + + loader.clear() + + assert loader.get_loaded_pages() == [] + + +class TestPriorityOperationQueue: + """Tests for PriorityOperationQueue class""" + + def test_enqueue_dequeue(self): + """Test basic enqueue and dequeue""" + from app.services.memory_manager import PriorityOperationQueue, BatchPriority + + queue = PriorityOperationQueue(max_size=10) + + queue.enqueue("item1", "data1", BatchPriority.NORMAL) + result = queue.dequeue() + + assert result is not None + item_id, data, priority = result + assert item_id == "item1" + assert data == "data1" + assert priority == BatchPriority.NORMAL + + def test_priority_ordering(self): + """Test that higher priority items are dequeued first""" + from app.services.memory_manager import PriorityOperationQueue, BatchPriority + + queue = PriorityOperationQueue() + + queue.enqueue("low", "low_data", BatchPriority.LOW) + queue.enqueue("high", "high_data", BatchPriority.HIGH) + queue.enqueue("normal", "normal_data", BatchPriority.NORMAL) + queue.enqueue("critical", "critical_data", BatchPriority.CRITICAL) + + # Dequeue in priority order + item_id, _, _ = queue.dequeue() + assert item_id == "critical" + + item_id, _, _ = queue.dequeue() + assert item_id == "high" + + item_id, _, _ = queue.dequeue() + assert item_id == "normal" + + item_id, _, _ = queue.dequeue() + assert item_id == "low" + + def test_cancel(self): + """Test cancelling an operation""" + from app.services.memory_manager import PriorityOperationQueue, BatchPriority + + queue = PriorityOperationQueue() + + queue.enqueue("item1", "data1", BatchPriority.NORMAL) + queue.enqueue("item2", "data2", BatchPriority.NORMAL) + + # Cancel item1 + assert queue.cancel("item1") is True + + # Dequeue should skip cancelled item + result = queue.dequeue() + assert result[0] == "item2" + + def test_dequeue_empty_returns_none(self): + """Test that dequeue on empty queue returns None""" + from app.services.memory_manager import PriorityOperationQueue + + queue = PriorityOperationQueue() + result = queue.dequeue(timeout=0.1) + assert result is None + + def test_max_size_limit(self): + """Test that queue respects max size""" + from app.services.memory_manager import PriorityOperationQueue, BatchPriority + + queue = PriorityOperationQueue(max_size=2) + + assert queue.enqueue("item1", "data1") is True + assert queue.enqueue("item2", "data2") is True + # Third item should fail without timeout + assert queue.enqueue("item3", "data3", timeout=0.1) is False + + def test_get_stats(self): + """Test queue statistics""" + from app.services.memory_manager import PriorityOperationQueue, BatchPriority + + queue = PriorityOperationQueue() + + queue.enqueue("item1", "data1", BatchPriority.HIGH) + queue.enqueue("item2", "data2", BatchPriority.LOW) + queue.dequeue() + + stats = queue.get_stats() + assert stats["total_enqueued"] == 2 + assert stats["total_dequeued"] == 1 + assert stats["queue_size"] == 1 + + +class TestRecoveryManager: + """Tests for RecoveryManager class""" + + def test_can_attempt_recovery(self): + """Test recovery attempt checking""" + from app.services.memory_manager import RecoveryManager + + manager = RecoveryManager( + cooldown_seconds=1.0, + max_recovery_attempts=3 + ) + + can_recover, reason = manager.can_attempt_recovery() + assert can_recover is True + assert "allowed" in reason.lower() + + def test_cooldown_period(self): + """Test that cooldown period is enforced""" + from app.services.memory_manager import RecoveryManager + + manager = RecoveryManager( + cooldown_seconds=60.0, + max_recovery_attempts=10 + ) + + # First recovery + manager.attempt_recovery() + + # Should be in cooldown + assert manager.is_in_cooldown() is True + + can_recover, reason = manager.can_attempt_recovery() + assert can_recover is False + assert "cooldown" in reason.lower() + + def test_max_recovery_attempts(self): + """Test that max recovery attempts are enforced""" + from app.services.memory_manager import RecoveryManager + + manager = RecoveryManager( + cooldown_seconds=0.01, # Very short cooldown for testing + max_recovery_attempts=2, + recovery_window_seconds=60.0 + ) + + # Perform max attempts + for _ in range(2): + manager.attempt_recovery() + time.sleep(0.02) # Wait for cooldown + + # Next attempt should be blocked + can_recover, reason = manager.can_attempt_recovery() + assert can_recover is False + assert "max" in reason.lower() + + def test_recovery_callbacks(self): + """Test recovery event callbacks""" + from app.services.memory_manager import RecoveryManager + + manager = RecoveryManager(cooldown_seconds=0.01) + + start_called = False + complete_called = False + complete_success = None + + def on_start(): + nonlocal start_called + start_called = True + + def on_complete(success): + nonlocal complete_called, complete_success + complete_called = True + complete_success = success + + manager.register_callbacks(on_start=on_start, on_complete=on_complete) + manager.attempt_recovery() + + assert start_called is True + assert complete_called is True + assert complete_success is not None + + def test_get_state(self): + """Test getting recovery state""" + from app.services.memory_manager import RecoveryManager + + manager = RecoveryManager(cooldown_seconds=60.0) + manager.attempt_recovery(error="Test error") + + state = manager.get_state() + assert state["recovery_count"] == 1 + assert state["in_cooldown"] is True + assert state["last_error"] == "Test error" + assert state["cooldown_remaining_seconds"] > 0 + + def test_emergency_release(self): + """Test emergency memory release""" + from app.services.memory_manager import RecoveryManager + + manager = RecoveryManager() + + # Emergency release without model manager + result = manager.emergency_release(model_manager=None) + # Should complete without error (may or may not free memory in test env) + assert isinstance(result, bool) + + +# ============================================================================= +# Section 1.2: Test model reload after unload +# ============================================================================= + +class TestModelReloadAfterUnload: + """Tests for model reload after unload functionality""" + + def setup_method(self): + """Reset singleton before each test""" + shutdown_model_manager() + ModelManager._instance = None + ModelManager._lock = threading.Lock() + + def teardown_method(self): + """Cleanup after each test""" + shutdown_model_manager() + ModelManager._instance = None + + def test_reload_after_unload(self): + """Test that a model can be reloaded after being unloaded""" + config = MemoryConfig() + manager = ModelManager(config) + + load_count = 0 + + def loader(): + nonlocal load_count + load_count += 1 + return Mock(name=f"model_instance_{load_count}") + + # First load + model1 = manager.get_or_load_model("test_model", loader, estimated_memory_mb=100) + assert load_count == 1 + assert "test_model" in manager.models + + # Release and unload + manager.release_model("test_model") + success = manager.unload_model("test_model") + assert success is True + assert "test_model" not in manager.models + + # Reload + model2 = manager.get_or_load_model("test_model", loader, estimated_memory_mb=100) + assert load_count == 2 # Loader called again + assert "test_model" in manager.models + + # Models should be different instances + assert model1 is not model2 + + manager.teardown() + + def test_reload_preserves_config(self): + """Test that reloaded model uses the same configuration""" + config = MemoryConfig() + manager = ModelManager(config) + + cleanup_count = 0 + + def cleanup(): + nonlocal cleanup_count + cleanup_count += 1 + + # Load with cleanup callback + manager.get_or_load_model( + "test_model", + lambda: Mock(), + estimated_memory_mb=200, + cleanup_callback=cleanup + ) + + # Release, unload (cleanup should be called) + manager.release_model("test_model") + manager.unload_model("test_model") + assert cleanup_count == 1 + + # Reload with new cleanup callback + def new_cleanup(): + nonlocal cleanup_count + cleanup_count += 10 + + manager.get_or_load_model( + "test_model", + lambda: Mock(), + estimated_memory_mb=300, + cleanup_callback=new_cleanup + ) + + # Verify new estimated memory + assert manager.models["test_model"].estimated_memory_mb == 300 + + # Release and unload again + manager.release_model("test_model") + manager.unload_model("test_model") + assert cleanup_count == 11 # New cleanup was called + + manager.teardown() + + def test_concurrent_reload(self): + """Test concurrent reload operations""" + config = MemoryConfig() + manager = ModelManager(config) + + load_count = 0 + lock = threading.Lock() + + def loader(): + nonlocal load_count + with lock: + load_count += 1 + time.sleep(0.05) # Simulate slow load + return Mock() + + # Load, release, unload + manager.get_or_load_model("test_model", loader) + manager.release_model("test_model") + manager.unload_model("test_model") + + # Concurrent reload attempts + results = [] + + def worker(): + model = manager.get_or_load_model("test_model", loader) + results.append(model) + + threads = [threading.Thread(target=worker) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All threads should get the same model instance + assert len(results) == 5 + assert len(set(id(r) for r in results)) == 1 + # Only one additional load should have occurred + assert load_count == 2 + + manager.teardown() + + +# ============================================================================= +# Section 4.2: Test memory savings with selective processing +# ============================================================================= + +class TestSelectiveProcessingMemorySavings: + """Tests for memory savings with selective processing""" + + def test_batch_processor_memory_constraints(self): + """Test that batch processor respects memory constraints""" + from app.services.memory_manager import BatchProcessor, BatchItem + + # Create processor with strict memory limit + processor = BatchProcessor( + max_batch_size=10, + max_memory_per_batch_mb=100.0, + cleanup_between_batches=False + ) + + # Add items with known memory estimates + processor.add_item(BatchItem(item_id="small1", data=1, estimated_memory_mb=20.0)) + processor.add_item(BatchItem(item_id="small2", data=2, estimated_memory_mb=20.0)) + processor.add_item(BatchItem(item_id="small3", data=3, estimated_memory_mb=20.0)) + processor.add_item(BatchItem(item_id="large1", data=4, estimated_memory_mb=80.0)) + + # First batch should include items that fit + results = processor.process_batch(lambda x: x) + + # Items should be processed respecting memory limit + total_memory_in_batch = sum( + item.estimated_memory_mb + for item in processor._queue + ) + sum(20.0 for _ in results) # Processed items had 20MB each + + # Remaining items should include the large one + assert processor.get_queue_size() >= 1 + + def test_progressive_loader_memory_efficiency(self): + """Test that progressive loader manages memory efficiently""" + from app.services.memory_manager import ProgressiveLoader + + loader = ProgressiveLoader(lookahead_pages=1, cleanup_after_pages=100) + loader.initialize(total_pages=10) + + pages_loaded = [] + + def loader_func(page_num): + pages_loaded.append(page_num) + return f"page_data_{page_num}" + + # Load pages sequentially + for i in range(10): + loader.load_page(i, loader_func) + + # Only recent pages should be kept in memory + loaded = loader.get_loaded_pages() + + # Should have unloaded distant pages + assert 0 not in loaded # First page should be unloaded + assert len(loaded) <= 3 # Current + lookahead + + loader.clear() + + def test_priority_queue_processing_order(self): + """Test that priority queue processes high priority items first""" + from app.services.memory_manager import PriorityOperationQueue, BatchPriority + + queue = PriorityOperationQueue() + + # Add items with different priorities + queue.enqueue("low1", "data", BatchPriority.LOW) + queue.enqueue("critical1", "data", BatchPriority.CRITICAL) + queue.enqueue("normal1", "data", BatchPriority.NORMAL) + queue.enqueue("high1", "data", BatchPriority.HIGH) + + # Process in priority order + order = [] + while True: + result = queue.dequeue(timeout=0.01) + if result is None: + break + order.append(result[0]) + + assert order[0] == "critical1" + assert order[1] == "high1" + assert order[2] == "normal1" + assert order[3] == "low1" + + +# ============================================================================= +# Section 5.2: Test recovery under various scenarios +# ============================================================================= + +class TestRecoveryScenarios: + """Tests for recovery under various scenarios""" + + def test_recovery_after_oom_simulation(self): + """Test recovery behavior after simulated OOM""" + from app.services.memory_manager import RecoveryManager + + manager = RecoveryManager( + cooldown_seconds=0.1, + max_recovery_attempts=5 + ) + + # Simulate OOM recovery + success = manager.attempt_recovery(error="CUDA out of memory") + assert success is not None # Recovery attempted + + # Check state + state = manager.get_state() + assert state["recovery_count"] == 1 + assert state["last_error"] == "CUDA out of memory" + + def test_recovery_cooldown_enforcement(self): + """Test that cooldown period is strictly enforced""" + from app.services.memory_manager import RecoveryManager + + manager = RecoveryManager( + cooldown_seconds=1.0, + max_recovery_attempts=10 + ) + + # First recovery + manager.attempt_recovery() + assert manager.is_in_cooldown() is True + + # Try immediate second recovery + can_recover, reason = manager.can_attempt_recovery() + assert can_recover is False + assert "cooldown" in reason.lower() + + # Wait for cooldown + time.sleep(1.1) + can_recover, reason = manager.can_attempt_recovery() + assert can_recover is True + + def test_recovery_max_attempts_window(self): + """Test that max recovery attempts are enforced within window""" + from app.services.memory_manager import RecoveryManager + + manager = RecoveryManager( + cooldown_seconds=0.01, + max_recovery_attempts=3, + recovery_window_seconds=60.0 + ) + + # Perform max attempts + for i in range(3): + manager.attempt_recovery(error=f"Error {i}") + time.sleep(0.02) # Wait for cooldown + + # Next attempt should be blocked + can_recover, reason = manager.can_attempt_recovery() + assert can_recover is False + assert "max" in reason.lower() + + def test_emergency_release_with_model_manager(self): + """Test emergency release unloads models""" + from app.services.memory_manager import RecoveryManager + + # Create a model manager with test models + shutdown_model_manager() + ModelManager._instance = None + config = MemoryConfig() + model_manager = ModelManager(config) + + # Load some test models + model_manager.get_or_load_model("model1", lambda: Mock(), estimated_memory_mb=100) + model_manager.get_or_load_model("model2", lambda: Mock(), estimated_memory_mb=200) + + # Release references + model_manager.release_model("model1") + model_manager.release_model("model2") + + assert len(model_manager.models) == 2 + + # Emergency release + recovery_manager = RecoveryManager() + result = recovery_manager.emergency_release(model_manager=model_manager) + + # Models should be unloaded + assert len(model_manager.models) == 0 + + model_manager.teardown() + shutdown_model_manager() + ModelManager._instance = None + + +# ============================================================================= +# Section 6.1: Test shutdown sequence +# ============================================================================= + +class TestShutdownSequence: + """Tests for shutdown sequence""" + + def test_model_manager_teardown(self): + """Test that teardown properly cleans up models""" + shutdown_model_manager() + ModelManager._instance = None + + config = MemoryConfig() + manager = ModelManager(config) + + # Load models + manager.get_or_load_model("model1", lambda: Mock()) + manager.get_or_load_model("model2", lambda: Mock()) + + assert len(manager.models) == 2 + + # Teardown + manager.teardown() + + assert len(manager.models) == 0 + assert manager._monitor_running is False + + shutdown_model_manager() + ModelManager._instance = None + + def test_cleanup_callbacks_called_on_teardown(self): + """Test that cleanup callbacks are called during teardown""" + shutdown_model_manager() + ModelManager._instance = None + + config = MemoryConfig() + manager = ModelManager(config) + + cleanup_calls = [] + + def cleanup1(): + cleanup_calls.append("model1") + + def cleanup2(): + cleanup_calls.append("model2") + + manager.get_or_load_model("model1", lambda: Mock(), cleanup_callback=cleanup1) + manager.get_or_load_model("model2", lambda: Mock(), cleanup_callback=cleanup2) + + # Teardown with force unload + manager.teardown() + + # Both callbacks should have been called + assert "model1" in cleanup_calls + assert "model2" in cleanup_calls + + shutdown_model_manager() + ModelManager._instance = None + + def test_prediction_semaphore_shutdown(self): + """Test prediction semaphore shutdown""" + from app.services.memory_manager import ( + get_prediction_semaphore, + shutdown_prediction_semaphore, + PredictionSemaphore + ) + + shutdown_prediction_semaphore() + PredictionSemaphore._instance = None + PredictionSemaphore._lock = threading.Lock() + + sem = get_prediction_semaphore(max_concurrent=2) + sem.acquire(task_id="test1") + + # Shutdown should reset the semaphore + shutdown_prediction_semaphore() + + # New instance should be fresh + new_sem = get_prediction_semaphore(max_concurrent=3) + assert new_sem._active_predictions == 0 + + shutdown_prediction_semaphore() + PredictionSemaphore._instance = None + + +# ============================================================================= +# Section 6.2: Test cleanup in error scenarios +# ============================================================================= + +class TestCleanupInErrorScenarios: + """Tests for cleanup in error scenarios""" + + def test_cleanup_after_loader_exception(self): + """Test cleanup when model loader raises exception""" + shutdown_model_manager() + ModelManager._instance = None + + config = MemoryConfig() + manager = ModelManager(config) + + def failing_loader(): + raise RuntimeError("Loader failed") + + with pytest.raises(RuntimeError): + manager.get_or_load_model("failing_model", failing_loader) + + # Model should not be in the manager + assert "failing_model" not in manager.models + + manager.teardown() + shutdown_model_manager() + ModelManager._instance = None + + def test_cleanup_after_processing_error(self): + """Test cleanup after processing error in batch processor""" + from app.services.memory_manager import BatchProcessor, BatchItem + + processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False) + + processor.add_item(BatchItem(item_id="good1", data=1)) + processor.add_item(BatchItem(item_id="bad", data="error")) + processor.add_item(BatchItem(item_id="good2", data=2)) + + def processor_func(data): + if data == "error": + raise ValueError("Processing error") + return data * 2 + + results = processor.process_all(processor_func) + + # Good items should succeed, bad item should fail + assert len(results) == 3 + assert results[0].success is True + assert results[1].success is False + assert results[2].success is True + + # Stats should reflect failure + stats = processor.get_stats() + assert stats["total_failures"] == 1 + + def test_pool_release_with_error(self): + """Test that pool properly handles release with error""" + from app.services.service_pool import ( + OCRServicePool, + PoolConfig, + PooledService, + ServiceState, + shutdown_service_pool + ) + + shutdown_service_pool() + OCRServicePool._instance = None + OCRServicePool._lock = threading.Lock() + + config = PoolConfig(max_consecutive_errors=2) + pool = OCRServicePool(config) + + # Pre-populate with a mock service + mock_service = Mock() + pooled_service = PooledService(service=mock_service, device="GPU:0") + pool.services["GPU:0"].append(pooled_service) + + # Acquire and release with errors + pooled = pool.acquire(device="GPU:0") + pool.release(pooled, error=Exception("Error 1")) + assert pooled.error_count == 1 + assert pooled.state == ServiceState.AVAILABLE + + pooled = pool.acquire(device="GPU:0") + pool.release(pooled, error=Exception("Error 2")) + assert pooled.error_count == 2 + assert pooled.state == ServiceState.UNHEALTHY + + pool.shutdown() + shutdown_service_pool() + OCRServicePool._instance = None + + +# ============================================================================= +# Section 8.1: Memory leak detection tests +# ============================================================================= + +class TestMemoryLeakDetection: + """Tests for memory leak detection""" + + def test_no_leak_on_model_cycle(self): + """Test that loading and unloading models doesn't leak""" + shutdown_model_manager() + ModelManager._instance = None + + config = MemoryConfig() + manager = ModelManager(config) + + initial_model_count = len(manager.models) + + # Perform multiple load/unload cycles + for i in range(10): + manager.get_or_load_model(f"temp_model_{i}", lambda: Mock()) + manager.release_model(f"temp_model_{i}") + manager.unload_model(f"temp_model_{i}") + + # Should be back to initial state + assert len(manager.models) == initial_model_count + + # Force gc and check + gc.collect() + + manager.teardown() + shutdown_model_manager() + ModelManager._instance = None + + def test_no_leak_on_semaphore_cycle(self): + """Test that semaphore acquire/release doesn't leak""" + from app.services.memory_manager import ( + PredictionSemaphore, + shutdown_prediction_semaphore + ) + + shutdown_prediction_semaphore() + PredictionSemaphore._instance = None + PredictionSemaphore._lock = threading.Lock() + + sem = PredictionSemaphore(max_concurrent=2) + + # Perform many acquire/release cycles + for i in range(100): + sem.acquire(task_id=f"task_{i}") + sem.release(task_id=f"task_{i}") + + # Active predictions should be 0 + assert sem._active_predictions == 0 + assert sem._queue_depth == 0 + + shutdown_prediction_semaphore() + PredictionSemaphore._instance = None + + def test_no_leak_in_batch_processor(self): + """Test that batch processor doesn't leak items""" + from app.services.memory_manager import BatchProcessor, BatchItem + + processor = BatchProcessor(max_batch_size=5, cleanup_between_batches=False) + + # Add and process many items + for i in range(50): + processor.add_item(BatchItem(item_id=f"item_{i}", data=i)) + + processor.process_all(lambda x: x * 2) + + # Queue should be empty + assert processor.get_queue_size() == 0 + + # Stats should be accurate + stats = processor.get_stats() + assert stats["total_processed"] == 50 + + +# ============================================================================= +# Section 8.1: Stress tests with concurrent requests +# ============================================================================= + +class TestStressConcurrentRequests: + """Stress tests with concurrent requests""" + + def test_concurrent_model_access_stress(self): + """Stress test concurrent model access""" + shutdown_model_manager() + ModelManager._instance = None + + config = MemoryConfig() + manager = ModelManager(config) + + load_count = 0 + lock = threading.Lock() + + def loader(): + nonlocal load_count + with lock: + load_count += 1 + return Mock() + + results = [] + errors = [] + + def worker(worker_id): + try: + model = manager.get_or_load_model("shared_model", loader) + time.sleep(0.01) # Simulate work + manager.release_model("shared_model") + results.append(worker_id) + except Exception as e: + errors.append(str(e)) + + # Launch many concurrent workers + threads = [threading.Thread(target=worker, args=(i,)) for i in range(20)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All workers should complete without errors + assert len(errors) == 0 + assert len(results) == 20 + + # Loader should only be called once + assert load_count == 1 + + manager.teardown() + shutdown_model_manager() + ModelManager._instance = None + + def test_concurrent_semaphore_stress(self): + """Stress test concurrent semaphore operations""" + from app.services.memory_manager import ( + PredictionSemaphore, + shutdown_prediction_semaphore + ) + + shutdown_prediction_semaphore() + PredictionSemaphore._instance = None + PredictionSemaphore._lock = threading.Lock() + + sem = PredictionSemaphore(max_concurrent=3) + + results = [] + max_concurrent_observed = 0 + current_count = 0 + lock = threading.Lock() + + def worker(worker_id): + nonlocal max_concurrent_observed, current_count + if sem.acquire(timeout=10.0, task_id=f"task_{worker_id}"): + with lock: + current_count += 1 + max_concurrent_observed = max(max_concurrent_observed, current_count) + + time.sleep(0.02) # Simulate work + + with lock: + current_count -= 1 + + sem.release(task_id=f"task_{worker_id}") + results.append(worker_id) + + # Launch many concurrent workers + threads = [threading.Thread(target=worker, args=(i,)) for i in range(15)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All should complete + assert len(results) == 15 + + # Max concurrent should not exceed limit + assert max_concurrent_observed <= 3 + + shutdown_prediction_semaphore() + PredictionSemaphore._instance = None + + def test_concurrent_batch_processing(self): + """Stress test concurrent batch processing""" + from app.services.memory_manager import BatchProcessor, BatchItem + + processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False) + + # Add items from multiple threads + def add_items(start_id): + for i in range(10): + processor.add_item(BatchItem(item_id=f"item_{start_id}_{i}", data=i)) + + threads = [threading.Thread(target=add_items, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Should have 50 items + assert processor.get_queue_size() == 50 + + # Process all + results = processor.process_all(lambda x: x) + assert len(results) == 50 + + +# ============================================================================= +# Section 8.1: Performance benchmarks +# ============================================================================= + +class TestPerformanceBenchmarks: + """Performance benchmark tests""" + + def test_model_load_performance(self): + """Benchmark model loading performance""" + shutdown_model_manager() + ModelManager._instance = None + + config = MemoryConfig() + manager = ModelManager(config) + + load_times = [] + + for i in range(5): + start = time.time() + manager.get_or_load_model(f"model_{i}", lambda: Mock()) + load_times.append(time.time() - start) + + # Average load time should be reasonable (< 100ms for mock) + avg_load_time = sum(load_times) / len(load_times) + assert avg_load_time < 0.1 # 100ms + + manager.teardown() + shutdown_model_manager() + ModelManager._instance = None + + def test_semaphore_throughput(self): + """Benchmark semaphore throughput""" + from app.services.memory_manager import ( + PredictionSemaphore, + shutdown_prediction_semaphore + ) + + shutdown_prediction_semaphore() + PredictionSemaphore._instance = None + PredictionSemaphore._lock = threading.Lock() + + sem = PredictionSemaphore(max_concurrent=10) + + start = time.time() + iterations = 1000 + + for i in range(iterations): + sem.acquire(timeout=1.0) + sem.release() + + elapsed = time.time() - start + + # Should handle at least 10000 ops/sec + ops_per_sec = iterations / elapsed + assert ops_per_sec > 1000 + + shutdown_prediction_semaphore() + PredictionSemaphore._instance = None + + def test_batch_processor_throughput(self): + """Benchmark batch processor throughput""" + from app.services.memory_manager import BatchProcessor, BatchItem + + processor = BatchProcessor(max_batch_size=100, cleanup_between_batches=False) + + # Add many items + for i in range(1000): + processor.add_item(BatchItem(item_id=f"item_{i}", data=i)) + + start = time.time() + results = processor.process_all(lambda x: x * 2) + elapsed = time.time() - start + + # Should process at least 10000 items/sec + items_per_sec = len(results) / elapsed + assert items_per_sec > 1000 + + stats = processor.get_stats() + assert stats["total_processed"] == 1000 + + +# ============================================================================= +# Tests for Memory Dump and Prometheus Metrics (Section 5.2 & 7.2) +# ============================================================================= + +class TestMemoryDumper: + """Tests for MemoryDumper class""" + + def setup_method(self): + """Reset singletons before each test""" + from app.services.memory_manager import shutdown_memory_dumper + shutdown_memory_dumper() + shutdown_model_manager() + ModelManager._instance = None + + def teardown_method(self): + """Cleanup after each test""" + from app.services.memory_manager import shutdown_memory_dumper + shutdown_memory_dumper() + shutdown_model_manager() + ModelManager._instance = None + + def test_create_dump(self): + """Test creating a memory dump""" + from app.services.memory_manager import MemoryDumper, MemoryDump + + dumper = MemoryDumper() + dump = dumper.create_dump() + + assert isinstance(dump, MemoryDump) + assert dump.timestamp > 0 + assert isinstance(dump.loaded_models, list) + assert isinstance(dump.gc_stats, dict) + + def test_dump_history(self): + """Test dump history tracking""" + from app.services.memory_manager import MemoryDumper + + dumper = MemoryDumper() + + # Create multiple dumps + for _ in range(5): + dumper.create_dump() + + history = dumper.get_dump_history() + assert len(history) == 5 + + latest = dumper.get_latest_dump() + assert latest is history[-1] + + def test_dump_comparison(self): + """Test comparing two dumps""" + from app.services.memory_manager import MemoryDumper + + dumper = MemoryDumper() + + dump1 = dumper.create_dump() + time.sleep(0.1) + dump2 = dumper.create_dump() + + comparison = dumper.compare_dumps(dump1, dump2) + + assert "time_delta_seconds" in comparison + assert comparison["time_delta_seconds"] > 0 + assert "gpu_memory_change_mb" in comparison + assert "cpu_memory_change_mb" in comparison + + def test_dump_to_dict(self): + """Test converting dump to dictionary""" + from app.services.memory_manager import MemoryDumper + + dumper = MemoryDumper() + dump = dumper.create_dump() + dump_dict = dumper.to_dict(dump) + + assert "timestamp" in dump_dict + assert "gpu" in dump_dict + assert "cpu" in dump_dict + assert "models" in dump_dict + assert "predictions" in dump_dict + + +class TestPrometheusMetrics: + """Tests for PrometheusMetrics class""" + + def setup_method(self): + """Reset singletons before each test""" + from app.services.memory_manager import shutdown_prometheus_metrics + shutdown_prometheus_metrics() + shutdown_model_manager() + ModelManager._instance = None + + def teardown_method(self): + """Cleanup after each test""" + from app.services.memory_manager import shutdown_prometheus_metrics + shutdown_prometheus_metrics() + shutdown_model_manager() + ModelManager._instance = None + + def test_export_metrics(self): + """Test exporting metrics in Prometheus format""" + from app.services.memory_manager import PrometheusMetrics + + prometheus = PrometheusMetrics() + metrics = prometheus.export_metrics() + + # Should be a non-empty string + assert isinstance(metrics, str) + assert len(metrics) > 0 + + # Should contain expected metric prefixes + assert "tool_ocr_memory_" in metrics + + def test_metric_format(self): + """Test that metrics follow Prometheus format""" + from app.services.memory_manager import PrometheusMetrics + + prometheus = PrometheusMetrics() + metrics = prometheus.export_metrics() + + lines = metrics.split("\n") + + # Check for HELP and TYPE comments + help_lines = [l for l in lines if l.startswith("# HELP")] + type_lines = [l for l in lines if l.startswith("# TYPE")] + + assert len(help_lines) > 0 + assert len(type_lines) > 0 + + def test_custom_metrics(self): + """Test setting custom metrics""" + from app.services.memory_manager import PrometheusMetrics + + prometheus = PrometheusMetrics() + + prometheus.set_custom_metric("custom_value", 42.0) + prometheus.set_custom_metric("labeled_value", 100.0, {"env": "test"}) + + metrics = prometheus.export_metrics() + + assert "custom_value" in metrics or "42" in metrics + + def test_get_prometheus_metrics_singleton(self): + """Test prometheus metrics singleton""" + from app.services.memory_manager import get_prometheus_metrics, shutdown_prometheus_metrics + + metrics1 = get_prometheus_metrics() + metrics2 = get_prometheus_metrics() + + assert metrics1 is metrics2 + + shutdown_prometheus_metrics() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/backend/tests/services/test_ocr_memory_integration.py b/backend/tests/services/test_ocr_memory_integration.py new file mode 100644 index 0000000..8de172a --- /dev/null +++ b/backend/tests/services/test_ocr_memory_integration.py @@ -0,0 +1,380 @@ +""" +Tests for OCR Service Memory Integration + +Tests the integration of MemoryGuard with OCRService patterns, +including pre-operation memory checks and CPU fallback logic. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +import sys + +# Mock paddle before importing memory_manager +paddle_mock = MagicMock() +paddle_mock.is_compiled_with_cuda.return_value = False +paddle_mock.device.cuda.device_count.return_value = 0 +paddle_mock.device.cuda.memory_allocated.return_value = 0 +paddle_mock.device.cuda.memory_reserved.return_value = 0 +paddle_mock.device.cuda.empty_cache = MagicMock() +sys.modules['paddle'] = paddle_mock + +from app.services.memory_manager import ( + MemoryGuard, + MemoryConfig, + MemoryStats, +) + + +class TestMemoryGuardIntegration: + """Tests for MemoryGuard integration patterns used in OCRService""" + + def setup_method(self): + """Setup for each test""" + self.config = MemoryConfig( + warning_threshold=0.80, + critical_threshold=0.95, + emergency_threshold=0.98, + enable_cpu_fallback=True, + ) + + def teardown_method(self): + """Cleanup after each test""" + pass + + def test_memory_check_below_threshold_allows_processing(self): + """Test that memory check returns True when below thresholds""" + guard = MemoryGuard(self.config) + + # Mock stats below warning threshold + with patch.object(guard, 'get_memory_stats') as mock_stats: + mock_stats.return_value = MemoryStats( + gpu_used_ratio=0.50, + gpu_free_mb=4000, + gpu_total_mb=8000, + ) + + is_available, stats = guard.check_memory(required_mb=2000) + + assert is_available is True + assert stats.gpu_free_mb >= 2000 + + guard.shutdown() + + def test_memory_check_above_critical_blocks_processing(self): + """Test that memory check returns False when above critical threshold""" + guard = MemoryGuard(self.config) + + # Mock stats above critical threshold + with patch.object(guard, 'get_memory_stats') as mock_stats: + mock_stats.return_value = MemoryStats( + gpu_used_ratio=0.96, + gpu_free_mb=320, + gpu_total_mb=8000, + ) + + is_available, stats = guard.check_memory(required_mb=1000) + + assert is_available is False + + guard.shutdown() + + def test_memory_check_insufficient_free_memory(self): + """Test that memory check returns False when free memory < required""" + guard = MemoryGuard(self.config) + + # Mock stats with insufficient free memory but below critical ratio + with patch.object(guard, 'get_memory_stats') as mock_stats: + mock_stats.return_value = MemoryStats( + gpu_used_ratio=0.70, + gpu_free_mb=500, + gpu_total_mb=8000, + ) + + is_available, stats = guard.check_memory(required_mb=1000) + + # Should return False (not enough free memory) + assert is_available is False + + guard.shutdown() + + +class TestCPUFallbackPattern: + """Tests for CPU fallback pattern as used in OCRService""" + + def test_cpu_fallback_activation_pattern(self): + """Test the CPU fallback activation pattern""" + # Simulate the pattern used in OCRService._activate_cpu_fallback + + class MockOCRService: + def __init__(self): + self._cpu_fallback_active = False + self.use_gpu = True + self.gpu_available = True + self.gpu_info = {'device_id': 0} + self._memory_guard = Mock() + + def _activate_cpu_fallback(self): + if self._cpu_fallback_active: + return + + self._cpu_fallback_active = True + self.use_gpu = False + self.gpu_info['cpu_fallback'] = True + self.gpu_info['fallback_reason'] = 'GPU memory insufficient' + + if self._memory_guard: + self._memory_guard.clear_gpu_cache() + + service = MockOCRService() + + # Verify initial state + assert service._cpu_fallback_active is False + assert service.use_gpu is True + + # Activate fallback + service._activate_cpu_fallback() + + # Verify fallback state + assert service._cpu_fallback_active is True + assert service.use_gpu is False + assert service.gpu_info.get('cpu_fallback') is True + service._memory_guard.clear_gpu_cache.assert_called_once() + + def test_cpu_fallback_idempotent(self): + """Test that CPU fallback activation is idempotent""" + class MockOCRService: + def __init__(self): + self._cpu_fallback_active = False + self.use_gpu = True + self._memory_guard = Mock() + self.gpu_info = {} + + def _activate_cpu_fallback(self): + if self._cpu_fallback_active: + return + self._cpu_fallback_active = True + self.use_gpu = False + if self._memory_guard: + self._memory_guard.clear_gpu_cache() + + service = MockOCRService() + + # Activate twice + service._activate_cpu_fallback() + service._activate_cpu_fallback() + + # clear_gpu_cache should only be called once + assert service._memory_guard.clear_gpu_cache.call_count == 1 + + def test_gpu_mode_restoration_pattern(self): + """Test the GPU mode restoration pattern""" + # Simulate the pattern used in OCRService._restore_gpu_mode + + class MockOCRService: + def __init__(self): + self._cpu_fallback_active = True + self.use_gpu = False + self.gpu_available = True + self.gpu_info = { + 'device_id': 0, + 'cpu_fallback': True, + 'fallback_reason': 'test' + } + self._memory_guard = Mock() + + def _restore_gpu_mode(self): + if not self._cpu_fallback_active: + return + + if not self.gpu_available: + return + + # Check if GPU memory is now available + if self._memory_guard: + is_available, stats = self._memory_guard.check_memory(required_mb=2000) + if is_available: + self._cpu_fallback_active = False + self.use_gpu = True + self.gpu_info.pop('cpu_fallback', None) + self.gpu_info.pop('fallback_reason', None) + + service = MockOCRService() + + # Mock memory guard to indicate sufficient memory + mock_stats = Mock() + mock_stats.gpu_free_mb = 5000 + service._memory_guard.check_memory.return_value = (True, mock_stats) + + # Restore GPU mode + service._restore_gpu_mode() + + # Verify GPU mode restored + assert service._cpu_fallback_active is False + assert service.use_gpu is True + assert 'cpu_fallback' not in service.gpu_info + + def test_gpu_mode_not_restored_when_memory_still_low(self): + """Test that GPU mode is not restored when memory is still low""" + class MockOCRService: + def __init__(self): + self._cpu_fallback_active = True + self.use_gpu = False + self.gpu_available = True + self.gpu_info = {'cpu_fallback': True} + self._memory_guard = Mock() + + def _restore_gpu_mode(self): + if not self._cpu_fallback_active: + return + if not self.gpu_available: + return + if self._memory_guard: + is_available, stats = self._memory_guard.check_memory(required_mb=2000) + if is_available: + self._cpu_fallback_active = False + self.use_gpu = True + + service = MockOCRService() + + # Mock memory guard to indicate insufficient memory + mock_stats = Mock() + mock_stats.gpu_free_mb = 500 + service._memory_guard.check_memory.return_value = (False, mock_stats) + + # Try to restore GPU mode + service._restore_gpu_mode() + + # Verify still in fallback mode + assert service._cpu_fallback_active is True + assert service.use_gpu is False + + +class TestPreOperationMemoryCheckPattern: + """Tests for pre-operation memory check pattern as used in OCRService""" + + def test_pre_operation_check_with_fallback(self): + """Test the pre-operation memory check pattern with fallback""" + guard = MemoryGuard(MemoryConfig( + warning_threshold=0.80, + critical_threshold=0.95, + enable_cpu_fallback=True, + )) + + # Simulate the pattern: + # 1. Check if in CPU fallback mode + # 2. Try to restore GPU mode if memory available + # 3. Perform memory check for operation + + class MockService: + def __init__(self): + self._cpu_fallback_active = False + self.use_gpu = True + self.gpu_available = True + self._memory_guard = guard + + def _restore_gpu_mode(self): + pass # Simplified + + def pre_operation_check(self, required_mb: int) -> bool: + # Try restore first + if self._cpu_fallback_active: + self._restore_gpu_mode() + + # Perform memory check + if not self.use_gpu: + return True # CPU mode, no GPU check needed + + is_available, stats = self._memory_guard.check_memory(required_mb=required_mb) + return is_available + + service = MockService() + + # Mock sufficient memory + with patch.object(guard, 'get_memory_stats') as mock_stats: + mock_stats.return_value = MemoryStats( + gpu_used_ratio=0.50, + gpu_free_mb=4000, + gpu_total_mb=8000, + ) + + result = service.pre_operation_check(required_mb=2000) + assert result is True + + guard.shutdown() + + def test_pre_operation_check_returns_true_in_cpu_mode(self): + """Test that pre-operation check returns True when in CPU mode""" + class MockService: + def __init__(self): + self._cpu_fallback_active = True + self.use_gpu = False + self._memory_guard = Mock() + + def pre_operation_check(self, required_mb: int) -> bool: + if not self.use_gpu: + return True # CPU mode, no GPU check needed + return False + + service = MockService() + result = service.pre_operation_check(required_mb=5000) + + # Should return True because we're in CPU mode + assert result is True + # Memory guard should not be called + service._memory_guard.check_memory.assert_not_called() + + +class TestMemoryCheckWithCleanup: + """Tests for memory check with cleanup pattern""" + + def test_memory_check_triggers_cleanup_on_failure(self): + """Test that memory check triggers cleanup when insufficient""" + guard = MemoryGuard(MemoryConfig( + warning_threshold=0.80, + critical_threshold=0.95, + )) + + # Track cleanup calls + cleanup_called = False + + def mock_cleanup(): + nonlocal cleanup_called + cleanup_called = True + + class MockService: + def __init__(self): + self._memory_guard = guard + self.cleanup_func = mock_cleanup + + def check_gpu_memory(self, required_mb: int) -> bool: + # First check + with patch.object(self._memory_guard, 'get_memory_stats') as mock_stats: + # First call - low memory + mock_stats.return_value = MemoryStats( + gpu_used_ratio=0.96, + gpu_free_mb=300, + gpu_total_mb=8000, + ) + + is_available, stats = self._memory_guard.check_memory(required_mb=required_mb) + + if not is_available: + # Trigger cleanup + self.cleanup_func() + self._memory_guard.clear_gpu_cache() + return False + + return True + + service = MockService() + result = service.check_gpu_memory(required_mb=1000) + + # Cleanup should have been triggered + assert cleanup_called is True + assert result is False + + guard.shutdown() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/backend/tests/services/test_service_pool.py b/backend/tests/services/test_service_pool.py new file mode 100644 index 0000000..1f09e6d --- /dev/null +++ b/backend/tests/services/test_service_pool.py @@ -0,0 +1,387 @@ +""" +Tests for OCR Service Pool + +Tests OCRServicePool functionality including acquire, release, and concurrency. +""" + +import pytest +import threading +import time +from unittest.mock import Mock, patch, MagicMock +import sys + +# Mock paddle before importing service_pool to avoid import errors +# when paddle is not installed in the test environment +paddle_mock = MagicMock() +paddle_mock.is_compiled_with_cuda.return_value = False +paddle_mock.device.cuda.device_count.return_value = 0 +paddle_mock.device.cuda.memory_allocated.return_value = 0 +paddle_mock.device.cuda.memory_reserved.return_value = 0 +paddle_mock.device.cuda.empty_cache = MagicMock() +sys.modules['paddle'] = paddle_mock + +from app.services.service_pool import ( + OCRServicePool, + PooledService, + PoolConfig, + ServiceState, + get_service_pool, + shutdown_service_pool, +) + + +class TestPoolConfig: + """Tests for PoolConfig class""" + + def test_default_values(self): + """Test default configuration values""" + config = PoolConfig() + assert config.max_services_per_device == 1 + assert config.max_total_services == 2 + assert config.acquire_timeout_seconds == 300.0 + assert config.max_queue_size == 50 + assert config.max_consecutive_errors == 3 + + def test_custom_values(self): + """Test custom configuration values""" + config = PoolConfig( + max_services_per_device=2, + max_total_services=4, + acquire_timeout_seconds=60.0, + ) + assert config.max_services_per_device == 2 + assert config.max_total_services == 4 + assert config.acquire_timeout_seconds == 60.0 + + +class TestPooledService: + """Tests for PooledService class""" + + def test_creation(self): + """Test PooledService creation""" + mock_service = Mock() + pooled = PooledService( + service=mock_service, + device="GPU:0", + ) + assert pooled.service is mock_service + assert pooled.device == "GPU:0" + assert pooled.state == ServiceState.AVAILABLE + assert pooled.use_count == 0 + assert pooled.error_count == 0 + + +class TestOCRServicePool: + """Tests for OCRServicePool class""" + + def setup_method(self): + """Reset singleton before each test""" + shutdown_service_pool() + OCRServicePool._instance = None + OCRServicePool._lock = threading.Lock() + + def teardown_method(self): + """Cleanup after each test""" + shutdown_service_pool() + OCRServicePool._instance = None + + def test_singleton_pattern(self): + """Test that OCRServicePool is a singleton""" + pool1 = OCRServicePool() + pool2 = OCRServicePool() + assert pool1 is pool2 + pool1.shutdown() + + def test_initialize_device(self): + """Test device initialization""" + config = PoolConfig() + pool = OCRServicePool(config) + + # Default device should be initialized + assert "GPU:0" in pool.services + assert "GPU:0" in pool.semaphores + + # Test adding new device + pool._initialize_device("GPU:1") + assert "GPU:1" in pool.services + assert "GPU:1" in pool.semaphores + + pool.shutdown() + + def test_acquire_creates_service(self): + """Test that acquire creates a new service if none available""" + config = PoolConfig(max_services_per_device=1) + pool = OCRServicePool(config) + + # Pre-populate with a mock service + mock_service = Mock() + mock_service.process = Mock() + mock_service.get_gpu_status = Mock() + pooled_service = PooledService(service=mock_service, device="GPU:0") + pool.services["GPU:0"].append(pooled_service) + + pooled = pool.acquire(device="GPU:0", timeout=5.0) + assert pooled is not None + assert pooled.state == ServiceState.IN_USE + assert pooled.use_count == 1 + + pool.shutdown() + + def test_acquire_reuses_available_service(self): + """Test that acquire reuses available services""" + config = PoolConfig(max_services_per_device=1) + pool = OCRServicePool(config) + + # Pre-populate with a mock service + mock_service = Mock() + pooled_service = PooledService(service=mock_service, device="GPU:0") + pool.services["GPU:0"].append(pooled_service) + + # First acquire + pooled1 = pool.acquire(device="GPU:0") + service_id = id(pooled1.service) + pool.release(pooled1) + + # Second acquire should get the same service + pooled2 = pool.acquire(device="GPU:0") + assert id(pooled2.service) == service_id + assert pooled2.use_count == 2 + + pool.shutdown() + + def test_release_makes_service_available(self): + """Test that release makes service available again""" + config = PoolConfig() + pool = OCRServicePool(config) + + # Pre-populate with a mock service + mock_service = Mock() + pooled_service = PooledService(service=mock_service, device="GPU:0") + pool.services["GPU:0"].append(pooled_service) + + pooled = pool.acquire(device="GPU:0") + assert pooled.state == ServiceState.IN_USE + + pool.release(pooled) + assert pooled.state == ServiceState.AVAILABLE + + pool.shutdown() + + def test_release_with_error(self): + """Test that release with error increments error count""" + config = PoolConfig(max_consecutive_errors=3) + pool = OCRServicePool(config) + + # Pre-populate with a mock service + mock_service = Mock() + pooled_service = PooledService(service=mock_service, device="GPU:0") + pool.services["GPU:0"].append(pooled_service) + + pooled = pool.acquire(device="GPU:0") + pool.release(pooled, error=Exception("Test error")) + + assert pooled.error_count == 1 + assert pooled.state == ServiceState.AVAILABLE + + pool.shutdown() + + def test_release_marks_unhealthy_after_errors(self): + """Test that service is marked unhealthy after too many errors""" + config = PoolConfig(max_consecutive_errors=2) + pool = OCRServicePool(config) + + # Pre-populate with a mock service + mock_service = Mock() + pooled_service = PooledService(service=mock_service, device="GPU:0") + pool.services["GPU:0"].append(pooled_service) + + pooled = pool.acquire(device="GPU:0") + pool.release(pooled, error=Exception("Error 1")) + + pooled = pool.acquire(device="GPU:0") + pool.release(pooled, error=Exception("Error 2")) + + assert pooled.state == ServiceState.UNHEALTHY + assert pooled.error_count == 2 + + pool.shutdown() + + def test_acquire_context_manager(self): + """Test context manager for acquire/release""" + config = PoolConfig() + pool = OCRServicePool(config) + + # Pre-populate with a mock service + mock_service = Mock() + pooled_service = PooledService(service=mock_service, device="GPU:0") + pool.services["GPU:0"].append(pooled_service) + + with pool.acquire_context(device="GPU:0") as pooled: + assert pooled is not None + assert pooled.state == ServiceState.IN_USE + + # After context, service should be available + assert pooled.state == ServiceState.AVAILABLE + + pool.shutdown() + + def test_acquire_context_manager_with_error(self): + """Test context manager releases on error""" + config = PoolConfig() + pool = OCRServicePool(config) + + # Pre-populate with a mock service + mock_service = Mock() + pooled_service = PooledService(service=mock_service, device="GPU:0") + pool.services["GPU:0"].append(pooled_service) + + with pytest.raises(ValueError): + with pool.acquire_context(device="GPU:0") as pooled: + raise ValueError("Test error") + + # Service should still be available after error + assert pooled.error_count == 1 + + pool.shutdown() + + def test_acquire_timeout(self): + """Test that acquire times out when no service available""" + config = PoolConfig( + max_services_per_device=1, + max_total_services=1, + ) + pool = OCRServicePool(config) + + # Pre-populate with a mock service + mock_service = Mock() + pooled_service = PooledService(service=mock_service, device="GPU:0") + pool.services["GPU:0"].append(pooled_service) + + # Acquire the only service + pooled1 = pool.acquire(device="GPU:0") + assert pooled1 is not None + + # Try to acquire another - should timeout + pooled2 = pool.acquire(device="GPU:0", timeout=0.5) + assert pooled2 is None + + pool.shutdown() + + def test_get_pool_stats(self): + """Test pool statistics""" + config = PoolConfig() + pool = OCRServicePool(config) + + # Pre-populate with a mock service + mock_service = Mock() + pooled_service = PooledService(service=mock_service, device="GPU:0") + pool.services["GPU:0"].append(pooled_service) + + # Acquire a service + pooled = pool.acquire(device="GPU:0") + + stats = pool.get_pool_stats() + assert stats["total_services"] == 1 + assert stats["in_use_services"] == 1 + assert stats["available_services"] == 0 + assert stats["metrics"]["total_acquisitions"] == 1 + + pool.release(pooled) + + stats = pool.get_pool_stats() + assert stats["available_services"] == 1 + assert stats["metrics"]["total_releases"] == 1 + + pool.shutdown() + + def test_health_check(self): + """Test health check functionality""" + config = PoolConfig() + pool = OCRServicePool(config) + + # Pre-populate with a mock service + mock_service = Mock() + mock_service.process = Mock() + mock_service.get_gpu_status = Mock() + pooled_service = PooledService(service=mock_service, device="GPU:0") + pool.services["GPU:0"].append(pooled_service) + + # Acquire and release to update use_count + pooled = pool.acquire(device="GPU:0") + pool.release(pooled) + + health = pool.health_check() + assert health["healthy"] is True + assert len(health["services"]) == 1 + assert health["services"][0]["responsive"] is True + + pool.shutdown() + + def test_concurrent_acquire(self): + """Test concurrent service acquisition""" + config = PoolConfig( + max_services_per_device=2, + max_total_services=2, + ) + pool = OCRServicePool(config) + + # Pre-populate with 2 mock services + for i in range(2): + mock_service = Mock() + pooled_service = PooledService(service=mock_service, device="GPU:0") + pool.services["GPU:0"].append(pooled_service) + + results = [] + + def worker(worker_id): + pooled = pool.acquire(device="GPU:0", timeout=5.0, task_id=f"task_{worker_id}") + if pooled: + results.append((worker_id, pooled)) + time.sleep(0.1) # Simulate work + pool.release(pooled) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All workers should have acquired a service + assert len(results) == 4 + + pool.shutdown() + + +class TestGetServicePool: + """Tests for get_service_pool helper function""" + + def setup_method(self): + """Reset singleton before each test""" + shutdown_service_pool() + OCRServicePool._instance = None + + def teardown_method(self): + """Cleanup after each test""" + shutdown_service_pool() + OCRServicePool._instance = None + + def test_get_service_pool_creates_singleton(self): + """Test that get_service_pool creates a singleton""" + pool1 = get_service_pool() + pool2 = get_service_pool() + assert pool1 is pool2 + shutdown_service_pool() + + def test_shutdown_service_pool(self): + """Test shutdown_service_pool cleans up""" + pool = get_service_pool() + shutdown_service_pool() + + # Should be able to create new pool + new_pool = get_service_pool() + assert new_pool._initialized is True + shutdown_service_pool() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 89113c9..ffec023 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -8,6 +8,7 @@ "name": "frontend", "version": "0.0.0", "dependencies": { + "@radix-ui/react-select": "^2.2.6", "@tanstack/react-query": "^5.90.7", "axios": "^1.13.2", "class-variance-authority": "^0.7.0", @@ -87,7 +88,6 @@ "integrity": "sha512-e7jT4DxYvIDLk1ZHmU/m/mB19rex9sv0c2ftBtjSBv+kVM/902eh0fINUzD7UwLLNR+jU585GxUJ8/EBfAM5fw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@babel/code-frame": "^7.27.1", "@babel/generator": "^7.28.5", @@ -947,6 +947,44 @@ "node": "^18.18.0 || ^20.9.0 || >=21.1.0" } }, + "node_modules/@floating-ui/core": { + "version": "1.7.3", + "resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.7.3.tgz", + "integrity": "sha512-sGnvb5dmrJaKEZ+LDIpguvdX3bDlEllmv4/ClQ9awcmCZrlx5jQyyMWFM5kBI+EyNOCDDiKk8il0zeuX3Zlg/w==", + "license": "MIT", + "dependencies": { + "@floating-ui/utils": "^0.2.10" + } + }, + "node_modules/@floating-ui/dom": { + "version": "1.7.4", + "resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.7.4.tgz", + "integrity": "sha512-OOchDgh4F2CchOX94cRVqhvy7b3AFb+/rQXyswmzmGakRfkMgoWVjfnLWkRirfLEfuD4ysVW16eXzwt3jHIzKA==", + "license": "MIT", + "dependencies": { + "@floating-ui/core": "^1.7.3", + "@floating-ui/utils": "^0.2.10" + } + }, + "node_modules/@floating-ui/react-dom": { + "version": "2.1.6", + "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.1.6.tgz", + "integrity": "sha512-4JX6rEatQEvlmgU80wZyq9RT96HZJa88q8hp0pBd+LrczeDI4o6uA2M+uvxngVHo4Ihr8uibXxH6+70zhAFrVw==", + "license": "MIT", + "dependencies": { + "@floating-ui/dom": "^1.7.4" + }, + "peerDependencies": { + "react": ">=16.8.0", + "react-dom": ">=16.8.0" + } + }, + "node_modules/@floating-ui/utils": { + "version": "0.2.10", + "resolved": "https://registry.npmjs.org/@floating-ui/utils/-/utils-0.2.10.tgz", + "integrity": "sha512-aGTxbpbg8/b5JfU1HXSrbH3wXZuLPJcNEcZQFMxLs3oSzgtVu6nFPkbbGGUvBcUjKV2YyB9Wxxabo+HEH9tcRQ==", + "license": "MIT" + }, "node_modules/@humanfs/core": { "version": "0.19.1", "resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz", @@ -1272,6 +1310,502 @@ "node": ">= 8" } }, + "node_modules/@radix-ui/number": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/number/-/number-1.1.1.tgz", + "integrity": "sha512-MkKCwxlXTgz6CFoJx3pCwn07GKp36+aZyu/u2Ln2VrA5DcdyCZkASEDBTd8x5whTQQL5CiYf4prXKLcgQdv29g==", + "license": "MIT" + }, + "node_modules/@radix-ui/primitive": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.3.tgz", + "integrity": "sha512-JTF99U/6XIjCBo0wqkU5sK10glYe27MRRsfwoiq5zzOEZLHU3A3KCMa5X/azekYRCJ0HlwI0crAXS/5dEHTzDg==", + "license": "MIT" + }, + "node_modules/@radix-ui/react-arrow": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.1.7.tgz", + "integrity": "sha512-F+M1tLhO+mlQaOWspE8Wstg+z6PwxwRd8oQ8IXceWz92kfAmalTRf0EjrouQeo7QssEPfCn05B4Ihs1K9WQ/7w==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-collection": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.7.tgz", + "integrity": "sha512-Fh9rGN0MoI4ZFUNyfFVNU4y9LUz93u9/0K+yLgA2bwRojxM8JU1DyvvMBabnZPBgMWREAJvU2jjVzq+LrFUglw==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-compose-refs": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.2.tgz", + "integrity": "sha512-z4eqJvfiNnFMHIIvXP3CY57y2WJs5g2v3X0zm9mEJkrkNv4rDxu+sg9Jh8EkXyeqBkB7SOcboo9dMVqhyrACIg==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-context": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.2.tgz", + "integrity": "sha512-jCi/QKUM2r1Ju5a3J64TH2A5SpKAgh0LpknyqdQ4m6DCV0xJ2HG1xARRwNGPQfi1SLdLWZ1OJz6F4OMBBNiGJA==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-direction": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-direction/-/react-direction-1.1.1.tgz", + "integrity": "sha512-1UEWRX6jnOA2y4H5WczZ44gOOjTEmlqv1uNW4GAJEO5+bauCBhv8snY65Iw5/VOS/ghKN9gr2KjnLKxrsvoMVw==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-dismissable-layer": { + "version": "1.1.11", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.11.tgz", + "integrity": "sha512-Nqcp+t5cTB8BinFkZgXiMJniQH0PsUt2k51FUhbdfeKvc4ACcG2uQniY/8+h1Yv6Kza4Q7lD7PQV0z0oicE0Mg==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.3", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-escape-keydown": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-focus-guards": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-guards/-/react-focus-guards-1.1.3.tgz", + "integrity": "sha512-0rFg/Rj2Q62NCm62jZw0QX7a3sz6QCQU0LpZdNrJX8byRGaGVTqbrW9jAoIAHyMQqsNpeZ81YgSizOt5WXq0Pw==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-focus-scope": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.7.tgz", + "integrity": "sha512-t2ODlkXBQyn7jkl6TNaw/MtVEVvIGelJDCG41Okq/KwUsJBwQ4XVZsHAVUkK4mBv3ewiAS3PGuUWuY2BoK4ZUw==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-id": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-id/-/react-id-1.1.1.tgz", + "integrity": "sha512-kGkGegYIdQsOb4XjsfM97rXsiHaBwco+hFI66oO4s9LU+PLAC5oJ7khdOVFxkhsmlbpUqDAvXw11CluXP+jkHg==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-popper": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.8.tgz", + "integrity": "sha512-0NJQ4LFFUuWkE7Oxf0htBKS6zLkkjBH+hM1uk7Ng705ReR8m/uelduy1DBo0PyBXPKVnBA6YBlU94MBGXrSBCw==", + "license": "MIT", + "dependencies": { + "@floating-ui/react-dom": "^2.0.0", + "@radix-ui/react-arrow": "1.1.7", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-layout-effect": "1.1.1", + "@radix-ui/react-use-rect": "1.1.1", + "@radix-ui/react-use-size": "1.1.1", + "@radix-ui/rect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-portal": { + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz", + "integrity": "sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-primitive": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-slot": "1.2.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-select": { + "version": "2.2.6", + "resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.2.6.tgz", + "integrity": "sha512-I30RydO+bnn2PQztvo25tswPH+wFBjehVGtmagkU78yMdwTwVf12wnAOF+AeP8S2N8xD+5UPbGhkUfPyvT+mwQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/number": "1.1.1", + "@radix-ui/primitive": "1.1.3", + "@radix-ui/react-collection": "1.1.7", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-direction": "1.1.1", + "@radix-ui/react-dismissable-layer": "1.1.11", + "@radix-ui/react-focus-guards": "1.1.3", + "@radix-ui/react-focus-scope": "1.1.7", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-popper": "1.2.8", + "@radix-ui/react-portal": "1.1.9", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", + "@radix-ui/react-use-callback-ref": "1.1.1", + "@radix-ui/react-use-controllable-state": "1.2.2", + "@radix-ui/react-use-layout-effect": "1.1.1", + "@radix-ui/react-use-previous": "1.1.1", + "@radix-ui/react-visually-hidden": "1.2.3", + "aria-hidden": "^1.2.4", + "react-remove-scroll": "^2.6.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-slot": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz", + "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-callback-ref": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.1.1.tgz", + "integrity": "sha512-FkBMwD+qbGQeMu1cOHnuGB6x4yzPjho8ap5WtbEJ26umhgqVXbhekKUQO+hZEL1vU92a3wHwdp0HAcqAUF5iDg==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-controllable-state": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-1.2.2.tgz", + "integrity": "sha512-BjasUjixPFdS+NKkypcyyN5Pmg83Olst0+c6vGov0diwTEo6mgdqVR6hxcEgFuh4QrAs7Rc+9KuGJ9TVCj0Zzg==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-use-effect-event": "0.0.2", + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-effect-event": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-effect-event/-/react-use-effect-event-0.0.2.tgz", + "integrity": "sha512-Qp8WbZOBe+blgpuUT+lw2xheLP8q0oatc9UpmiemEICxGvFLYmHm9QowVZGHtJlGbS6A6yJ3iViad/2cVjnOiA==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-escape-keydown": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-1.1.1.tgz", + "integrity": "sha512-Il0+boE7w/XebUHyBjroE+DbByORGR9KKmITzbR7MyQ4akpORYP/ZmbhAr0DG7RmmBqoOnZdy2QlvajJ2QA59g==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-use-callback-ref": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-layout-effect": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-1.1.1.tgz", + "integrity": "sha512-RbJRS4UWQFkzHTTwVymMTUv8EqYhOp8dOOviLj2ugtTiXRaRQS7GLGxZTLL1jWhMeoSCf5zmcZkqTl9IiYfXcQ==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-previous": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-previous/-/react-use-previous-1.1.1.tgz", + "integrity": "sha512-2dHfToCj/pzca2Ck724OZ5L0EVrr3eHRNsG/b3xQJLA2hZpVCS99bLAX+hm1IHXDEnzU6by5z/5MIY794/a8NQ==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-rect": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-rect/-/react-use-rect-1.1.1.tgz", + "integrity": "sha512-QTYuDesS0VtuHNNvMh+CjlKJ4LJickCMUAqjlE3+j8w+RlRpwyX3apEQKGFzbZGdo7XNG1tXa+bQqIE7HIXT2w==", + "license": "MIT", + "dependencies": { + "@radix-ui/rect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-use-size": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-use-size/-/react-use-size-1.1.1.tgz", + "integrity": "sha512-ewrXRDTAqAXlkl6t/fkXWNAhFX9I+CkKlw6zjEwk86RSPKwZr3xpBRso655aqYafwtnbpHLj6toFzmd6xdVptQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-use-layout-effect": "1.1.1" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-visually-hidden": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.2.3.tgz", + "integrity": "sha512-pzJq12tEaaIhqjbzpCuv/OypJY/BPavOofm+dbab+MHLajy277+1lLm6JFcGgF5eskJ6mquGirhXY2GD/8u8Ug==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.1.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/rect": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/rect/-/rect-1.1.1.tgz", + "integrity": "sha512-HPwpGIzkl28mWyZqG52jiqDJ12waP11Pa1lGoiyUkIEuMLBP0oeK/C89esbXrxsky5we7dfd8U58nm0SgAWpVw==", + "license": "MIT" + }, "node_modules/@rolldown/pluginutils": { "version": "1.0.0-beta.47", "resolved": "https://registry.npmjs.org/@rolldown/pluginutils/-/pluginutils-1.0.0-beta.47.tgz", @@ -1990,7 +2524,6 @@ "integrity": "sha512-GNWcUTRBgIRJD5zj+Tq0fKOJ5XZajIiBroOF0yvj2bSU1WvNdYS/dn9UxwsujGW4JX06dnHyjV2y9rRaybH0iQ==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "undici-types": "~7.16.0" } @@ -2000,7 +2533,6 @@ "resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.4.tgz", "integrity": "sha512-tBFxBp9Nfyy5rsmefN+WXc1JeW/j2BpBHFdLZbEVfs9wn3E3NRFxwV0pJg8M1qQAexFpvz73hJXFofV0ZAu92A==", "license": "MIT", - "peer": true, "dependencies": { "csstype": "^3.0.2" } @@ -2009,7 +2541,7 @@ "version": "19.2.3", "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-19.2.3.tgz", "integrity": "sha512-jp2L/eY6fn+KgVVQAOqYItbF0VY/YApe5Mz2F0aykSO8gx31bYCZyvSeYxCHKvzHG5eZjc+zyaS5BrBWya2+kQ==", - "dev": true, + "devOptional": true, "license": "MIT", "peerDependencies": { "@types/react": "^19.2.0" @@ -2067,7 +2599,6 @@ "integrity": "sha512-tK3GPFWbirvNgsNKto+UmB/cRtn6TZfyw0D6IKrW55n6Vbs7KJoZtI//kpTKzE/DUmmnAFD8/Ca46s7Obs92/w==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.46.4", "@typescript-eslint/types": "8.46.4", @@ -2326,7 +2857,6 @@ "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "license": "MIT", - "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -2384,6 +2914,18 @@ "dev": true, "license": "Python-2.0" }, + "node_modules/aria-hidden": { + "version": "1.2.6", + "resolved": "https://registry.npmjs.org/aria-hidden/-/aria-hidden-1.2.6.tgz", + "integrity": "sha512-ik3ZgC9dY/lYVVM++OISsaYDeg1tb0VtP5uL3ouh1koGOaUMDPpbFIei4JkFimWUFPn90sbMNMXQAIVOlnYKJA==", + "license": "MIT", + "dependencies": { + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + } + }, "node_modules/asynckit": { "version": "0.4.0", "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", @@ -2519,7 +3061,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "baseline-browser-mapping": "^2.8.25", "caniuse-lite": "^1.0.30001754", @@ -2817,6 +3358,12 @@ "node": ">=8" } }, + "node_modules/detect-node-es": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/detect-node-es/-/detect-node-es-1.1.0.tgz", + "integrity": "sha512-ypdmJU/TbBby2Dxibuv7ZLW3Bs1QEmM7nHjEANfohJLvE0XVujisn1qPJcZxg+qDucsr+bP6fLD1rPS3AhJ7EQ==", + "license": "MIT" + }, "node_modules/devlop": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/devlop/-/devlop-1.1.0.tgz", @@ -2981,7 +3528,6 @@ "integrity": "sha512-BhHmn2yNOFA9H9JmmIVKJmd288g9hrVRDkdoIgRCRuSySRUHH7r/DI6aAXW9T1WwUuY3DFgrcaqB+deURBLR5g==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", @@ -3414,6 +3960,15 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/get-nonce": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-nonce/-/get-nonce-1.0.1.tgz", + "integrity": "sha512-FJhYRoDaiatfEkUK8HKlicmu/3SGFD51q3itKDGoSTysQJBnfOcxU5GxnhE1E6soB76MbT0MBtnKJuXyAx+96Q==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, "node_modules/get-proto": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", @@ -3606,7 +4161,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "@babel/runtime": "^7.27.6" }, @@ -5096,7 +5650,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -5186,7 +5739,6 @@ "resolved": "https://registry.npmjs.org/react/-/react-19.2.0.tgz", "integrity": "sha512-tmbWg6W31tQLeB5cdIBOicJDJRR2KzXsV7uSK9iNfLWQ5bIZfxuPEHp7M8wiHyHnn0DD1i7w3Zmin0FtkrwoCQ==", "license": "MIT", - "peer": true, "engines": { "node": ">=0.10.0" } @@ -5196,7 +5748,6 @@ "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.2.0.tgz", "integrity": "sha512-UlbRu4cAiGaIewkPyiRGJk0imDN2T3JjieT6spoL2UeSf5od4n5LB/mQ4ejmxhCFT1tYe8IvaFulzynWovsEFQ==", "license": "MIT", - "peer": true, "dependencies": { "scheduler": "^0.27.0" }, @@ -5332,6 +5883,53 @@ "node": ">=0.10.0" } }, + "node_modules/react-remove-scroll": { + "version": "2.7.1", + "resolved": "https://registry.npmjs.org/react-remove-scroll/-/react-remove-scroll-2.7.1.tgz", + "integrity": "sha512-HpMh8+oahmIdOuS5aFKKY6Pyog+FNaZV/XyJOq7b4YFwsFHe5yYfdbIalI4k3vU2nSDql7YskmUseHsRrJqIPA==", + "license": "MIT", + "dependencies": { + "react-remove-scroll-bar": "^2.3.7", + "react-style-singleton": "^2.2.3", + "tslib": "^2.1.0", + "use-callback-ref": "^1.3.3", + "use-sidecar": "^1.1.3" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/react-remove-scroll-bar": { + "version": "2.3.8", + "resolved": "https://registry.npmjs.org/react-remove-scroll-bar/-/react-remove-scroll-bar-2.3.8.tgz", + "integrity": "sha512-9r+yi9+mgU33AKcj6IbT9oRCO78WriSj6t/cF8DWBZJ9aOGPOTEDvdUDz1FwKim7QXWwmHqtdHnRJfhAxEG46Q==", + "license": "MIT", + "dependencies": { + "react-style-singleton": "^2.2.2", + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/react-router": { "version": "7.9.6", "resolved": "https://registry.npmjs.org/react-router/-/react-router-7.9.6.tgz", @@ -5370,6 +5968,28 @@ "react-dom": ">=18" } }, + "node_modules/react-style-singleton": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/react-style-singleton/-/react-style-singleton-2.2.3.tgz", + "integrity": "sha512-b6jSvxvVnyptAiLjbkWLE/lOnR4lfTtDAl+eUC7RZy+QQWc6wRzIV2CE6xBuMmDxc2qIihtDCZD5NPOFl7fRBQ==", + "license": "MIT", + "dependencies": { + "get-nonce": "^1.0.0", + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/remark-parse": { "version": "11.0.0", "resolved": "https://registry.npmjs.org/remark-parse/-/remark-parse-11.0.0.tgz", @@ -5691,7 +6311,6 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -5770,7 +6389,6 @@ "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", "devOptional": true, "license": "Apache-2.0", - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -5938,6 +6556,49 @@ "punycode": "^2.1.0" } }, + "node_modules/use-callback-ref": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/use-callback-ref/-/use-callback-ref-1.3.3.tgz", + "integrity": "sha512-jQL3lRnocaFtu3V00JToYz/4QkNWswxijDaCVNZRiRTO3HQDLsdu1ZtmIUvV4yPp+rvWm5j0y0TG/S61cuijTg==", + "license": "MIT", + "dependencies": { + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/use-sidecar": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/use-sidecar/-/use-sidecar-1.1.3.tgz", + "integrity": "sha512-Fedw0aZvkhynoPYlA5WXrMCAMm+nSWdZt6lzJQ7Ok8S6Q+VsHmHpRWndVRJ8Be0ZbkfPc5LRYH+5XrzXcEeLRQ==", + "license": "MIT", + "dependencies": { + "detect-node-es": "^1.1.0", + "tslib": "^2.0.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/use-sync-external-store": { "version": "1.6.0", "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.6.0.tgz", @@ -5981,7 +6642,6 @@ "integrity": "sha512-BxAKBWmIbrDgrokdGZH1IgkIk/5mMHDreLDmCJ0qpyJaAteP8NvMhkwr/ZCQNqNH97bw/dANTE9PDzqwJghfMQ==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.5.0", @@ -6075,7 +6735,6 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, diff --git a/frontend/package.json b/frontend/package.json index 1f85631..5528d8b 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -10,6 +10,7 @@ "preview": "vite preview" }, "dependencies": { + "@radix-ui/react-select": "^2.2.6", "@tanstack/react-query": "^5.90.7", "axios": "^1.13.2", "class-variance-authority": "^0.7.0", diff --git a/frontend/src/components/PDFViewer.tsx b/frontend/src/components/PDFViewer.tsx index db4c8c6..cc65274 100644 --- a/frontend/src/components/PDFViewer.tsx +++ b/frontend/src/components/PDFViewer.tsx @@ -1,11 +1,17 @@ -import { useState, useMemo } from 'react' -import { Document, Page } from 'react-pdf' +import { useState, useCallback, useMemo, useRef, useEffect } from 'react' +import { Document, Page, pdfjs } from 'react-pdf' +import type { PDFDocumentProxy } from 'pdfjs-dist' import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card' import { Button } from '@/components/ui/button' -import { ChevronLeft, ChevronRight, ZoomIn, ZoomOut } from 'lucide-react' +import { ChevronLeft, ChevronRight, ZoomIn, ZoomOut, Loader2 } from 'lucide-react' import 'react-pdf/dist/Page/AnnotationLayer.css' import 'react-pdf/dist/Page/TextLayer.css' +// Configure standard font data URL for proper font rendering +const pdfOptions = { + standardFontDataUrl: `https://unpkg.com/pdfjs-dist@${pdfjs.version}/standard_fonts/`, +} + interface PDFViewerProps { title?: string pdfUrl: string @@ -17,41 +23,56 @@ export default function PDFViewer({ title, pdfUrl, className, httpHeaders }: PDF const [numPages, setNumPages] = useState(0) const [pageNumber, setPageNumber] = useState(1) const [scale, setScale] = useState(1.0) - const [loading, setLoading] = useState(true) + const [documentLoaded, setDocumentLoaded] = useState(false) const [error, setError] = useState(null) - // Memoize the file prop to prevent unnecessary reloads + // Store PDF document reference + const pdfDocRef = useRef(null) + + // Memoize file config to prevent unnecessary reloads const fileConfig = useMemo(() => { return httpHeaders ? { url: pdfUrl, httpHeaders } : pdfUrl }, [pdfUrl, httpHeaders]) - const onDocumentLoadSuccess = ({ numPages }: { numPages: number }) => { - setNumPages(numPages) - setLoading(false) + // Reset state when URL changes + useEffect(() => { + setDocumentLoaded(false) setError(null) - } + setNumPages(0) + setPageNumber(1) + pdfDocRef.current = null + }, [pdfUrl]) - const onDocumentLoadError = (error: Error) => { - console.error('Error loading PDF:', error) - setError('Failed to load PDF. Please try again later.') - setLoading(false) - } + const onDocumentLoadSuccess = useCallback((pdf: { numPages: number }) => { + pdfDocRef.current = pdf as unknown as PDFDocumentProxy + setNumPages(pdf.numPages) + setPageNumber(1) + setDocumentLoaded(true) + setError(null) + }, []) - const goToPreviousPage = () => { + const onDocumentLoadError = useCallback((err: Error) => { + console.error('Error loading PDF:', err) + setError('無法載入 PDF 檔案。請稍後再試。') + setDocumentLoaded(false) + pdfDocRef.current = null + }, []) + + const goToPreviousPage = useCallback(() => { setPageNumber((prev) => Math.max(prev - 1, 1)) - } + }, []) - const goToNextPage = () => { + const goToNextPage = useCallback(() => { setPageNumber((prev) => Math.min(prev + 1, numPages)) - } + }, [numPages]) - const zoomIn = () => { + const zoomIn = useCallback(() => { setScale((prev) => Math.min(prev + 0.2, 3.0)) - } + }, []) - const zoomOut = () => { + const zoomOut = useCallback(() => { setScale((prev) => Math.max(prev - 0.2, 0.5)) - } + }, []) return ( @@ -69,18 +90,18 @@ export default function PDFViewer({ title, pdfUrl, className, httpHeaders }: PDF variant="outline" size="sm" onClick={goToPreviousPage} - disabled={pageNumber <= 1 || loading} + disabled={pageNumber <= 1 || !documentLoaded} > - Page {pageNumber} of {numPages || '...'} + 第 {pageNumber} 頁 / 共 {numPages || '...'} 頁 @@ -92,7 +113,7 @@ export default function PDFViewer({ title, pdfUrl, className, httpHeaders }: PDF variant="outline" size="sm" onClick={zoomOut} - disabled={scale <= 0.5 || loading} + disabled={scale <= 0.5 || !documentLoaded} > @@ -103,7 +124,7 @@ export default function PDFViewer({ title, pdfUrl, className, httpHeaders }: PDF variant="outline" size="sm" onClick={zoomIn} - disabled={scale >= 3.0 || loading} + disabled={scale >= 3.0 || !documentLoaded} > @@ -113,39 +134,48 @@ export default function PDFViewer({ title, pdfUrl, className, httpHeaders }: PDF {/* PDF Document */}
- {loading && ( -
-
-
- )} - - {error && ( + {error ? (
-

Error

+

錯誤

{error}

- )} - - {!error && ( + ) : ( -
+
+ +

載入 PDF 中...

+
} > - + {documentLoaded && ( + + +
+ } + error={ +
+ 無法載入第 {pageNumber} 頁 +
+ } + /> + )} )} diff --git a/frontend/src/components/TaskNotFound.tsx b/frontend/src/components/TaskNotFound.tsx new file mode 100644 index 0000000..d791b1d --- /dev/null +++ b/frontend/src/components/TaskNotFound.tsx @@ -0,0 +1,46 @@ +import { useNavigate } from 'react-router-dom' +import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card' +import { Button } from '@/components/ui/button' +import { Trash2 } from 'lucide-react' + +interface TaskNotFoundProps { + taskId: string | null + onClearAndUpload: () => void +} + +export default function TaskNotFound({ taskId, onClearAndUpload }: TaskNotFoundProps) { + const navigate = useNavigate() + + const handleClick = () => { + onClearAndUpload() + navigate('/upload') + } + + return ( +
+ + +
+
+ +
+
+ 任務已刪除 +
+ +

+ 此任務已被刪除或不存在。請上傳新檔案以建立新任務。 +

+ {taskId && ( +

+ 任務 ID: {taskId} +

+ )} + +
+
+
+ ) +} diff --git a/frontend/src/components/ui/select.tsx b/frontend/src/components/ui/select.tsx index aa15d77..ea38493 100644 --- a/frontend/src/components/ui/select.tsx +++ b/frontend/src/components/ui/select.tsx @@ -1,12 +1,14 @@ import * as React from 'react' +import * as SelectPrimitive from '@radix-ui/react-select' import { cn } from '@/lib/utils' -import { ChevronDown } from 'lucide-react' +import { Check, ChevronDown, ChevronUp } from 'lucide-react' -export interface SelectProps extends React.SelectHTMLAttributes { +// Simple native select for backwards compatibility +export interface NativeSelectProps extends React.SelectHTMLAttributes { options: Array<{ value: string; label: string }> } -const Select = React.forwardRef( +const NativeSelect = React.forwardRef( ({ className, options, ...props }, ref) => { return (
@@ -33,6 +35,168 @@ const Select = React.forwardRef( ) } ) -Select.displayName = 'Select' +NativeSelect.displayName = 'NativeSelect' -export { Select } +const Select = SelectPrimitive.Root + +const SelectGroup = SelectPrimitive.Group + +const SelectValue = SelectPrimitive.Value + +const SelectTrigger = React.forwardRef< + React.ComponentRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + span]:line-clamp-1', + className + )} + {...props} + > + {children} + + + + +)) +SelectTrigger.displayName = SelectPrimitive.Trigger.displayName + +const SelectScrollUpButton = React.forwardRef< + React.ComponentRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + +)) +SelectScrollUpButton.displayName = SelectPrimitive.ScrollUpButton.displayName + +const SelectScrollDownButton = React.forwardRef< + React.ComponentRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + +)) +SelectScrollDownButton.displayName = SelectPrimitive.ScrollDownButton.displayName + +const SelectContent = React.forwardRef< + React.ComponentRef, + React.ComponentPropsWithoutRef +>(({ className, children, position = 'popper', ...props }, ref) => ( + + + + + {children} + + + + +)) +SelectContent.displayName = SelectPrimitive.Content.displayName + +const SelectLabel = React.forwardRef< + React.ComponentRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +SelectLabel.displayName = SelectPrimitive.Label.displayName + +const SelectItem = React.forwardRef< + React.ComponentRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + + + + + + + {children} + +)) +SelectItem.displayName = SelectPrimitive.Item.displayName + +const SelectSeparator = React.forwardRef< + React.ComponentRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +SelectSeparator.displayName = SelectPrimitive.Separator.displayName + +export { + Select, + SelectGroup, + SelectValue, + SelectTrigger, + SelectContent, + SelectLabel, + SelectItem, + SelectSeparator, + SelectScrollUpButton, + SelectScrollDownButton, + NativeSelect, +} diff --git a/frontend/src/hooks/useTaskValidation.ts b/frontend/src/hooks/useTaskValidation.ts new file mode 100644 index 0000000..bd3f415 --- /dev/null +++ b/frontend/src/hooks/useTaskValidation.ts @@ -0,0 +1,64 @@ +import { useEffect, useState } from 'react' +import { useQuery } from '@tanstack/react-query' +import { useUploadStore } from '@/store/uploadStore' +import { apiClientV2 } from '@/services/apiV2' +import type { TaskDetail } from '@/types/apiV2' + +interface UseTaskValidationResult { + taskId: string | null + taskDetail: TaskDetail | undefined + isLoading: boolean + isNotFound: boolean + clearAndReset: () => void +} + +/** + * Hook for validating task existence and handling deleted tasks gracefully. + * Shows loading state first, then either returns task data or marks as not found. + */ +export function useTaskValidation(options?: { + refetchInterval?: number | false | ((query: any) => number | false) +}): UseTaskValidationResult { + const { batchId, clearUpload } = useUploadStore() + const taskId = batchId ? String(batchId) : null + + const [isNotFound, setIsNotFound] = useState(false) + + const { data: taskDetail, isLoading, error, isFetching } = useQuery({ + queryKey: ['taskDetail', taskId], + queryFn: () => apiClientV2.getTask(taskId!), + enabled: !!taskId && !isNotFound, + retry: (failureCount, error: any) => { + // Don't retry on 404 + if (error?.response?.status === 404) { + return false + } + return failureCount < 2 + }, + refetchInterval: options?.refetchInterval ?? false, + // Disable stale time to ensure we check fresh data + staleTime: 0, + }) + + // Handle 404 error - mark as not found immediately + useEffect(() => { + if (error && (error as any)?.response?.status === 404) { + setIsNotFound(true) + } + }, [error]) + + // Clear state and store + const clearAndReset = () => { + clearUpload() + setIsNotFound(false) + } + + return { + taskId, + taskDetail, + // Show loading if we have a taskId and are still fetching (but not if already marked as not found) + isLoading: !!taskId && !isNotFound && (isLoading || isFetching) && !taskDetail, + isNotFound, + clearAndReset, + } +} diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index 0c5d6d4..7def4e4 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -1,4 +1,3 @@ -import { StrictMode } from 'react' import { createRoot } from 'react-dom/client' import { BrowserRouter } from 'react-router-dom' import { QueryClient, QueryClientProvider } from '@tanstack/react-query' @@ -10,8 +9,8 @@ import App from './App.tsx' // Configure PDF.js worker for react-pdf import { pdfjs } from 'react-pdf' -// Use the worker from react-pdf's bundled pdfjs-dist -pdfjs.GlobalWorkerOptions.workerSrc = `//unpkg.com/pdfjs-dist@${pdfjs.version}/build/pdf.worker.min.mjs` +// Use CDN for the worker (most reliable for Vite) +pdfjs.GlobalWorkerOptions.workerSrc = `https://unpkg.com/pdfjs-dist@${pdfjs.version}/build/pdf.worker.min.mjs` // Create React Query client const queryClient = new QueryClient({ @@ -24,16 +23,16 @@ const queryClient = new QueryClient({ }, }) +// Note: StrictMode disabled due to react-pdf incompatibility +// StrictMode's double-invocation causes PDF worker race conditions createRoot(document.getElementById('root')!).render( - - - - - - - - - - - , + + + + + + + + + , ) diff --git a/frontend/src/pages/ProcessingPage.tsx b/frontend/src/pages/ProcessingPage.tsx index 392ff4b..265ee9e 100644 --- a/frontend/src/pages/ProcessingPage.tsx +++ b/frontend/src/pages/ProcessingPage.tsx @@ -1,26 +1,35 @@ -import { useEffect, useState } from 'react' +import { useState, useEffect } from 'react' import { useNavigate } from 'react-router-dom' import { useTranslation } from 'react-i18next' -import { useQuery, useMutation } from '@tanstack/react-query' +import { useMutation } from '@tanstack/react-query' import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card' import { Progress } from '@/components/ui/progress' import { Button } from '@/components/ui/button' import { Badge } from '@/components/ui/badge' import { useToast } from '@/components/ui/toast' -import { useUploadStore } from '@/store/uploadStore' import { apiClientV2 } from '@/services/apiV2' import { Play, CheckCircle, FileText, AlertCircle, Clock, Activity, Loader2 } from 'lucide-react' import PPStructureParams from '@/components/PPStructureParams' +import TaskNotFound from '@/components/TaskNotFound' +import { useTaskValidation } from '@/hooks/useTaskValidation' import type { PPStructureV3Params, ProcessingOptions } from '@/types/apiV2' export default function ProcessingPage() { const { t } = useTranslation() const navigate = useNavigate() const { toast } = useToast() - const { batchId } = useUploadStore() - // In V2, batchId is actually a task_id (string) - const taskId = batchId ? String(batchId) : null + // Use shared hook for task validation + const { taskId, taskDetail, isLoading: isValidating, isNotFound, clearAndReset } = useTaskValidation({ + refetchInterval: (query) => { + const data = query.state.data + if (!data) return 2000 + if (data.status === 'completed' || data.status === 'failed') { + return false + } + return 2000 + }, + }) // PP-StructureV3 parameters state const [ppStructureParams, setPpStructureParams] = useState({}) @@ -56,22 +65,6 @@ export default function ProcessingPage() { }, }) - // Poll task status - const { data: taskDetail } = useQuery({ - queryKey: ['taskDetail', taskId], - queryFn: () => apiClientV2.getTask(taskId!), - enabled: !!taskId, - refetchInterval: (query) => { - const data = query.state.data - if (!data) return 2000 - // Stop polling if completed or failed - if (data.status === 'completed' || data.status === 'failed') { - return false - } - return 2000 // Poll every 2 seconds - }, - }) - // Auto-redirect when completed useEffect(() => { if (taskDetail?.status === 'completed') { @@ -115,6 +108,23 @@ export default function ProcessingPage() { } } + // Show loading while validating task + if (isValidating) { + return ( +
+
+ +

載入任務資訊...

+
+
+ ) + } + + // Show message when task was deleted + if (isNotFound) { + return + } + // Show helpful message when no task is selected if (!taskId) { return ( diff --git a/frontend/src/pages/ResultsPage.tsx b/frontend/src/pages/ResultsPage.tsx index 33cd9b9..bdc6a0a 100644 --- a/frontend/src/pages/ResultsPage.tsx +++ b/frontend/src/pages/ResultsPage.tsx @@ -1,29 +1,23 @@ +import { useMemo } from 'react' import { useNavigate } from 'react-router-dom' import { useTranslation } from 'react-i18next' -import { useQuery } from '@tanstack/react-query' import { Button } from '@/components/ui/button' import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card' import PDFViewer from '@/components/PDFViewer' import { useToast } from '@/components/ui/toast' -import { useUploadStore } from '@/store/uploadStore' import { apiClientV2 } from '@/services/apiV2' import { FileText, Download, AlertCircle, TrendingUp, Clock, Layers, FileJson, Loader2 } from 'lucide-react' import { Badge } from '@/components/ui/badge' +import TaskNotFound from '@/components/TaskNotFound' +import { useTaskValidation } from '@/hooks/useTaskValidation' export default function ResultsPage() { const { t } = useTranslation() const navigate = useNavigate() const { toast } = useToast() - const { batchId } = useUploadStore() - // In V2, batchId is actually a task_id (string) - const taskId = batchId ? String(batchId) : null - - // Get task details - const { data: taskDetail, isLoading } = useQuery({ - queryKey: ['taskDetail', taskId], - queryFn: () => apiClientV2.getTask(taskId!), - enabled: !!taskId, + // Use shared hook for task validation + const { taskId, taskDetail, isLoading, isNotFound, clearAndReset } = useTaskValidation({ refetchInterval: (query) => { const data = query.state.data if (!data) return 2000 @@ -34,6 +28,19 @@ export default function ResultsPage() { }, }) + // Construct PDF URL for preview - memoize to prevent unnecessary reloads + // Must be called unconditionally before any early returns (React hooks rule) + const API_BASE_URL = import.meta.env.VITE_API_BASE_URL || 'http://localhost:8000' + const pdfUrl = useMemo(() => { + return taskId ? `${API_BASE_URL}/api/v2/tasks/${taskId}/download/pdf` : '' + }, [taskId, API_BASE_URL]) + + // Get auth token for PDF preview - memoize to prevent new object reference each render + const pdfHttpHeaders = useMemo(() => { + const authToken = localStorage.getItem('auth_token_v2') + return authToken ? { Authorization: `Bearer ${authToken}` } : undefined + }, []) + const handleDownloadPDF = async () => { if (!taskId) return try { @@ -101,6 +108,23 @@ export default function ResultsPage() { } } + // Show loading while validating task + if (isLoading) { + return ( +
+
+ +

載入任務結果...

+
+
+ ) + } + + // Show message when task was deleted + if (isNotFound) { + return + } + // Show helpful message when no task is selected if (!taskId) { return ( @@ -127,17 +151,7 @@ export default function ResultsPage() { ) } - if (isLoading) { - return ( -
-
- -

載入任務結果...

-
-
- ) - } - + // Fallback for no task detail (shouldn't happen with proper validation) if (!taskDetail) { return (
@@ -157,14 +171,6 @@ export default function ResultsPage() { const isCompleted = taskDetail.status === 'completed' - // Construct PDF URL for preview - const API_BASE_URL = import.meta.env.VITE_API_BASE_URL || 'http://localhost:8000' - const pdfUrl = taskId ? `${API_BASE_URL}/api/v2/tasks/${taskId}/download/pdf` : '' - - // Get auth token for PDF preview - const authToken = localStorage.getItem('auth_token_v2') - const pdfHttpHeaders = authToken ? { Authorization: `Bearer ${authToken}` } : undefined - return (
{/* Page Header */} diff --git a/frontend/src/pages/TaskDetailPage.tsx b/frontend/src/pages/TaskDetailPage.tsx index ae8b523..2c5b6ff 100644 --- a/frontend/src/pages/TaskDetailPage.tsx +++ b/frontend/src/pages/TaskDetailPage.tsx @@ -1,3 +1,4 @@ +import { useMemo } from 'react' import { useParams, useNavigate } from 'react-router-dom' import { useTranslation } from 'react-i18next' import { useQuery } from '@tanstack/react-query' @@ -65,6 +66,19 @@ export default function TaskDetailPage() { retry: false, }) + // Construct PDF URL for preview - memoize to prevent unnecessary reloads + // Must be called unconditionally before any early returns (React hooks rule) + const API_BASE_URL = import.meta.env.VITE_API_BASE_URL || 'http://localhost:8000' + const pdfUrl = useMemo(() => { + return taskId ? `${API_BASE_URL}/api/v2/tasks/${taskId}/download/pdf` : '' + }, [taskId, API_BASE_URL]) + + // Get auth token for PDF preview - memoize to prevent new object reference each render + const pdfHttpHeaders = useMemo(() => { + const authToken = localStorage.getItem('auth_token_v2') + return authToken ? { Authorization: `Bearer ${authToken}` } : undefined + }, []) + const getTrackBadge = (track?: ProcessingTrack) => { if (!track) return null switch (track) { @@ -218,14 +232,6 @@ export default function TaskDetailPage() { const isProcessing = taskDetail.status === 'processing' const isFailed = taskDetail.status === 'failed' - // Construct PDF URL for preview - const API_BASE_URL = import.meta.env.VITE_API_BASE_URL || 'http://localhost:8000' - const pdfUrl = taskId ? `${API_BASE_URL}/api/v2/tasks/${taskId}/download/pdf` : '' - - // Get auth token for PDF preview - const authToken = localStorage.getItem('auth_token_v2') - const pdfHttpHeaders = authToken ? { Authorization: `Bearer ${authToken}` } : undefined - return (
{/* Page Header */} diff --git a/frontend/src/pages/TaskHistoryPage.tsx b/frontend/src/pages/TaskHistoryPage.tsx index e68f4da..deffc66 100644 --- a/frontend/src/pages/TaskHistoryPage.tsx +++ b/frontend/src/pages/TaskHistoryPage.tsx @@ -28,7 +28,7 @@ import { TableHeader, TableRow, } from '@/components/ui/table' -import { Select } from '@/components/ui/select' +import { NativeSelect } from '@/components/ui/select' import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' export default function TaskHistoryPage() { @@ -112,6 +112,43 @@ export default function TaskHistoryPage() { } } + // Delete all tasks + const handleDeleteAll = async () => { + if (tasks.length === 0) { + alert('沒有可刪除的任務') + return + } + + if (!confirm(`確定要刪除所有 ${total} 個任務嗎?此操作無法復原!`)) return + + try { + setLoading(true) + // Delete tasks one by one + for (const task of tasks) { + await apiClientV2.deleteTask(task.task_id) + } + // If there are more pages, keep fetching and deleting + let hasMoreTasks = hasMore + while (hasMoreTasks) { + const response = await apiClientV2.listTasks({ page: 1, page_size: 100 }) + if (response.tasks.length === 0) break + for (const task of response.tasks) { + await apiClientV2.deleteTask(task.task_id) + } + hasMoreTasks = response.has_more + } + fetchTasks() + fetchStats() + alert('所有任務已刪除') + } catch (err: any) { + alert(err.response?.data?.detail || '刪除任務失敗') + fetchTasks() + fetchStats() + } finally { + setLoading(false) + } + } + // View task details const handleViewDetails = (taskId: string) => { navigate(`/tasks/${taskId}`) @@ -220,10 +257,16 @@ export default function TaskHistoryPage() {

任務歷史

查看和管理您的 OCR 任務

- +
+ + +
{/* Statistics */} @@ -288,7 +331,7 @@ export default function TaskHistoryPage() {
-