diff --git a/backend/app/core/config.py b/backend/app/core/config.py
index 20c6722..9ffc019 100644
--- a/backend/app/core/config.py
+++ b/backend/app/core/config.py
@@ -104,6 +104,37 @@ class Settings(BaseSettings):
enable_cudnn_benchmark: bool = Field(default=True) # Optimize convolution algorithms
num_threads: int = Field(default=4) # CPU threads for preprocessing
+ # ===== Enhanced Memory Management Configuration =====
+ # Memory thresholds (as ratio of total GPU memory)
+ memory_warning_threshold: float = Field(default=0.80) # 80% - start warning
+ memory_critical_threshold: float = Field(default=0.95) # 95% - throttle operations
+ memory_emergency_threshold: float = Field(default=0.98) # 98% - emergency cleanup
+
+ # Memory monitoring
+ memory_check_interval_seconds: int = Field(default=30) # Background check interval
+ enable_memory_alerts: bool = Field(default=True) # Enable memory alerts
+
+ # Model lifecycle management
+ enable_model_lifecycle_management: bool = Field(default=True) # Use ModelManager
+ pp_structure_idle_timeout_seconds: int = Field(default=300) # Unload PP-Structure after idle
+ structure_model_memory_mb: int = Field(default=2000) # Estimated memory for PP-StructureV3
+ ocr_model_memory_mb: int = Field(default=500) # Estimated memory per OCR language model
+
+ # Service pool configuration
+ enable_service_pool: bool = Field(default=True) # Use OCRServicePool
+ max_services_per_device: int = Field(default=1) # Max OCRService per GPU
+ max_total_services: int = Field(default=2) # Max total OCRService instances
+ service_acquire_timeout_seconds: float = Field(default=300.0) # Timeout for acquiring service
+ max_queue_size: int = Field(default=50) # Max pending tasks per device
+
+ # Concurrency control
+ max_concurrent_predictions: int = Field(default=2) # Max concurrent PP-StructureV3 predictions
+ enable_cpu_fallback: bool = Field(default=True) # Fall back to CPU when GPU memory low
+
+ # Emergency recovery
+ enable_emergency_cleanup: bool = Field(default=True) # Auto-cleanup on memory pressure
+ enable_worker_restart: bool = Field(default=False) # Restart workers on OOM (requires supervisor)
+
# ===== File Upload Configuration =====
max_upload_size: int = Field(default=52428800) # 50MB
allowed_extensions: str = Field(default="png,jpg,jpeg,pdf,bmp,tiff,doc,docx,ppt,pptx")
diff --git a/backend/app/main.py b/backend/app/main.py
index e888224..51e8eef 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -7,10 +7,103 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import logging
+import signal
+import sys
+import asyncio
from pathlib import Path
+from typing import Optional
from app.core.config import settings
+
+# =============================================================================
+# Section 6.1: Signal Handlers
+# =============================================================================
+
+# Flag to indicate graceful shutdown is in progress
+_shutdown_requested = False
+_shutdown_complete = asyncio.Event()
+
+# Track active connections for draining
+_active_connections = 0
+_connection_lock = asyncio.Lock()
+
+
+async def increment_connections():
+ """Track active connection count"""
+ global _active_connections
+ async with _connection_lock:
+ _active_connections += 1
+
+
+async def decrement_connections():
+ """Track active connection count"""
+ global _active_connections
+ async with _connection_lock:
+ _active_connections -= 1
+
+
+def get_active_connections() -> int:
+ """Get current active connection count"""
+ return _active_connections
+
+
+def is_shutdown_requested() -> bool:
+ """Check if graceful shutdown has been requested"""
+ return _shutdown_requested
+
+
+def _signal_handler(signum: int, frame) -> None:
+ """
+ Signal handler for SIGTERM and SIGINT.
+
+ Initiates graceful shutdown by setting the shutdown flag.
+ The actual cleanup is handled by the lifespan context manager.
+ """
+ global _shutdown_requested
+ signal_name = signal.Signals(signum).name
+ logger = logging.getLogger(__name__)
+
+ if _shutdown_requested:
+ logger.warning(f"Received {signal_name} again, forcing immediate exit...")
+ sys.exit(1)
+
+ logger.info(f"Received {signal_name}, initiating graceful shutdown...")
+ _shutdown_requested = True
+
+ # Try to stop the event loop gracefully
+ try:
+ loop = asyncio.get_event_loop()
+ if loop.is_running():
+ # Schedule shutdown event
+ loop.call_soon_threadsafe(_shutdown_complete.set)
+ except RuntimeError:
+ pass # No event loop running
+
+
+def setup_signal_handlers():
+ """
+ Set up signal handlers for graceful shutdown.
+
+ Handles:
+ - SIGTERM: Standard termination signal (from systemd, docker, etc.)
+ - SIGINT: Keyboard interrupt (Ctrl+C)
+ """
+ logger = logging.getLogger(__name__)
+
+ try:
+ # SIGTERM - Standard termination signal
+ signal.signal(signal.SIGTERM, _signal_handler)
+ logger.info("SIGTERM handler installed")
+
+ # SIGINT - Keyboard interrupt
+ signal.signal(signal.SIGINT, _signal_handler)
+ logger.info("SIGINT handler installed")
+
+ except (ValueError, OSError) as e:
+ # Signal handling may not be available in all contexts
+ logger.warning(f"Could not install signal handlers: {e}")
+
# Ensure log directory exists before configuring logging
Path(settings.log_file).parent.mkdir(parents=True, exist_ok=True)
@@ -38,16 +131,91 @@ for logger_name in ['app', 'app.services', 'app.services.pdf_generator_service',
logger = logging.getLogger(__name__)
+async def drain_connections(timeout: float = 30.0):
+ """
+ Wait for active connections to complete (connection draining).
+
+ Args:
+ timeout: Maximum time to wait for connections to drain
+ """
+ logger.info(f"Draining connections (timeout={timeout}s)...")
+ start_time = asyncio.get_event_loop().time()
+
+ while get_active_connections() > 0:
+ elapsed = asyncio.get_event_loop().time() - start_time
+ if elapsed >= timeout:
+ logger.warning(
+ f"Connection drain timeout after {timeout}s. "
+ f"{get_active_connections()} connections still active."
+ )
+ break
+
+ logger.info(f"Waiting for {get_active_connections()} active connections...")
+ await asyncio.sleep(1.0)
+
+ if get_active_connections() == 0:
+ logger.info("All connections drained successfully")
+
+
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan events"""
# Startup
logger.info("Starting Tool_OCR V2 application...")
+ # Set up signal handlers for graceful shutdown
+ setup_signal_handlers()
+
# Ensure all directories exist
settings.ensure_directories()
logger.info("All directories created/verified")
+ # Initialize memory management if enabled
+ if settings.enable_model_lifecycle_management:
+ try:
+ from app.services.memory_manager import get_model_manager, MemoryConfig
+
+ memory_config = MemoryConfig(
+ warning_threshold=settings.memory_warning_threshold,
+ critical_threshold=settings.memory_critical_threshold,
+ emergency_threshold=settings.memory_emergency_threshold,
+ model_idle_timeout_seconds=settings.pp_structure_idle_timeout_seconds,
+ memory_check_interval_seconds=settings.memory_check_interval_seconds,
+ enable_auto_cleanup=settings.enable_memory_optimization,
+ enable_emergency_cleanup=settings.enable_emergency_cleanup,
+ max_concurrent_predictions=settings.max_concurrent_predictions,
+ enable_cpu_fallback=settings.enable_cpu_fallback,
+ gpu_memory_limit_mb=settings.gpu_memory_limit_mb,
+ )
+ get_model_manager(memory_config)
+ logger.info("Memory management initialized")
+ except Exception as e:
+ logger.warning(f"Failed to initialize memory management: {e}")
+
+ # Initialize service pool if enabled
+ if settings.enable_service_pool:
+ try:
+ from app.services.service_pool import get_service_pool, PoolConfig
+
+ pool_config = PoolConfig(
+ max_services_per_device=settings.max_services_per_device,
+ max_total_services=settings.max_total_services,
+ acquire_timeout_seconds=settings.service_acquire_timeout_seconds,
+ max_queue_size=settings.max_queue_size,
+ )
+ get_service_pool(pool_config)
+ logger.info("OCR service pool initialized")
+ except Exception as e:
+ logger.warning(f"Failed to initialize service pool: {e}")
+
+ # Initialize prediction semaphore for controlling concurrent PP-StructureV3 predictions
+ try:
+ from app.services.memory_manager import get_prediction_semaphore
+ get_prediction_semaphore(max_concurrent=settings.max_concurrent_predictions)
+ logger.info(f"Prediction semaphore initialized (max_concurrent={settings.max_concurrent_predictions})")
+ except Exception as e:
+ logger.warning(f"Failed to initialize prediction semaphore: {e}")
+
logger.info("Application startup complete")
yield
@@ -55,6 +223,45 @@ async def lifespan(app: FastAPI):
# Shutdown
logger.info("Shutting down Tool_OCR application...")
+ # Connection draining - wait for active requests to complete
+ await drain_connections(timeout=30.0)
+
+ # Shutdown recovery manager if initialized
+ try:
+ from app.services.memory_manager import shutdown_recovery_manager
+ shutdown_recovery_manager()
+ logger.info("Recovery manager shutdown complete")
+ except Exception as e:
+ logger.debug(f"Recovery manager shutdown skipped: {e}")
+
+ # Shutdown service pool
+ if settings.enable_service_pool:
+ try:
+ from app.services.service_pool import shutdown_service_pool
+ shutdown_service_pool()
+ logger.info("Service pool shutdown complete")
+ except Exception as e:
+ logger.warning(f"Error shutting down service pool: {e}")
+
+ # Shutdown prediction semaphore
+ try:
+ from app.services.memory_manager import shutdown_prediction_semaphore
+ shutdown_prediction_semaphore()
+ logger.info("Prediction semaphore shutdown complete")
+ except Exception as e:
+ logger.warning(f"Error shutting down prediction semaphore: {e}")
+
+ # Shutdown memory manager
+ if settings.enable_model_lifecycle_management:
+ try:
+ from app.services.memory_manager import shutdown_model_manager
+ shutdown_model_manager()
+ logger.info("Memory manager shutdown complete")
+ except Exception as e:
+ logger.warning(f"Error shutting down memory manager: {e}")
+
+ logger.info("Tool_OCR shutdown complete")
+
# Create FastAPI application
app = FastAPI(
@@ -77,9 +284,7 @@ app.add_middleware(
# Health check endpoint
@app.get("/health")
async def health_check():
- """Health check endpoint with GPU status"""
- from app.services.ocr_service import OCRService
-
+ """Health check endpoint with GPU status and memory management info"""
response = {
"status": "healthy",
"service": "Tool_OCR V2",
@@ -88,10 +293,31 @@ async def health_check():
# Add GPU status information
try:
- # Create temporary OCRService instance to get GPU status
- # In production, this should be a singleton service
- ocr_service = OCRService()
- gpu_status = ocr_service.get_gpu_status()
+ # Use service pool if available to avoid creating new instances
+ gpu_status = None
+ if settings.enable_service_pool:
+ try:
+ from app.services.service_pool import get_service_pool
+ pool = get_service_pool()
+ pool_stats = pool.get_pool_stats()
+ response["service_pool"] = pool_stats
+
+ # Get GPU status from first available service
+ for device, services in pool.services.items():
+ for pooled in services:
+ if hasattr(pooled.service, 'get_gpu_status'):
+ gpu_status = pooled.service.get_gpu_status()
+ break
+ if gpu_status:
+ break
+ except Exception as e:
+ logger.debug(f"Could not get service pool stats: {e}")
+
+ # Fallback: create temporary instance if no pool or no service available
+ if gpu_status is None:
+ from app.services.ocr_service import OCRService
+ ocr_service = OCRService()
+ gpu_status = ocr_service.get_gpu_status()
response["gpu"] = {
"available": gpu_status.get("gpu_available", False),
@@ -120,6 +346,15 @@ async def health_check():
"error": str(e),
}
+ # Add memory management status
+ if settings.enable_model_lifecycle_management:
+ try:
+ from app.services.memory_manager import get_model_manager
+ model_manager = get_model_manager()
+ response["memory_management"] = model_manager.get_model_stats()
+ except Exception as e:
+ logger.debug(f"Could not get memory management stats: {e}")
+
return response
diff --git a/backend/app/models/unified_document.py b/backend/app/models/unified_document.py
index ce236ea..bd7cd72 100644
--- a/backend/app/models/unified_document.py
+++ b/backend/app/models/unified_document.py
@@ -212,26 +212,44 @@ class TableData:
if self.caption:
html.append(f"
{self.caption}
")
- # Group cells by row
- rows_data = {}
+ # Group cells by row and column for quick lookup
+ cell_map = {}
for cell in self.cells:
- if cell.row not in rows_data:
- rows_data[cell.row] = []
- rows_data[cell.row].append(cell)
+ cell_map[(cell.row, cell.col)] = cell
- # Generate HTML
+ # Track which cells are covered by row/col spans
+ covered = set()
+ for cell in self.cells:
+ if cell.row_span > 1 or cell.col_span > 1:
+ for r in range(cell.row, cell.row + cell.row_span):
+ for c in range(cell.col, cell.col + cell.col_span):
+ if (r, c) != (cell.row, cell.col):
+ covered.add((r, c))
+
+ # Generate HTML with proper column filling
for row_idx in range(self.rows):
html.append("
")
- if row_idx in rows_data:
- for cell in sorted(rows_data[row_idx], key=lambda c: c.col):
+ for col_idx in range(self.cols):
+ # Skip cells covered by row/col spans
+ if (row_idx, col_idx) in covered:
+ continue
+
+ cell = cell_map.get((row_idx, col_idx))
+ tag = "th" if row_idx == 0 and self.headers else "td"
+
+ if cell:
span_attrs = []
if cell.row_span > 1:
span_attrs.append(f'rowspan="{cell.row_span}"')
if cell.col_span > 1:
span_attrs.append(f'colspan="{cell.col_span}"')
span_str = " ".join(span_attrs)
- tag = "th" if row_idx == 0 and self.headers else "td"
- html.append(f'<{tag} {span_str}>{cell.content}{tag}>')
+ content = cell.content if cell.content else ""
+ html.append(f'<{tag} {span_str}>{content}{tag}>')
+ else:
+ # Fill in empty cell for missing positions
+ html.append(f'<{tag}>{tag}>')
+
html.append("
")
html.append("")
diff --git a/backend/app/routers/tasks.py b/backend/app/routers/tasks.py
index 8e0505c..38a170b 100644
--- a/backend/app/routers/tasks.py
+++ b/backend/app/routers/tasks.py
@@ -39,6 +39,7 @@ from app.schemas.task import (
from app.services.task_service import task_service
from app.services.file_access_service import file_access_service
from app.services.ocr_service import OCRService
+from app.services.service_pool import get_service_pool, PoolConfig
# Import dual-track components
try:
@@ -47,6 +48,13 @@ try:
except ImportError:
DUAL_TRACK_AVAILABLE = False
+# Service pool availability
+SERVICE_POOL_AVAILABLE = True
+try:
+ from app.services.memory_manager import get_model_manager
+except ImportError:
+ SERVICE_POOL_AVAILABLE = False
+
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v2/tasks", tags=["Tasks"])
@@ -63,7 +71,10 @@ def process_task_ocr(
pp_structure_params: Optional[dict] = None
):
"""
- Background task to process OCR for a task with dual-track support
+ Background task to process OCR for a task with dual-track support.
+
+ Uses OCRServicePool to acquire a shared service instance instead of
+ creating a new one, preventing GPU memory proliferation.
Args:
task_id: Task UUID string
@@ -80,6 +91,7 @@ def process_task_ocr(
db = SessionLocal()
start_time = datetime.now()
+ pooled_service = None
try:
logger.info(f"Starting OCR processing for task {task_id}, file: {filename}")
@@ -91,16 +103,39 @@ def process_task_ocr(
logger.error(f"Task {task_id} not found in database")
return
- # Initialize OCR service
- ocr_service = OCRService()
+ # Acquire OCR service from pool (or create new if pool disabled)
+ ocr_service = None
+ if settings.enable_service_pool and SERVICE_POOL_AVAILABLE:
+ try:
+ service_pool = get_service_pool()
+ pooled_service = service_pool.acquire(
+ device="GPU:0",
+ timeout=settings.service_acquire_timeout_seconds,
+ task_id=task_id
+ )
+ if pooled_service:
+ ocr_service = pooled_service.service
+ logger.info(f"Acquired OCR service from pool for task {task_id}")
+ else:
+ logger.warning(f"Timeout acquiring service from pool, creating new instance")
+ except Exception as e:
+ logger.warning(f"Failed to acquire service from pool: {e}, creating new instance")
+
+ # Fallback: create new instance if pool acquisition failed
+ if ocr_service is None:
+ logger.info("Creating new OCRService instance (pool disabled or unavailable)")
+ ocr_service = OCRService()
# Create result directory before OCR processing (needed for saving extracted images)
result_dir = Path(settings.result_dir) / task_id
result_dir.mkdir(parents=True, exist_ok=True)
- # Process the file with OCR (use dual-track if available)
- if use_dual_track and hasattr(ocr_service, 'process'):
- # Use new dual-track processing
+ # Process the file with OCR
+ # Use dual-track processing if:
+ # 1. use_dual_track is True (auto-detection)
+ # 2. OR force_track is specified (explicit track selection)
+ if (use_dual_track or force_track) and hasattr(ocr_service, 'process'):
+ # Use new dual-track processing (or forced track)
ocr_result = ocr_service.process(
file_path=Path(file_path),
lang=language,
@@ -111,7 +146,7 @@ def process_task_ocr(
pp_structure_params=pp_structure_params
)
else:
- # Fall back to traditional processing
+ # Fall back to traditional processing (no force_track support)
ocr_result = ocr_service.process_image(
image_path=Path(file_path),
lang=language,
@@ -131,6 +166,16 @@ def process_task_ocr(
source_file_path=Path(file_path)
)
+ # Release service back to pool (success case)
+ if pooled_service:
+ try:
+ service_pool = get_service_pool()
+ service_pool.release(pooled_service, error=None)
+ logger.info(f"Released OCR service back to pool for task {task_id}")
+ pooled_service = None # Prevent double release in finally
+ except Exception as e:
+ logger.warning(f"Failed to release service to pool: {e}")
+
# Close old session and create fresh one to avoid MySQL timeout
# (long OCR processing may cause connection to become stale)
db.close()
@@ -158,6 +203,15 @@ def process_task_ocr(
except Exception as e:
logger.exception(f"OCR processing failed for task {task_id}")
+ # Release service back to pool with error
+ if pooled_service:
+ try:
+ service_pool = get_service_pool()
+ service_pool.release(pooled_service, error=e)
+ pooled_service = None
+ except Exception as release_error:
+ logger.warning(f"Failed to release service to pool: {release_error}")
+
# Update task status to failed (direct database update)
try:
task = db.query(Task).filter(Task.id == task_db_id).first()
@@ -170,6 +224,13 @@ def process_task_ocr(
logger.error(f"Failed to update task status: {update_error}")
finally:
+ # Ensure service is released in case of any missed release
+ if pooled_service:
+ try:
+ service_pool = get_service_pool()
+ service_pool.release(pooled_service, error=None)
+ except Exception:
+ pass
db.close()
@@ -330,7 +391,13 @@ async def get_task(
with open(result_path) as f:
result_data = json.load(f)
metadata = result_data.get("metadata", {})
- processing_track = metadata.get("processing_track")
+ track_str = metadata.get("processing_track")
+ # Convert string to enum to avoid Pydantic serialization warning
+ if track_str:
+ try:
+ processing_track = ProcessingTrackEnum(track_str)
+ except ValueError:
+ processing_track = None
except Exception:
pass # Silently ignore errors reading the result file
diff --git a/backend/app/services/direct_extraction_engine.py b/backend/app/services/direct_extraction_engine.py
index 0ed9405..d0b23b9 100644
--- a/backend/app/services/direct_extraction_engine.py
+++ b/backend/app/services/direct_extraction_engine.py
@@ -247,9 +247,11 @@ class DirectExtractionEngine:
element_counter += len(image_elements)
# Extract vector graphics (charts, diagrams) from drawing commands
+ # Pass table_bboxes to filter out table border drawings before clustering
if self.enable_image_extraction:
vector_elements = self._extract_vector_graphics(
- page, page_num, document_id, element_counter, output_dir
+ page, page_num, document_id, element_counter, output_dir,
+ table_bboxes=table_bboxes
)
elements.extend(vector_elements)
element_counter += len(vector_elements)
@@ -705,40 +707,52 @@ class DirectExtractionEngine:
y1=bbox_data[3]
)
- # Extract column widths from table cells
+ # Extract column widths from table cells by analyzing X boundaries
column_widths = []
if hasattr(table, 'cells') and table.cells:
- # Group cells by column
- cols_x = {}
+ # Collect all unique X boundaries (both left and right edges)
+ x_boundaries = set()
for cell in table.cells:
- col_idx = None
- # Determine column index by x0 position
- for idx, x0 in enumerate(sorted(set(c[0] for c in table.cells))):
- if abs(cell[0] - x0) < 1.0: # Within 1pt tolerance
- col_idx = idx
- break
+ x_boundaries.add(round(cell[0], 1)) # x0 (left edge)
+ x_boundaries.add(round(cell[2], 1)) # x1 (right edge)
- if col_idx is not None:
- if col_idx not in cols_x:
- cols_x[col_idx] = {'x0': cell[0], 'x1': cell[2]}
- else:
- cols_x[col_idx]['x1'] = max(cols_x[col_idx]['x1'], cell[2])
+ # Sort boundaries to get column edges
+ sorted_x = sorted(x_boundaries)
- # Calculate width for each column
- for col_idx in sorted(cols_x.keys()):
- width = cols_x[col_idx]['x1'] - cols_x[col_idx]['x0']
- column_widths.append(width)
+ # Calculate column widths from adjacent boundaries
+ if len(sorted_x) >= 2:
+ column_widths = [sorted_x[i+1] - sorted_x[i] for i in range(len(sorted_x)-1)]
+ logger.debug(f"Calculated column widths from {len(sorted_x)} boundaries: {column_widths}")
+
+ # Extract row heights from table cells by analyzing Y boundaries
+ row_heights = []
+ if hasattr(table, 'cells') and table.cells:
+ # Collect all unique Y boundaries (both top and bottom edges)
+ y_boundaries = set()
+ for cell in table.cells:
+ y_boundaries.add(round(cell[1], 1)) # y0 (top edge)
+ y_boundaries.add(round(cell[3], 1)) # y1 (bottom edge)
+
+ # Sort boundaries to get row edges
+ sorted_y = sorted(y_boundaries)
+
+ # Calculate row heights from adjacent boundaries
+ if len(sorted_y) >= 2:
+ row_heights = [sorted_y[i+1] - sorted_y[i] for i in range(len(sorted_y)-1)]
+ logger.debug(f"Calculated row heights from {len(sorted_y)} boundaries: {row_heights}")
# Create table cells
+ # Note: Include ALL cells (even empty ones) to preserve table structure
+ # This is critical for correct HTML generation and PDF rendering
cells = []
for row_idx, row in enumerate(data):
for col_idx, cell_text in enumerate(row):
- if cell_text:
- cells.append(TableCell(
- row=row_idx,
- col=col_idx,
- content=str(cell_text) if cell_text else ""
- ))
+ # Always add cell, even if empty, to maintain table structure
+ cells.append(TableCell(
+ row=row_idx,
+ col=col_idx,
+ content=str(cell_text) if cell_text else ""
+ ))
# Create table data
table_data = TableData(
@@ -748,8 +762,13 @@ class DirectExtractionEngine:
headers=data[0] if data else None # Assume first row is header
)
- # Store column widths in metadata
- metadata = {"column_widths": column_widths} if column_widths else None
+ # Store column widths and row heights in metadata
+ metadata = {}
+ if column_widths:
+ metadata["column_widths"] = column_widths
+ if row_heights:
+ metadata["row_heights"] = row_heights
+ metadata = metadata if metadata else None
return DocumentElement(
element_id=f"table_{page_num}_{counter}",
@@ -978,7 +997,9 @@ class DirectExtractionEngine:
image_filename = f"{document_id}_p{page_num}_img{img_idx}.png"
image_path = output_dir / image_filename
pix.save(str(image_path))
- image_data["saved_path"] = str(image_path)
+ # Store relative filename only (consistent with OCR track)
+ # PDF generator will join with result_dir to get full path
+ image_data["saved_path"] = image_filename
logger.debug(f"Saved image to {image_path}")
element = DocumentElement(
@@ -1001,12 +1022,272 @@ class DirectExtractionEngine:
return elements
+ def has_missing_images(self, page: fitz.Page) -> bool:
+ """
+ Detect if a page likely has images that weren't extracted.
+
+ This checks for inline image blocks (type=1 in text dict) which indicate
+ graphics composed of many small image blocks (like logos) that
+ page.get_images() cannot detect.
+
+ Args:
+ page: PyMuPDF page object
+
+ Returns:
+ True if there are likely missing images that need OCR extraction
+ """
+ try:
+ # Check if get_images found anything
+ standard_images = page.get_images()
+ if standard_images:
+ return False # Standard images were found, no need for fallback
+
+ # Check for inline image blocks (type=1)
+ text_dict = page.get_text("dict", sort=True)
+ blocks = text_dict.get("blocks", [])
+
+ image_block_count = sum(1 for b in blocks if b.get("type") == 1)
+
+ # If there are many inline image blocks, likely there's a logo or graphic
+ if image_block_count >= 10:
+ logger.info(f"Detected {image_block_count} inline image blocks - may need OCR for image extraction")
+ return True
+
+ return False
+
+ except Exception as e:
+ logger.warning(f"Error checking for missing images: {e}")
+ return False
+
+ def check_document_for_missing_images(self, pdf_path: Path) -> List[int]:
+ """
+ Check a PDF document for pages that likely have missing images.
+
+ This opens the PDF and checks each page for inline image blocks
+ that weren't extracted by get_images().
+
+ Args:
+ pdf_path: Path to the PDF file
+
+ Returns:
+ List of page numbers (1-indexed) that have missing images
+ """
+ pages_with_missing_images = []
+
+ try:
+ doc = fitz.open(str(pdf_path))
+ for page_num in range(len(doc)):
+ page = doc[page_num]
+ if self.has_missing_images(page):
+ pages_with_missing_images.append(page_num + 1) # 1-indexed
+ doc.close()
+
+ if pages_with_missing_images:
+ logger.info(f"Document has missing images on pages: {pages_with_missing_images}")
+
+ except Exception as e:
+ logger.error(f"Error checking document for missing images: {e}")
+
+ return pages_with_missing_images
+
+ def render_inline_image_regions(
+ self,
+ pdf_path: Path,
+ unified_doc: 'UnifiedDocument',
+ pages: List[int],
+ output_dir: Optional[Path] = None
+ ) -> int:
+ """
+ Render inline image regions and add them to the unified document.
+
+ This is a fallback when OCR doesn't detect images. It clusters inline
+ image blocks (type=1) and renders them as images.
+
+ Args:
+ pdf_path: Path to the PDF file
+ unified_doc: UnifiedDocument to add images to
+ pages: List of page numbers (1-indexed) to process
+ output_dir: Directory to save rendered images
+
+ Returns:
+ Number of images added
+ """
+ images_added = 0
+
+ try:
+ doc = fitz.open(str(pdf_path))
+
+ for page_num in pages:
+ if page_num < 1 or page_num > len(doc):
+ continue
+
+ page = doc[page_num - 1] # 0-indexed
+ page_rect = page.rect
+
+ # Get inline image blocks
+ text_dict = page.get_text("dict", sort=True)
+ blocks = text_dict.get("blocks", [])
+
+ image_blocks = []
+ for block in blocks:
+ if block.get("type") == 1: # Image block
+ bbox = block.get("bbox")
+ if bbox:
+ image_blocks.append(fitz.Rect(bbox))
+
+ if len(image_blocks) < 5: # Reduced from 10
+ logger.debug(f"Page {page_num}: Only {len(image_blocks)} inline image blocks, skipping")
+ continue
+
+ logger.info(f"Page {page_num}: Found {len(image_blocks)} inline image blocks")
+
+ # Cluster nearby image blocks
+ regions = self._cluster_nearby_rects(image_blocks, tolerance=5.0)
+ logger.info(f"Page {page_num}: Clustered into {len(regions)} regions")
+
+ # Find the corresponding page in unified_doc
+ target_page = None
+ for p in unified_doc.pages:
+ if p.page_number == page_num:
+ target_page = p
+ break
+
+ if not target_page:
+ continue
+
+ for region_idx, region_rect in enumerate(regions):
+ logger.info(f"Page {page_num} region {region_idx}: {region_rect} (w={region_rect.width:.1f}, h={region_rect.height:.1f})")
+
+ # Skip very small regions
+ if region_rect.width < 30 or region_rect.height < 30:
+ logger.info(f" -> Skipped: too small (min 30x30)")
+ continue
+
+ # Skip regions that are primarily in the table area (below top 40%)
+ # But allow regions that START in the top portion
+ page_30_pct = page_rect.height * 0.3
+ page_40_pct = page_rect.height * 0.4
+ if region_rect.y0 > page_40_pct:
+ logger.info(f" -> Skipped: y0={region_rect.y0:.1f} > 40% of page ({page_40_pct:.1f})")
+ continue
+
+ logger.info(f"Rendering inline image region {region_idx} on page {page_num}: {region_rect}")
+
+ try:
+ # Add small padding
+ clip_rect = region_rect + (-2, -2, 2, 2)
+ clip_rect.intersect(page_rect)
+
+ # Render at 2x resolution
+ mat = fitz.Matrix(2, 2)
+ pix = page.get_pixmap(clip=clip_rect, matrix=mat, alpha=False)
+
+ # Create bounding box
+ bbox = BoundingBox(
+ x0=clip_rect.x0,
+ y0=clip_rect.y0,
+ x1=clip_rect.x1,
+ y1=clip_rect.y1
+ )
+
+ image_data = {
+ "width": pix.width,
+ "height": pix.height,
+ "colorspace": "rgb",
+ "type": "inline_region"
+ }
+
+ # Save image if output directory provided
+ if output_dir:
+ output_dir.mkdir(parents=True, exist_ok=True)
+ doc_id = unified_doc.document_id or "unknown"
+ image_filename = f"{doc_id}_p{page_num}_logo{region_idx}.png"
+ image_path = output_dir / image_filename
+ pix.save(str(image_path))
+ image_data["saved_path"] = image_filename
+ logger.info(f"Saved inline image region to {image_path}")
+
+ element = DocumentElement(
+ element_id=f"logo_{page_num}_{region_idx}",
+ type=ElementType.LOGO,
+ content=image_data,
+ bbox=bbox,
+ confidence=0.9,
+ metadata={
+ "region_type": "inline_image_blocks",
+ "block_count": len(image_blocks)
+ }
+ )
+ target_page.elements.append(element)
+ images_added += 1
+
+ pix = None # Free memory
+
+ except Exception as e:
+ logger.error(f"Error rendering inline image region {region_idx}: {e}")
+
+ doc.close()
+
+ if images_added > 0:
+ current_images = unified_doc.metadata.total_images or 0
+ unified_doc.metadata.total_images = current_images + images_added
+ logger.info(f"Added {images_added} inline image regions to document")
+
+ except Exception as e:
+ logger.error(f"Error rendering inline image regions: {e}")
+
+ return images_added
+
+ def _cluster_nearby_rects(self, rects: List[fitz.Rect], tolerance: float = 5.0) -> List[fitz.Rect]:
+ """Cluster nearby rectangles into regions."""
+ if not rects:
+ return []
+
+ sorted_rects = sorted(rects, key=lambda r: (r.y0, r.x0))
+
+ merged = []
+ for rect in sorted_rects:
+ merged_with_existing = False
+ for i, region in enumerate(merged):
+ expanded = region + (-tolerance, -tolerance, tolerance, tolerance)
+ if expanded.intersects(rect):
+ merged[i] = region | rect
+ merged_with_existing = True
+ break
+ if not merged_with_existing:
+ merged.append(rect)
+
+ # Second pass: merge any regions that now overlap
+ changed = True
+ while changed:
+ changed = False
+ new_merged = []
+ skip = set()
+
+ for i, r1 in enumerate(merged):
+ if i in skip:
+ continue
+ current = r1
+ for j, r2 in enumerate(merged[i+1:], start=i+1):
+ if j in skip:
+ continue
+ expanded = current + (-tolerance, -tolerance, tolerance, tolerance)
+ if expanded.intersects(r2):
+ current = current | r2
+ skip.add(j)
+ changed = True
+ new_merged.append(current)
+ merged = new_merged
+
+ return merged
+
def _extract_vector_graphics(self,
page: fitz.Page,
page_num: int,
document_id: str,
counter: int,
- output_dir: Optional[Path]) -> List[DocumentElement]:
+ output_dir: Optional[Path],
+ table_bboxes: Optional[List[BoundingBox]] = None) -> List[DocumentElement]:
"""
Extract vector graphics (charts, diagrams) from page.
@@ -1020,6 +1301,7 @@ class DirectExtractionEngine:
document_id: Unique document identifier
counter: Starting counter for element IDs
output_dir: Directory to save rendered graphics
+ table_bboxes: List of table bounding boxes to exclude table border drawings
Returns:
List of DocumentElement objects representing vector graphics
@@ -1034,16 +1316,25 @@ class DirectExtractionEngine:
logger.debug(f"Page {page_num} contains {len(drawings)} vector drawing commands")
+ # Filter out drawings that are likely table borders
+ # Table borders are typically thin rectangular lines within table regions
+ non_table_drawings = self._filter_table_border_drawings(drawings, table_bboxes)
+ logger.debug(f"After filtering table borders: {len(non_table_drawings)} drawings remain")
+
+ if not non_table_drawings:
+ logger.debug("All drawings appear to be table borders, no vector graphics to extract")
+ return elements
+
# Cluster drawings into groups (charts, diagrams, etc.)
try:
- # PyMuPDF's cluster_drawings() groups nearby drawings automatically
- drawing_clusters = page.cluster_drawings()
+ # Use custom clustering that only considers non-table drawings
+ drawing_clusters = self._cluster_non_table_drawings(page, non_table_drawings)
logger.debug(f"Clustered into {len(drawing_clusters)} groups")
except (AttributeError, TypeError) as e:
# cluster_drawings not available or has different signature
# Fallback: try to identify charts by analyzing drawing density
- logger.warning(f"cluster_drawings() failed ({e}), using fallback method")
- drawing_clusters = self._cluster_drawings_fallback(page, drawings)
+ logger.warning(f"Custom clustering failed ({e}), using fallback method")
+ drawing_clusters = self._cluster_drawings_fallback(page, non_table_drawings)
for cluster_idx, bbox in enumerate(drawing_clusters):
# Ignore small regions (likely noise or separator lines)
@@ -1148,6 +1439,124 @@ class DirectExtractionEngine:
return filtered_clusters
+ def _filter_table_border_drawings(self, drawings: list, table_bboxes: Optional[List[BoundingBox]]) -> list:
+ """
+ Filter out drawings that are likely table borders.
+
+ Table borders are typically:
+ - Thin rectangular lines (height or width < 5pt)
+ - Located within or on the edge of table bounding boxes
+
+ Args:
+ drawings: List of PyMuPDF drawing objects
+ table_bboxes: List of table bounding boxes
+
+ Returns:
+ List of drawings that are NOT table borders (likely logos, charts, etc.)
+ """
+ if not table_bboxes:
+ return drawings
+
+ non_table_drawings = []
+ table_border_count = 0
+
+ for drawing in drawings:
+ rect = drawing.get('rect')
+ if not rect:
+ continue
+
+ draw_rect = fitz.Rect(rect)
+
+ # Check if this drawing is a thin line (potential table border)
+ is_thin_line = draw_rect.width < 5 or draw_rect.height < 5
+
+ # Check if drawing overlaps significantly with any table
+ overlaps_table = False
+ for table_bbox in table_bboxes:
+ table_rect = fitz.Rect(table_bbox.x0, table_bbox.y0, table_bbox.x1, table_bbox.y1)
+
+ # Expand table rect slightly to include border lines on edges
+ expanded_table = table_rect + (-5, -5, 5, 5)
+
+ if expanded_table.contains(draw_rect) or expanded_table.intersects(draw_rect):
+ # Calculate overlap ratio
+ intersection = draw_rect & expanded_table
+ if not intersection.is_empty:
+ overlap_ratio = intersection.get_area() / draw_rect.get_area() if draw_rect.get_area() > 0 else 0
+
+ # If drawing is mostly inside table region, it's likely a border
+ if overlap_ratio > 0.8:
+ overlaps_table = True
+ break
+
+ # Keep drawing if it's NOT (thin line AND overlapping table)
+ # This keeps: logos (complex shapes), charts outside tables, etc.
+ if is_thin_line and overlaps_table:
+ table_border_count += 1
+ else:
+ non_table_drawings.append(drawing)
+
+ if table_border_count > 0:
+ logger.debug(f"Filtered out {table_border_count} table border drawings")
+
+ return non_table_drawings
+
+ def _cluster_non_table_drawings(self, page: fitz.Page, drawings: list) -> list:
+ """
+ Cluster non-table drawings into groups.
+
+ This method clusters drawings that have been pre-filtered to exclude table borders.
+ It uses a more conservative clustering approach suitable for logos and charts.
+
+ Args:
+ page: PyMuPDF page object
+ drawings: Pre-filtered list of drawings (excluding table borders)
+
+ Returns:
+ List of fitz.Rect representing clustered drawing regions
+ """
+ if not drawings:
+ return []
+
+ # Collect all drawing bounding boxes
+ bboxes = []
+ for drawing in drawings:
+ rect = drawing.get('rect')
+ if rect:
+ bboxes.append(fitz.Rect(rect))
+
+ if not bboxes:
+ return []
+
+ # More conservative clustering with smaller tolerance
+ # This prevents grouping distant graphics together
+ clusters = []
+ tolerance = 10 # Smaller tolerance than fallback (was 20)
+
+ for bbox in bboxes:
+ # Try to merge with existing cluster
+ merged = False
+ for i, cluster in enumerate(clusters):
+ # Check if bbox is close to this cluster
+ expanded_cluster = cluster + (-tolerance, -tolerance, tolerance, tolerance)
+ if expanded_cluster.intersects(bbox):
+ # Merge bbox into cluster
+ clusters[i] = cluster | bbox # Union of rectangles
+ merged = True
+ break
+
+ if not merged:
+ # Create new cluster
+ clusters.append(bbox)
+
+ # Filter out very small clusters (noise)
+ # Keep minimum 30x30 for logos (smaller than default 50x50)
+ filtered_clusters = [c for c in clusters if c.width >= 30 and c.height >= 30]
+
+ logger.debug(f"Non-table clustering: {len(bboxes)} drawings -> {len(clusters)} clusters -> {len(filtered_clusters)} filtered")
+
+ return filtered_clusters
+
def _deduplicate_table_chart_overlap(self, elements: List[DocumentElement]) -> List[DocumentElement]:
"""
Intelligently resolve TABLE-CHART overlaps based on table structure completeness.
diff --git a/backend/app/services/memory_manager.py b/backend/app/services/memory_manager.py
new file mode 100644
index 0000000..57ce976
--- /dev/null
+++ b/backend/app/services/memory_manager.py
@@ -0,0 +1,2269 @@
+"""
+Tool_OCR - Memory Management System
+Provides centralized model lifecycle management with reference counting,
+idle timeout, and GPU memory monitoring.
+"""
+
+import asyncio
+import gc
+import logging
+import threading
+import time
+from dataclasses import dataclass, field
+from datetime import datetime
+from enum import Enum
+from typing import Any, Callable, Dict, List, Optional, Tuple
+from weakref import WeakValueDictionary
+
+import paddle
+
+# Optional torch import for additional GPU memory management
+try:
+ import torch
+ TORCH_AVAILABLE = True
+except ImportError:
+ TORCH_AVAILABLE = False
+
+# Optional pynvml import for NVIDIA GPU monitoring
+try:
+ import pynvml
+ PYNVML_AVAILABLE = True
+except ImportError:
+ PYNVML_AVAILABLE = False
+
+logger = logging.getLogger(__name__)
+
+
+class MemoryBackend(Enum):
+ """Available memory query backends"""
+ PADDLE = "paddle"
+ TORCH = "torch"
+ PYNVML = "pynvml"
+ NONE = "none"
+
+
+@dataclass
+class MemoryStats:
+ """GPU/CPU memory statistics"""
+ gpu_used_mb: float = 0.0
+ gpu_free_mb: float = 0.0
+ gpu_total_mb: float = 0.0
+ gpu_used_ratio: float = 0.0
+ cpu_used_mb: float = 0.0
+ cpu_available_mb: float = 0.0
+ timestamp: float = field(default_factory=time.time)
+ backend: MemoryBackend = MemoryBackend.NONE
+
+
+@dataclass
+class ModelEntry:
+ """Entry for a managed model"""
+ model: Any
+ model_id: str
+ ref_count: int = 0
+ last_used: float = field(default_factory=time.time)
+ created_at: float = field(default_factory=time.time)
+ estimated_memory_mb: float = 0.0
+ is_loading: bool = False
+ cleanup_callback: Optional[Callable] = None
+
+
+class MemoryConfig:
+ """Configuration for memory management"""
+
+ def __init__(
+ self,
+ warning_threshold: float = 0.80,
+ critical_threshold: float = 0.95,
+ emergency_threshold: float = 0.98,
+ model_idle_timeout_seconds: int = 300,
+ memory_check_interval_seconds: int = 30,
+ enable_auto_cleanup: bool = True,
+ enable_emergency_cleanup: bool = True,
+ max_concurrent_predictions: int = 2,
+ enable_cpu_fallback: bool = True,
+ gpu_memory_limit_mb: int = 6144,
+ ):
+ self.warning_threshold = warning_threshold
+ self.critical_threshold = critical_threshold
+ self.emergency_threshold = emergency_threshold
+ self.model_idle_timeout_seconds = model_idle_timeout_seconds
+ self.memory_check_interval_seconds = memory_check_interval_seconds
+ self.enable_auto_cleanup = enable_auto_cleanup
+ self.enable_emergency_cleanup = enable_emergency_cleanup
+ self.max_concurrent_predictions = max_concurrent_predictions
+ self.enable_cpu_fallback = enable_cpu_fallback
+ self.gpu_memory_limit_mb = gpu_memory_limit_mb
+
+
+class MemoryGuard:
+ """
+ Monitor GPU/CPU memory usage and trigger preventive actions.
+
+ Supports multiple backends: paddle.device.cuda, pynvml, torch
+ """
+
+ def __init__(self, config: Optional[MemoryConfig] = None):
+ self.config = config or MemoryConfig()
+ self.backend = self._detect_backend()
+ self._history: List[MemoryStats] = []
+ self._max_history = 100
+ self._alerts: List[Dict] = []
+ self._lock = threading.Lock()
+
+ # Initialize pynvml if available
+ self._nvml_handle = None
+ if self.backend == MemoryBackend.PYNVML:
+ self._init_pynvml()
+
+ logger.info(f"MemoryGuard initialized with backend: {self.backend.value}")
+
+ def _detect_backend(self) -> MemoryBackend:
+ """Detect the best available memory query backend"""
+ # Prefer pynvml for accurate GPU memory info
+ if PYNVML_AVAILABLE:
+ try:
+ pynvml.nvmlInit()
+ pynvml.nvmlShutdown()
+ return MemoryBackend.PYNVML
+ except Exception:
+ pass
+
+ # Fall back to torch if available
+ if TORCH_AVAILABLE and torch.cuda.is_available():
+ return MemoryBackend.TORCH
+
+ # Fall back to paddle
+ if paddle.is_compiled_with_cuda():
+ try:
+ if paddle.device.cuda.device_count() > 0:
+ return MemoryBackend.PADDLE
+ except Exception:
+ pass
+
+ return MemoryBackend.NONE
+
+ def _init_pynvml(self):
+ """Initialize pynvml for GPU monitoring"""
+ try:
+ pynvml.nvmlInit()
+ self._nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+ logger.info("pynvml initialized for GPU monitoring")
+ except Exception as e:
+ logger.warning(f"Failed to initialize pynvml: {e}")
+ self.backend = MemoryBackend.PADDLE if paddle.is_compiled_with_cuda() else MemoryBackend.NONE
+
+ def get_memory_stats(self, device_id: int = 0) -> MemoryStats:
+ """
+ Get current memory statistics.
+
+ Args:
+ device_id: GPU device ID (default 0)
+
+ Returns:
+ MemoryStats with current memory usage
+ """
+ stats = MemoryStats(backend=self.backend)
+
+ try:
+ if self.backend == MemoryBackend.PYNVML and self._nvml_handle:
+ mem_info = pynvml.nvmlDeviceGetMemoryInfo(self._nvml_handle)
+ stats.gpu_total_mb = mem_info.total / (1024**2)
+ stats.gpu_used_mb = mem_info.used / (1024**2)
+ stats.gpu_free_mb = mem_info.free / (1024**2)
+ stats.gpu_used_ratio = mem_info.used / mem_info.total if mem_info.total > 0 else 0
+
+ elif self.backend == MemoryBackend.TORCH:
+ stats.gpu_total_mb = torch.cuda.get_device_properties(device_id).total_memory / (1024**2)
+ stats.gpu_used_mb = torch.cuda.memory_allocated(device_id) / (1024**2)
+ stats.gpu_free_mb = (torch.cuda.get_device_properties(device_id).total_memory -
+ torch.cuda.memory_allocated(device_id)) / (1024**2)
+ stats.gpu_used_ratio = stats.gpu_used_mb / stats.gpu_total_mb if stats.gpu_total_mb > 0 else 0
+
+ elif self.backend == MemoryBackend.PADDLE:
+ # Paddle doesn't provide total memory directly, use allocated/reserved
+ stats.gpu_used_mb = paddle.device.cuda.memory_allocated(device_id) / (1024**2)
+ stats.gpu_free_mb = 0 # Not directly available
+ stats.gpu_total_mb = self.config.gpu_memory_limit_mb
+ stats.gpu_used_ratio = stats.gpu_used_mb / stats.gpu_total_mb if stats.gpu_total_mb > 0 else 0
+
+ # Get CPU memory info
+ try:
+ import psutil
+ mem = psutil.virtual_memory()
+ stats.cpu_used_mb = mem.used / (1024**2)
+ stats.cpu_available_mb = mem.available / (1024**2)
+ except ImportError:
+ pass
+
+ except Exception as e:
+ logger.warning(f"Failed to get memory stats: {e}")
+
+ # Store in history
+ with self._lock:
+ self._history.append(stats)
+ if len(self._history) > self._max_history:
+ self._history.pop(0)
+
+ return stats
+
+ def check_memory(self, required_mb: int = 0, device_id: int = 0) -> Tuple[bool, MemoryStats]:
+ """
+ Check if sufficient GPU memory is available.
+
+ Args:
+ required_mb: Required memory in MB (0 for just checking thresholds)
+ device_id: GPU device ID
+
+ Returns:
+ Tuple of (is_available, current_stats)
+ """
+ stats = self.get_memory_stats(device_id)
+
+ # Check if we have enough free memory
+ if required_mb > 0 and stats.gpu_free_mb > 0:
+ if stats.gpu_free_mb < required_mb:
+ self._add_alert("insufficient_memory",
+ f"Required {required_mb}MB but only {stats.gpu_free_mb:.0f}MB available")
+ return False, stats
+
+ # Check threshold levels
+ if stats.gpu_used_ratio > self.config.emergency_threshold:
+ self._add_alert("emergency",
+ f"GPU memory at {stats.gpu_used_ratio*100:.1f}% (emergency threshold)")
+ return False, stats
+
+ if stats.gpu_used_ratio > self.config.critical_threshold:
+ self._add_alert("critical",
+ f"GPU memory at {stats.gpu_used_ratio*100:.1f}% (critical threshold)")
+ return False, stats
+
+ if stats.gpu_used_ratio > self.config.warning_threshold:
+ self._add_alert("warning",
+ f"GPU memory at {stats.gpu_used_ratio*100:.1f}% (warning threshold)")
+
+ return True, stats
+
+ def _add_alert(self, level: str, message: str):
+ """Add an alert to the alert history"""
+ alert = {
+ "level": level,
+ "message": message,
+ "timestamp": time.time()
+ }
+ with self._lock:
+ self._alerts.append(alert)
+ # Keep last 50 alerts
+ if len(self._alerts) > 50:
+ self._alerts.pop(0)
+
+ if level == "emergency":
+ logger.error(f"MEMORY ALERT [{level}]: {message}")
+ elif level == "critical":
+ logger.warning(f"MEMORY ALERT [{level}]: {message}")
+ else:
+ logger.info(f"MEMORY ALERT [{level}]: {message}")
+
+ def get_alerts(self, since_timestamp: float = 0) -> List[Dict]:
+ """Get alerts since a given timestamp"""
+ with self._lock:
+ return [a for a in self._alerts if a["timestamp"] > since_timestamp]
+
+ def clear_gpu_cache(self):
+ """Clear GPU memory cache"""
+ try:
+ if TORCH_AVAILABLE and torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ logger.debug("Cleared PyTorch GPU cache")
+
+ if paddle.is_compiled_with_cuda():
+ paddle.device.cuda.empty_cache()
+ logger.debug("Cleared PaddlePaddle GPU cache")
+
+ gc.collect()
+
+ except Exception as e:
+ logger.warning(f"GPU cache clear failed: {e}")
+
+ def shutdown(self):
+ """Clean up resources"""
+ if PYNVML_AVAILABLE and self._nvml_handle:
+ try:
+ pynvml.nvmlShutdown()
+ except Exception:
+ pass
+
+
+class PredictionSemaphore:
+ """
+ Semaphore for controlling concurrent PP-StructureV3 predictions.
+
+ PP-StructureV3.predict() is memory-intensive. Running multiple predictions
+ simultaneously can cause OOM errors. This class limits concurrent predictions
+ and provides timeout handling.
+ """
+
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ """Singleton pattern - ensure only one PredictionSemaphore exists"""
+ with cls._lock:
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
+
+ def __init__(self, max_concurrent: int = 2, default_timeout: float = 300.0):
+ if self._initialized:
+ return
+
+ self._max_concurrent = max_concurrent
+ self._default_timeout = default_timeout
+ self._semaphore = threading.Semaphore(max_concurrent)
+ self._condition = threading.Condition()
+ self._queue_depth = 0
+ self._active_predictions = 0
+
+ # Metrics
+ self._total_predictions = 0
+ self._total_timeouts = 0
+ self._total_wait_time = 0.0
+ self._metrics_lock = threading.Lock()
+
+ self._initialized = True
+ logger.info(f"PredictionSemaphore initialized (max_concurrent={max_concurrent})")
+
+ def acquire(self, timeout: Optional[float] = None, task_id: Optional[str] = None) -> bool:
+ """
+ Acquire a prediction slot.
+
+ Args:
+ timeout: Timeout in seconds (None for default, 0 for non-blocking)
+ task_id: Optional task identifier for logging
+
+ Returns:
+ True if acquired, False if timed out
+ """
+ timeout = timeout if timeout is not None else self._default_timeout
+ start_time = time.time()
+
+ with self._condition:
+ self._queue_depth += 1
+
+ task_str = f" for task {task_id}" if task_id else ""
+ logger.debug(f"Waiting for prediction slot{task_str} (queue_depth={self._queue_depth})")
+
+ try:
+ acquired = self._semaphore.acquire(timeout=timeout if timeout > 0 else None)
+
+ wait_time = time.time() - start_time
+ with self._metrics_lock:
+ self._total_wait_time += wait_time
+ if acquired:
+ self._total_predictions += 1
+ self._active_predictions += 1
+ else:
+ self._total_timeouts += 1
+
+ with self._condition:
+ self._queue_depth -= 1
+
+ if acquired:
+ logger.debug(f"Prediction slot acquired{task_str} (waited {wait_time:.2f}s, active={self._active_predictions})")
+ else:
+ logger.warning(f"Prediction slot timeout{task_str} after {timeout}s")
+
+ return acquired
+
+ except Exception as e:
+ with self._condition:
+ self._queue_depth -= 1
+ logger.error(f"Error acquiring prediction slot: {e}")
+ return False
+
+ def release(self, task_id: Optional[str] = None):
+ """
+ Release a prediction slot.
+
+ Args:
+ task_id: Optional task identifier for logging
+ """
+ self._semaphore.release()
+
+ with self._metrics_lock:
+ self._active_predictions = max(0, self._active_predictions - 1)
+
+ task_str = f" for task {task_id}" if task_id else ""
+ logger.debug(f"Prediction slot released{task_str} (active={self._active_predictions})")
+
+ def get_stats(self) -> Dict:
+ """Get prediction semaphore statistics"""
+ with self._metrics_lock:
+ avg_wait = self._total_wait_time / max(1, self._total_predictions)
+ return {
+ "max_concurrent": self._max_concurrent,
+ "active_predictions": self._active_predictions,
+ "queue_depth": self._queue_depth,
+ "total_predictions": self._total_predictions,
+ "total_timeouts": self._total_timeouts,
+ "average_wait_seconds": round(avg_wait, 3),
+ }
+
+ def reset_metrics(self):
+ """Reset metrics counters"""
+ with self._metrics_lock:
+ self._total_predictions = 0
+ self._total_timeouts = 0
+ self._total_wait_time = 0.0
+
+
+class PredictionContext:
+ """
+ Context manager for PP-StructureV3 predictions with semaphore control.
+
+ Usage:
+ with prediction_context(task_id="task_123") as acquired:
+ if acquired:
+ result = structure_engine.predict(image_path)
+ else:
+ # Handle timeout
+ """
+
+ def __init__(
+ self,
+ semaphore: PredictionSemaphore,
+ timeout: Optional[float] = None,
+ task_id: Optional[str] = None
+ ):
+ self._semaphore = semaphore
+ self._timeout = timeout
+ self._task_id = task_id
+ self._acquired = False
+
+ def __enter__(self) -> bool:
+ self._acquired = self._semaphore.acquire(
+ timeout=self._timeout,
+ task_id=self._task_id
+ )
+ return self._acquired
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self._acquired:
+ self._semaphore.release(task_id=self._task_id)
+ return False # Don't suppress exceptions
+
+
+# Global prediction semaphore instance
+_prediction_semaphore: Optional[PredictionSemaphore] = None
+
+
+def get_prediction_semaphore(max_concurrent: Optional[int] = None) -> PredictionSemaphore:
+ """
+ Get the global PredictionSemaphore instance.
+
+ Args:
+ max_concurrent: Max concurrent predictions (only used on first call)
+
+ Returns:
+ PredictionSemaphore singleton instance
+ """
+ global _prediction_semaphore
+ if _prediction_semaphore is None:
+ from app.core.config import settings
+ max_conc = max_concurrent or settings.max_concurrent_predictions
+ _prediction_semaphore = PredictionSemaphore(max_concurrent=max_conc)
+ return _prediction_semaphore
+
+
+def shutdown_prediction_semaphore():
+ """Reset the global PredictionSemaphore instance"""
+ global _prediction_semaphore
+ if _prediction_semaphore is not None:
+ # Reset singleton for clean state
+ PredictionSemaphore._instance = None
+ PredictionSemaphore._lock = threading.Lock()
+ _prediction_semaphore = None
+
+
+def prediction_context(
+ timeout: Optional[float] = None,
+ task_id: Optional[str] = None
+) -> PredictionContext:
+ """
+ Create a prediction context manager.
+
+ Args:
+ timeout: Timeout in seconds for acquiring slot
+ task_id: Optional task identifier for logging
+
+ Returns:
+ PredictionContext context manager
+ """
+ semaphore = get_prediction_semaphore()
+ return PredictionContext(semaphore, timeout, task_id)
+
+
+class ModelManager:
+ """
+ Centralized model lifecycle management with reference counting and idle timeout.
+
+ Features:
+ - Reference counting for shared model instances
+ - Idle timeout for automatic unloading
+ - LRU eviction when memory pressure
+ - Thread-safe operations
+ """
+
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ """Singleton pattern - ensure only one ModelManager exists"""
+ with cls._lock:
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
+
+ def __init__(self, config: Optional[MemoryConfig] = None):
+ if self._initialized:
+ return
+
+ self.config = config or MemoryConfig()
+ self.models: Dict[str, ModelEntry] = {}
+ self.memory_guard = MemoryGuard(self.config)
+ self._model_lock = threading.RLock()
+ self._loading_locks: Dict[str, threading.Lock] = {}
+
+ # Start background timeout monitor
+ self._monitor_running = True
+ self._monitor_thread = threading.Thread(
+ target=self._timeout_monitor_loop,
+ daemon=True,
+ name="ModelManager-TimeoutMonitor"
+ )
+ self._monitor_thread.start()
+
+ self._initialized = True
+ logger.info("ModelManager initialized")
+
+ def _timeout_monitor_loop(self):
+ """Background thread to monitor and unload idle models"""
+ while self._monitor_running:
+ try:
+ time.sleep(self.config.memory_check_interval_seconds)
+ if self.config.enable_auto_cleanup:
+ self._cleanup_idle_models()
+ except Exception as e:
+ logger.error(f"Error in timeout monitor: {e}")
+
+ def _cleanup_idle_models(self):
+ """Unload models that have been idle longer than the timeout"""
+ current_time = time.time()
+ models_to_unload = []
+
+ with self._model_lock:
+ for model_id, entry in self.models.items():
+ # Only unload if no active references and idle timeout exceeded
+ if entry.ref_count <= 0:
+ idle_time = current_time - entry.last_used
+ if idle_time > self.config.model_idle_timeout_seconds:
+ models_to_unload.append(model_id)
+
+ for model_id in models_to_unload:
+ self.unload_model(model_id, force=False)
+
+ def get_or_load_model(
+ self,
+ model_id: str,
+ loader_func: Callable[[], Any],
+ estimated_memory_mb: float = 0,
+ cleanup_callback: Optional[Callable] = None
+ ) -> Any:
+ """
+ Get a model by ID, loading it if not already loaded.
+
+ Args:
+ model_id: Unique identifier for the model
+ loader_func: Function to call to load the model if not cached
+ estimated_memory_mb: Estimated memory usage for this model
+ cleanup_callback: Optional callback to run before unloading
+
+ Returns:
+ The model instance
+ """
+ with self._model_lock:
+ # Check if model is already loaded
+ if model_id in self.models:
+ entry = self.models[model_id]
+ if not entry.is_loading:
+ entry.ref_count += 1
+ entry.last_used = time.time()
+ logger.debug(f"Model {model_id} acquired (ref_count={entry.ref_count})")
+ return entry.model
+
+ # Create loading lock for this model if not exists
+ if model_id not in self._loading_locks:
+ self._loading_locks[model_id] = threading.Lock()
+
+ # Load model outside the main lock to allow concurrent operations
+ loading_lock = self._loading_locks[model_id]
+
+ with loading_lock:
+ # Double-check after acquiring loading lock
+ with self._model_lock:
+ if model_id in self.models and not self.models[model_id].is_loading:
+ entry = self.models[model_id]
+ entry.ref_count += 1
+ entry.last_used = time.time()
+ return entry.model
+
+ # Mark as loading
+ self.models[model_id] = ModelEntry(
+ model=None,
+ model_id=model_id,
+ is_loading=True,
+ estimated_memory_mb=estimated_memory_mb,
+ cleanup_callback=cleanup_callback
+ )
+
+ try:
+ # Check memory before loading
+ if estimated_memory_mb > 0:
+ is_available, stats = self.memory_guard.check_memory(int(estimated_memory_mb))
+ if not is_available and self.config.enable_emergency_cleanup:
+ logger.warning(f"Memory low, attempting cleanup before loading {model_id}")
+ self._evict_lru_models(required_mb=estimated_memory_mb)
+
+ # Load the model
+ logger.info(f"Loading model {model_id} (estimated {estimated_memory_mb}MB)")
+ start_time = time.time()
+ model = loader_func()
+ load_time = time.time() - start_time
+ logger.info(f"Model {model_id} loaded in {load_time:.2f}s")
+
+ # Update entry
+ with self._model_lock:
+ self.models[model_id] = ModelEntry(
+ model=model,
+ model_id=model_id,
+ ref_count=1,
+ estimated_memory_mb=estimated_memory_mb,
+ cleanup_callback=cleanup_callback,
+ is_loading=False
+ )
+
+ return model
+
+ except Exception as e:
+ # Clean up failed entry
+ with self._model_lock:
+ if model_id in self.models:
+ del self.models[model_id]
+ logger.error(f"Failed to load model {model_id}: {e}")
+ raise
+
+ def release_model(self, model_id: str):
+ """
+ Release a reference to a model.
+
+ Args:
+ model_id: Model identifier
+ """
+ with self._model_lock:
+ if model_id in self.models:
+ entry = self.models[model_id]
+ entry.ref_count = max(0, entry.ref_count - 1)
+ entry.last_used = time.time()
+ logger.debug(f"Model {model_id} released (ref_count={entry.ref_count})")
+
+ def unload_model(self, model_id: str, force: bool = False) -> bool:
+ """
+ Unload a model from memory.
+
+ Args:
+ model_id: Model identifier
+ force: Force unload even if references exist
+
+ Returns:
+ True if model was unloaded
+ """
+ with self._model_lock:
+ if model_id not in self.models:
+ return False
+
+ entry = self.models[model_id]
+
+ # Don't unload if there are active references (unless forced)
+ if entry.ref_count > 0 and not force:
+ logger.warning(f"Cannot unload {model_id}: {entry.ref_count} active references")
+ return False
+
+ # Run cleanup callback if provided
+ if entry.cleanup_callback:
+ try:
+ entry.cleanup_callback()
+ except Exception as e:
+ logger.warning(f"Cleanup callback failed for {model_id}: {e}")
+
+ # Delete model
+ del self.models[model_id]
+ logger.info(f"Model {model_id} unloaded")
+
+ # Clear GPU cache after unloading
+ self.memory_guard.clear_gpu_cache()
+ return True
+
+ def _evict_lru_models(self, required_mb: float = 0):
+ """
+ Evict least recently used models to free memory.
+
+ Args:
+ required_mb: Target amount of memory to free
+ """
+ with self._model_lock:
+ # Sort models by last_used (oldest first), excluding those with references
+ eviction_candidates = [
+ (model_id, entry)
+ for model_id, entry in self.models.items()
+ if entry.ref_count <= 0 and not entry.is_loading
+ ]
+ eviction_candidates.sort(key=lambda x: x[1].last_used)
+
+ freed_mb = 0
+ for model_id, entry in eviction_candidates:
+ if self.unload_model(model_id, force=False):
+ freed_mb += entry.estimated_memory_mb
+ logger.info(f"Evicted LRU model {model_id}, freed ~{entry.estimated_memory_mb}MB")
+
+ if required_mb > 0 and freed_mb >= required_mb:
+ break
+
+ def get_model_stats(self) -> Dict:
+ """Get statistics about loaded models"""
+ with self._model_lock:
+ return {
+ "total_models": len(self.models),
+ "models": {
+ model_id: {
+ "ref_count": entry.ref_count,
+ "last_used": entry.last_used,
+ "estimated_memory_mb": entry.estimated_memory_mb,
+ "is_loading": entry.is_loading,
+ "idle_seconds": time.time() - entry.last_used
+ }
+ for model_id, entry in self.models.items()
+ },
+ "total_estimated_memory_mb": sum(
+ e.estimated_memory_mb for e in self.models.values()
+ ),
+ "memory_stats": self.memory_guard.get_memory_stats().__dict__
+ }
+
+ def teardown(self):
+ """
+ Clean up all models and resources.
+ Called during application shutdown.
+ """
+ logger.info("ModelManager teardown started")
+
+ # Stop monitor thread
+ self._monitor_running = False
+
+ # Unload all models
+ with self._model_lock:
+ model_ids = list(self.models.keys())
+
+ for model_id in model_ids:
+ self.unload_model(model_id, force=True)
+
+ # Clean up memory guard
+ self.memory_guard.shutdown()
+
+ logger.info("ModelManager teardown completed")
+
+
+# Global singleton instance
+_model_manager: Optional[ModelManager] = None
+
+
+def get_model_manager(config: Optional[MemoryConfig] = None) -> ModelManager:
+ """
+ Get the global ModelManager instance.
+
+ Args:
+ config: Optional configuration (only used on first call)
+
+ Returns:
+ ModelManager singleton instance
+ """
+ global _model_manager
+ if _model_manager is None:
+ _model_manager = ModelManager(config)
+ return _model_manager
+
+
+def shutdown_model_manager():
+ """Shutdown the global ModelManager instance"""
+ global _model_manager
+ if _model_manager is not None:
+ _model_manager.teardown()
+ _model_manager = None
+
+
+# =============================================================================
+# Section 4.2: Batch Processing and Progressive Loading
+# =============================================================================
+
+class BatchPriority(Enum):
+ """Priority levels for batch operations"""
+ LOW = 0
+ NORMAL = 1
+ HIGH = 2
+ CRITICAL = 3
+
+
+@dataclass
+class BatchItem:
+ """Item in a processing batch"""
+ item_id: str
+ data: Any
+ priority: BatchPriority = BatchPriority.NORMAL
+ created_at: float = field(default_factory=time.time)
+ estimated_memory_mb: float = 0.0
+ metadata: Dict = field(default_factory=dict)
+
+
+@dataclass
+class BatchResult:
+ """Result of batch processing"""
+ item_id: str
+ success: bool
+ result: Any = None
+ error: Optional[str] = None
+ processing_time_ms: float = 0.0
+
+
+class BatchProcessor:
+ """
+ Process items in batches to optimize memory usage for large documents.
+
+ Features:
+ - Memory-aware batch sizing
+ - Priority-based processing
+ - Progress tracking
+ - Automatic memory cleanup between batches
+ """
+
+ def __init__(
+ self,
+ max_batch_size: int = 5,
+ max_memory_per_batch_mb: float = 2000.0,
+ memory_guard: Optional[MemoryGuard] = None,
+ cleanup_between_batches: bool = True
+ ):
+ """
+ Initialize BatchProcessor.
+
+ Args:
+ max_batch_size: Maximum items per batch
+ max_memory_per_batch_mb: Maximum memory allowed per batch
+ memory_guard: MemoryGuard instance for memory monitoring
+ cleanup_between_batches: Whether to clear GPU cache between batches
+ """
+ self.max_batch_size = max_batch_size
+ self.max_memory_per_batch_mb = max_memory_per_batch_mb
+ self.memory_guard = memory_guard or MemoryGuard()
+ self.cleanup_between_batches = cleanup_between_batches
+
+ self._queue: List[BatchItem] = []
+ self._lock = threading.Lock()
+ self._processing = False
+
+ # Statistics
+ self._total_processed = 0
+ self._total_batches = 0
+ self._total_failures = 0
+
+ logger.info(f"BatchProcessor initialized (max_batch_size={max_batch_size}, max_memory={max_memory_per_batch_mb}MB)")
+
+ def add_item(self, item: BatchItem):
+ """Add an item to the processing queue"""
+ with self._lock:
+ self._queue.append(item)
+ # Sort by priority (highest first), then by creation time (oldest first)
+ self._queue.sort(key=lambda x: (-x.priority.value, x.created_at))
+ logger.debug(f"Added item {item.item_id} to batch queue (queue_size={len(self._queue)})")
+
+ def add_items(self, items: List[BatchItem]):
+ """Add multiple items to the processing queue"""
+ with self._lock:
+ self._queue.extend(items)
+ self._queue.sort(key=lambda x: (-x.priority.value, x.created_at))
+ logger.debug(f"Added {len(items)} items to batch queue (queue_size={len(self._queue)})")
+
+ def _create_batch(self) -> List[BatchItem]:
+ """Create a batch from the queue based on size and memory constraints"""
+ batch = []
+ batch_memory = 0.0
+
+ with self._lock:
+ remaining = []
+ for item in self._queue:
+ # Check if adding this item would exceed limits
+ if len(batch) >= self.max_batch_size:
+ remaining.append(item)
+ continue
+
+ if batch_memory + item.estimated_memory_mb > self.max_memory_per_batch_mb and batch:
+ remaining.append(item)
+ continue
+
+ batch.append(item)
+ batch_memory += item.estimated_memory_mb
+
+ self._queue = remaining
+
+ return batch
+
+ def process_batch(
+ self,
+ processor_func: Callable[[Any], Any],
+ progress_callback: Optional[Callable[[int, int, BatchResult], None]] = None
+ ) -> List[BatchResult]:
+ """
+ Process a single batch of items.
+
+ Args:
+ processor_func: Function to process each item (receives item.data)
+ progress_callback: Optional callback(current, total, result)
+
+ Returns:
+ List of BatchResult for each item in the batch
+ """
+ batch = self._create_batch()
+ if not batch:
+ return []
+
+ self._processing = True
+ results = []
+ total = len(batch)
+
+ try:
+ for i, item in enumerate(batch):
+ start_time = time.time()
+ result = BatchResult(item_id=item.item_id, success=False)
+
+ try:
+ # Check memory before processing
+ is_available, stats = self.memory_guard.check_memory(
+ int(item.estimated_memory_mb)
+ )
+ if not is_available:
+ logger.warning(f"Insufficient memory for item {item.item_id}, cleaning up...")
+ self.memory_guard.clear_gpu_cache()
+ gc.collect()
+
+ # Process item
+ result.result = processor_func(item.data)
+ result.success = True
+
+ except Exception as e:
+ result.error = str(e)
+ self._total_failures += 1
+ logger.error(f"Failed to process item {item.item_id}: {e}")
+
+ result.processing_time_ms = (time.time() - start_time) * 1000
+ results.append(result)
+ self._total_processed += 1
+
+ # Call progress callback
+ if progress_callback:
+ progress_callback(i + 1, total, result)
+
+ self._total_batches += 1
+
+ finally:
+ self._processing = False
+
+ # Clean up after batch
+ if self.cleanup_between_batches:
+ self.memory_guard.clear_gpu_cache()
+ gc.collect()
+
+ return results
+
+ def process_all(
+ self,
+ processor_func: Callable[[Any], Any],
+ progress_callback: Optional[Callable[[int, int, BatchResult], None]] = None
+ ) -> List[BatchResult]:
+ """
+ Process all items in the queue.
+
+ Args:
+ processor_func: Function to process each item
+ progress_callback: Optional progress callback
+
+ Returns:
+ List of all BatchResults
+ """
+ all_results = []
+
+ while True:
+ with self._lock:
+ if not self._queue:
+ break
+
+ batch_results = self.process_batch(processor_func, progress_callback)
+ all_results.extend(batch_results)
+
+ return all_results
+
+ def get_queue_size(self) -> int:
+ """Get current queue size"""
+ with self._lock:
+ return len(self._queue)
+
+ def get_stats(self) -> Dict:
+ """Get processing statistics"""
+ with self._lock:
+ return {
+ "queue_size": len(self._queue),
+ "total_processed": self._total_processed,
+ "total_batches": self._total_batches,
+ "total_failures": self._total_failures,
+ "is_processing": self._processing,
+ "max_batch_size": self.max_batch_size,
+ "max_memory_per_batch_mb": self.max_memory_per_batch_mb,
+ }
+
+ def clear_queue(self):
+ """Clear the processing queue"""
+ with self._lock:
+ self._queue.clear()
+ logger.info("Batch queue cleared")
+
+
+class ProgressiveLoader:
+ """
+ Progressive page loader for multi-page documents.
+
+ Loads and processes pages incrementally to minimize memory usage.
+ Supports lookahead loading for better performance.
+ """
+
+ def __init__(
+ self,
+ lookahead_pages: int = 2,
+ memory_guard: Optional[MemoryGuard] = None,
+ cleanup_after_pages: int = 5
+ ):
+ """
+ Initialize ProgressiveLoader.
+
+ Args:
+ lookahead_pages: Number of pages to load ahead
+ memory_guard: MemoryGuard instance
+ cleanup_after_pages: Trigger cleanup after this many pages
+ """
+ self.lookahead_pages = lookahead_pages
+ self.memory_guard = memory_guard or MemoryGuard()
+ self.cleanup_after_pages = cleanup_after_pages
+
+ self._loaded_pages: Dict[int, Any] = {}
+ self._lock = threading.Lock()
+ self._current_page = 0
+ self._total_pages = 0
+ self._pages_since_cleanup = 0
+
+ logger.info(f"ProgressiveLoader initialized (lookahead={lookahead_pages})")
+
+ def initialize(self, total_pages: int):
+ """Initialize loader with total page count"""
+ with self._lock:
+ self._total_pages = total_pages
+ self._current_page = 0
+ self._loaded_pages.clear()
+ self._pages_since_cleanup = 0
+ logger.info(f"ProgressiveLoader initialized for {total_pages} pages")
+
+ def load_page(
+ self,
+ page_num: int,
+ loader_func: Callable[[int], Any],
+ unload_distant: bool = True
+ ) -> Any:
+ """
+ Load a specific page.
+
+ Args:
+ page_num: Page number to load (0-indexed)
+ loader_func: Function to load page (receives page_num)
+ unload_distant: Unload pages far from current position
+
+ Returns:
+ Loaded page data
+ """
+ with self._lock:
+ # Check if already loaded
+ if page_num in self._loaded_pages:
+ self._current_page = page_num
+ return self._loaded_pages[page_num]
+
+ # Load the page
+ logger.debug(f"Loading page {page_num}")
+ page_data = loader_func(page_num)
+
+ with self._lock:
+ self._loaded_pages[page_num] = page_data
+ self._current_page = page_num
+ self._pages_since_cleanup += 1
+
+ # Unload distant pages to save memory
+ if unload_distant:
+ self._unload_distant_pages()
+
+ # Trigger cleanup if needed
+ if self._pages_since_cleanup >= self.cleanup_after_pages:
+ self.memory_guard.clear_gpu_cache()
+ gc.collect()
+ self._pages_since_cleanup = 0
+
+ return page_data
+
+ def _unload_distant_pages(self):
+ """Unload pages far from current position"""
+ keep_range = range(
+ max(0, self._current_page - 1),
+ min(self._total_pages, self._current_page + self.lookahead_pages + 1)
+ )
+
+ pages_to_unload = [
+ p for p in self._loaded_pages.keys()
+ if p not in keep_range
+ ]
+
+ for page_num in pages_to_unload:
+ del self._loaded_pages[page_num]
+ logger.debug(f"Unloaded distant page {page_num}")
+
+ def prefetch_pages(
+ self,
+ start_page: int,
+ loader_func: Callable[[int], Any]
+ ):
+ """
+ Prefetch upcoming pages in background.
+
+ Args:
+ start_page: Starting page number
+ loader_func: Function to load page
+ """
+ for i in range(self.lookahead_pages):
+ page_num = start_page + i + 1
+ if page_num >= self._total_pages:
+ break
+
+ with self._lock:
+ if page_num in self._loaded_pages:
+ continue
+
+ try:
+ self.load_page(page_num, loader_func, unload_distant=False)
+ except Exception as e:
+ logger.warning(f"Prefetch failed for page {page_num}: {e}")
+
+ def iterate_pages(
+ self,
+ loader_func: Callable[[int], Any],
+ processor_func: Callable[[int, Any], Any],
+ progress_callback: Optional[Callable[[int, int], None]] = None
+ ) -> List[Any]:
+ """
+ Iterate through all pages with progressive loading.
+
+ Args:
+ loader_func: Function to load a page
+ processor_func: Function to process page (receives page_num, data)
+ progress_callback: Optional callback(current_page, total_pages)
+
+ Returns:
+ List of results from processor_func
+ """
+ results = []
+
+ for page_num in range(self._total_pages):
+ # Load page
+ page_data = self.load_page(page_num, loader_func)
+
+ # Process page
+ result = processor_func(page_num, page_data)
+ results.append(result)
+
+ # Report progress
+ if progress_callback:
+ progress_callback(page_num + 1, self._total_pages)
+
+ # Start prefetching next pages in background
+ if self.lookahead_pages > 0:
+ # Use thread for prefetching to not block
+ prefetch_thread = threading.Thread(
+ target=self.prefetch_pages,
+ args=(page_num, loader_func),
+ daemon=True
+ )
+ prefetch_thread.start()
+
+ return results
+
+ def get_loaded_pages(self) -> List[int]:
+ """Get list of currently loaded page numbers"""
+ with self._lock:
+ return list(self._loaded_pages.keys())
+
+ def get_stats(self) -> Dict:
+ """Get loader statistics"""
+ with self._lock:
+ return {
+ "total_pages": self._total_pages,
+ "current_page": self._current_page,
+ "loaded_pages_count": len(self._loaded_pages),
+ "loaded_pages": list(self._loaded_pages.keys()),
+ "lookahead_pages": self.lookahead_pages,
+ "pages_since_cleanup": self._pages_since_cleanup,
+ }
+
+ def clear(self):
+ """Clear all loaded pages"""
+ with self._lock:
+ self._loaded_pages.clear()
+ self._current_page = 0
+ self._pages_since_cleanup = 0
+ self.memory_guard.clear_gpu_cache()
+ gc.collect()
+
+
+class PriorityOperationQueue:
+ """
+ Priority queue for OCR operations.
+
+ Higher priority operations are processed first.
+ Supports timeout and cancellation.
+ """
+
+ def __init__(self, max_size: int = 100):
+ """
+ Initialize priority queue.
+
+ Args:
+ max_size: Maximum queue size (0 for unlimited)
+ """
+ self.max_size = max_size
+ self._queue: List[Tuple[BatchPriority, float, str, Any]] = [] # (priority, timestamp, id, data)
+ self._lock = threading.Lock()
+ self._condition = threading.Condition(self._lock)
+ self._cancelled: set = set()
+
+ # Statistics
+ self._total_enqueued = 0
+ self._total_dequeued = 0
+ self._total_cancelled = 0
+
+ logger.info(f"PriorityOperationQueue initialized (max_size={max_size})")
+
+ def enqueue(
+ self,
+ item_id: str,
+ data: Any,
+ priority: BatchPriority = BatchPriority.NORMAL,
+ timeout: Optional[float] = None
+ ) -> bool:
+ """
+ Add an operation to the queue.
+
+ Args:
+ item_id: Unique identifier for the operation
+ data: Operation data
+ priority: Operation priority
+ timeout: Optional timeout to wait for space in queue
+
+ Returns:
+ True if enqueued successfully
+ """
+ with self._condition:
+ # Wait for space if queue is full
+ if self.max_size > 0 and len(self._queue) >= self.max_size:
+ if timeout is not None:
+ result = self._condition.wait_for(
+ lambda: len(self._queue) < self.max_size,
+ timeout=timeout
+ )
+ if not result:
+ logger.warning(f"Queue full, timeout waiting to enqueue {item_id}")
+ return False
+ else:
+ logger.warning(f"Queue full, cannot enqueue {item_id}")
+ return False
+
+ # Add to queue (negative priority for max-heap behavior)
+ import heapq
+ heapq.heappush(
+ self._queue,
+ (-priority.value, time.time(), item_id, data)
+ )
+ self._total_enqueued += 1
+ self._condition.notify()
+
+ logger.debug(f"Enqueued operation {item_id} with priority {priority.name}")
+ return True
+
+ def dequeue(self, timeout: Optional[float] = None) -> Optional[Tuple[str, Any, BatchPriority]]:
+ """
+ Get the highest priority operation from the queue.
+
+ Args:
+ timeout: Optional timeout to wait for an item
+
+ Returns:
+ Tuple of (item_id, data, priority) or None if timeout
+ """
+ import heapq
+
+ with self._condition:
+ # Wait for an item
+ if not self._queue:
+ if timeout is not None:
+ result = self._condition.wait_for(
+ lambda: len(self._queue) > 0,
+ timeout=timeout
+ )
+ if not result:
+ return None
+ else:
+ return None
+
+ # Get highest priority item
+ neg_priority, _, item_id, data = heapq.heappop(self._queue)
+ priority = BatchPriority(-neg_priority)
+
+ # Skip if cancelled
+ if item_id in self._cancelled:
+ self._cancelled.discard(item_id)
+ self._total_cancelled += 1
+ self._condition.notify()
+ return self.dequeue(timeout=0) # Try next item
+
+ self._total_dequeued += 1
+ self._condition.notify()
+
+ logger.debug(f"Dequeued operation {item_id} with priority {priority.name}")
+ return item_id, data, priority
+
+ def cancel(self, item_id: str) -> bool:
+ """
+ Cancel a pending operation.
+
+ Args:
+ item_id: Operation identifier to cancel
+
+ Returns:
+ True if the operation was found and marked for cancellation
+ """
+ with self._lock:
+ # Check if item is in queue
+ for _, _, qid, _ in self._queue:
+ if qid == item_id:
+ self._cancelled.add(item_id)
+ logger.info(f"Operation {item_id} marked for cancellation")
+ return True
+ return False
+
+ def get_size(self) -> int:
+ """Get current queue size"""
+ with self._lock:
+ return len(self._queue)
+
+ def get_stats(self) -> Dict:
+ """Get queue statistics"""
+ with self._lock:
+ # Count by priority
+ priority_counts = {p.name: 0 for p in BatchPriority}
+ for neg_priority, _, _, _ in self._queue:
+ priority = BatchPriority(-neg_priority)
+ priority_counts[priority.name] += 1
+
+ return {
+ "queue_size": len(self._queue),
+ "max_size": self.max_size,
+ "total_enqueued": self._total_enqueued,
+ "total_dequeued": self._total_dequeued,
+ "total_cancelled": self._total_cancelled,
+ "pending_cancellations": len(self._cancelled),
+ "by_priority": priority_counts,
+ }
+
+ def clear(self):
+ """Clear the queue"""
+ with self._lock:
+ self._queue.clear()
+ self._cancelled.clear()
+ logger.info("Priority queue cleared")
+
+
+# =============================================================================
+# Section 5.2: Recovery Mechanisms
+# =============================================================================
+
+@dataclass
+class RecoveryState:
+ """State of recovery mechanism"""
+ last_recovery_time: float = 0.0
+ recovery_count: int = 0
+ in_cooldown: bool = False
+ cooldown_until: float = 0.0
+ last_error: Optional[str] = None
+
+
+class RecoveryManager:
+ """
+ Manages recovery mechanisms for memory issues and failures.
+
+ Features:
+ - Emergency memory release
+ - Cooldown period after recovery
+ - Recovery attempt limits
+ """
+
+ def __init__(
+ self,
+ cooldown_seconds: float = 30.0,
+ max_recovery_attempts: int = 3,
+ recovery_window_seconds: float = 300.0,
+ memory_guard: Optional[MemoryGuard] = None
+ ):
+ """
+ Initialize RecoveryManager.
+
+ Args:
+ cooldown_seconds: Cooldown period after recovery
+ max_recovery_attempts: Max recovery attempts within window
+ recovery_window_seconds: Window for counting recovery attempts
+ memory_guard: MemoryGuard instance
+ """
+ self.cooldown_seconds = cooldown_seconds
+ self.max_recovery_attempts = max_recovery_attempts
+ self.recovery_window_seconds = recovery_window_seconds
+ self.memory_guard = memory_guard or MemoryGuard()
+
+ self._state = RecoveryState()
+ self._lock = threading.Lock()
+ self._recovery_times: List[float] = []
+
+ # Callbacks
+ self._on_recovery_start: List[Callable] = []
+ self._on_recovery_complete: List[Callable[[bool], None]] = []
+
+ logger.info(f"RecoveryManager initialized (cooldown={cooldown_seconds}s)")
+
+ def register_callbacks(
+ self,
+ on_start: Optional[Callable] = None,
+ on_complete: Optional[Callable[[bool], None]] = None
+ ):
+ """Register recovery event callbacks"""
+ if on_start:
+ self._on_recovery_start.append(on_start)
+ if on_complete:
+ self._on_recovery_complete.append(on_complete)
+
+ def is_in_cooldown(self) -> bool:
+ """Check if currently in cooldown period"""
+ with self._lock:
+ if not self._state.in_cooldown:
+ return False
+
+ if time.time() >= self._state.cooldown_until:
+ self._state.in_cooldown = False
+ return False
+
+ return True
+
+ def get_cooldown_remaining(self) -> float:
+ """Get remaining cooldown time in seconds"""
+ with self._lock:
+ if not self._state.in_cooldown:
+ return 0.0
+ return max(0, self._state.cooldown_until - time.time())
+
+ def _count_recent_recoveries(self) -> int:
+ """Count recovery attempts within the window"""
+ cutoff = time.time() - self.recovery_window_seconds
+ with self._lock:
+ # Clean old entries
+ self._recovery_times = [t for t in self._recovery_times if t > cutoff]
+ return len(self._recovery_times)
+
+ def can_attempt_recovery(self) -> Tuple[bool, str]:
+ """
+ Check if recovery can be attempted.
+
+ Returns:
+ Tuple of (can_recover, reason)
+ """
+ if self.is_in_cooldown():
+ remaining = self.get_cooldown_remaining()
+ return False, f"In cooldown period ({remaining:.1f}s remaining)"
+
+ recent_count = self._count_recent_recoveries()
+ if recent_count >= self.max_recovery_attempts:
+ return False, f"Max recovery attempts ({self.max_recovery_attempts}) reached"
+
+ return True, "Recovery allowed"
+
+ def attempt_recovery(self, error: Optional[str] = None) -> bool:
+ """
+ Attempt memory recovery.
+
+ Args:
+ error: Optional error message that triggered recovery
+
+ Returns:
+ True if recovery was successful
+ """
+ can_recover, reason = self.can_attempt_recovery()
+ if not can_recover:
+ logger.warning(f"Cannot attempt recovery: {reason}")
+ return False
+
+ logger.info("Starting memory recovery...")
+
+ # Notify callbacks
+ for callback in self._on_recovery_start:
+ try:
+ callback()
+ except Exception as e:
+ logger.warning(f"Recovery start callback failed: {e}")
+
+ success = False
+ try:
+ # Step 1: Clear GPU cache
+ self.memory_guard.clear_gpu_cache()
+
+ # Step 2: Force garbage collection
+ gc.collect()
+
+ # Step 3: Check memory status
+ is_available, stats = self.memory_guard.check_memory()
+ success = is_available or stats.gpu_used_ratio < 0.9
+
+ if success:
+ logger.info(
+ f"Memory recovery successful. GPU: {stats.gpu_used_ratio*100:.1f}% used"
+ )
+ else:
+ logger.warning(
+ f"Memory recovery incomplete. GPU still at {stats.gpu_used_ratio*100:.1f}%"
+ )
+
+ except Exception as e:
+ logger.error(f"Recovery failed with error: {e}")
+ success = False
+
+ # Update state
+ with self._lock:
+ self._state.last_recovery_time = time.time()
+ self._state.recovery_count += 1
+ self._state.last_error = error
+ self._recovery_times.append(time.time())
+
+ # Enter cooldown
+ self._state.in_cooldown = True
+ self._state.cooldown_until = time.time() + self.cooldown_seconds
+
+ logger.info(f"Entering cooldown period ({self.cooldown_seconds}s)")
+
+ # Notify callbacks
+ for callback in self._on_recovery_complete:
+ try:
+ callback(success)
+ except Exception as e:
+ logger.warning(f"Recovery complete callback failed: {e}")
+
+ return success
+
+ def emergency_release(self, model_manager: Optional['ModelManager'] = None) -> bool:
+ """
+ Emergency memory release - more aggressive than normal recovery.
+
+ Args:
+ model_manager: Optional ModelManager to unload models from
+
+ Returns:
+ True if significant memory was freed
+ """
+ logger.warning("Initiating EMERGENCY memory release")
+
+ initial_stats = self.memory_guard.get_memory_stats()
+
+ # Step 1: Unload all models if model_manager provided
+ if model_manager:
+ logger.info("Unloading all models...")
+ try:
+ model_ids = list(model_manager.models.keys())
+ for model_id in model_ids:
+ model_manager.unload_model(model_id, force=True)
+ except Exception as e:
+ logger.error(f"Failed to unload models: {e}")
+
+ # Step 2: Clear all caches
+ self.memory_guard.clear_gpu_cache()
+
+ # Step 3: Multiple rounds of garbage collection
+ for i in range(3):
+ gc.collect()
+ time.sleep(0.1)
+
+ # Step 4: Check improvement
+ final_stats = self.memory_guard.get_memory_stats()
+ freed_mb = initial_stats.gpu_used_mb - final_stats.gpu_used_mb
+
+ logger.info(
+ f"Emergency release complete. Freed ~{freed_mb:.0f}MB. "
+ f"GPU: {final_stats.gpu_used_mb:.0f}MB / {final_stats.gpu_total_mb:.0f}MB "
+ f"({final_stats.gpu_used_ratio*100:.1f}%)"
+ )
+
+ return freed_mb > 100 # Consider success if freed >100MB
+
+ def get_state(self) -> Dict:
+ """Get current recovery state"""
+ with self._lock:
+ return {
+ "last_recovery_time": self._state.last_recovery_time,
+ "recovery_count": self._state.recovery_count,
+ "in_cooldown": self._state.in_cooldown,
+ "cooldown_remaining_seconds": self.get_cooldown_remaining(),
+ "recent_recoveries": self._count_recent_recoveries(),
+ "max_recovery_attempts": self.max_recovery_attempts,
+ "last_error": self._state.last_error,
+ }
+
+
+# Global recovery manager instance
+_recovery_manager: Optional[RecoveryManager] = None
+
+
+def get_recovery_manager(
+ cooldown_seconds: float = 30.0,
+ max_recovery_attempts: int = 3
+) -> RecoveryManager:
+ """
+ Get the global RecoveryManager instance.
+
+ Args:
+ cooldown_seconds: Cooldown period after recovery
+ max_recovery_attempts: Max recovery attempts within window
+
+ Returns:
+ RecoveryManager singleton instance
+ """
+ global _recovery_manager
+ if _recovery_manager is None:
+ _recovery_manager = RecoveryManager(
+ cooldown_seconds=cooldown_seconds,
+ max_recovery_attempts=max_recovery_attempts
+ )
+ return _recovery_manager
+
+
+def shutdown_recovery_manager():
+ """Shutdown the global RecoveryManager instance"""
+ global _recovery_manager
+ _recovery_manager = None
+
+
+# =============================================================================
+# Section 5.2: Memory Dump for Debugging
+# =============================================================================
+
+@dataclass
+class MemoryDumpEntry:
+ """Entry in a memory dump"""
+ object_type: str
+ object_id: str
+ size_bytes: int
+ ref_count: int
+ details: Dict = field(default_factory=dict)
+
+
+@dataclass
+class MemoryDump:
+ """Complete memory dump for debugging"""
+ timestamp: float
+ total_gpu_memory_mb: float
+ used_gpu_memory_mb: float
+ free_gpu_memory_mb: float
+ total_cpu_memory_mb: float
+ used_cpu_memory_mb: float
+ loaded_models: List[Dict]
+ active_predictions: int
+ queue_depth: int
+ service_pool_stats: Dict
+ recovery_state: Dict
+ python_objects: List[MemoryDumpEntry] = field(default_factory=list)
+ gc_stats: Dict = field(default_factory=dict)
+
+
+class MemoryDumper:
+ """
+ Creates memory dumps for debugging memory issues.
+
+ Captures comprehensive memory state including:
+ - GPU/CPU memory usage
+ - Loaded models and their references
+ - Active predictions and queue state
+ - Python garbage collector statistics
+ - Large object tracking
+ """
+
+ def __init__(self, memory_guard: Optional[MemoryGuard] = None):
+ """
+ Initialize MemoryDumper.
+
+ Args:
+ memory_guard: MemoryGuard instance for memory queries
+ """
+ self.memory_guard = memory_guard or MemoryGuard()
+ self._dump_history: List[MemoryDump] = []
+ self._max_history = 10
+ self._lock = threading.Lock()
+
+ logger.info("MemoryDumper initialized")
+
+ def create_dump(
+ self,
+ include_python_objects: bool = False,
+ min_object_size: int = 1048576 # 1MB
+ ) -> MemoryDump:
+ """
+ Create a memory dump capturing current state.
+
+ Args:
+ include_python_objects: Include large Python objects in dump
+ min_object_size: Minimum object size to include (bytes)
+
+ Returns:
+ MemoryDump with current memory state
+ """
+ logger.info("Creating memory dump...")
+
+ # Get memory stats
+ stats = self.memory_guard.get_memory_stats()
+
+ # Get model manager stats
+ loaded_models = []
+ try:
+ model_manager = get_model_manager()
+ model_stats = model_manager.get_model_stats()
+ loaded_models = [
+ {
+ "model_id": model_id,
+ "ref_count": info["ref_count"],
+ "estimated_memory_mb": info["estimated_memory_mb"],
+ "idle_seconds": info["idle_seconds"],
+ "is_loading": info["is_loading"],
+ }
+ for model_id, info in model_stats.get("models", {}).items()
+ ]
+ except Exception as e:
+ logger.debug(f"Could not get model stats: {e}")
+
+ # Get prediction semaphore stats
+ active_predictions = 0
+ queue_depth = 0
+ try:
+ semaphore = get_prediction_semaphore()
+ sem_stats = semaphore.get_stats()
+ active_predictions = sem_stats.get("active_predictions", 0)
+ queue_depth = sem_stats.get("queue_depth", 0)
+ except Exception as e:
+ logger.debug(f"Could not get semaphore stats: {e}")
+
+ # Get service pool stats
+ service_pool_stats = {}
+ try:
+ from app.services.service_pool import get_service_pool
+ pool = get_service_pool()
+ service_pool_stats = pool.get_pool_stats()
+ except Exception as e:
+ logger.debug(f"Could not get service pool stats: {e}")
+
+ # Get recovery state
+ recovery_state = {}
+ try:
+ recovery_manager = get_recovery_manager()
+ recovery_state = recovery_manager.get_state()
+ except Exception as e:
+ logger.debug(f"Could not get recovery state: {e}")
+
+ # Get GC stats
+ gc_stats = {
+ "counts": gc.get_count(),
+ "threshold": gc.get_threshold(),
+ "is_tracking": gc.isenabled(),
+ }
+
+ # Create dump
+ dump = MemoryDump(
+ timestamp=time.time(),
+ total_gpu_memory_mb=stats.gpu_total_mb,
+ used_gpu_memory_mb=stats.gpu_used_mb,
+ free_gpu_memory_mb=stats.gpu_free_mb,
+ total_cpu_memory_mb=stats.cpu_used_mb + stats.cpu_available_mb,
+ used_cpu_memory_mb=stats.cpu_used_mb,
+ loaded_models=loaded_models,
+ active_predictions=active_predictions,
+ queue_depth=queue_depth,
+ service_pool_stats=service_pool_stats,
+ recovery_state=recovery_state,
+ gc_stats=gc_stats,
+ )
+
+ # Optionally include large Python objects
+ if include_python_objects:
+ dump.python_objects = self._get_large_objects(min_object_size)
+
+ # Store in history
+ with self._lock:
+ self._dump_history.append(dump)
+ if len(self._dump_history) > self._max_history:
+ self._dump_history.pop(0)
+
+ logger.info(
+ f"Memory dump created: GPU {stats.gpu_used_mb:.0f}/{stats.gpu_total_mb:.0f}MB, "
+ f"{len(loaded_models)} models, {active_predictions} active predictions"
+ )
+
+ return dump
+
+ def _get_large_objects(self, min_size: int) -> List[MemoryDumpEntry]:
+ """Get list of large Python objects for debugging"""
+ large_objects = []
+
+ try:
+ import sys
+
+ # Get all objects tracked by GC
+ for obj in gc.get_objects():
+ try:
+ size = sys.getsizeof(obj)
+ if size >= min_size:
+ entry = MemoryDumpEntry(
+ object_type=type(obj).__name__,
+ object_id=str(id(obj)),
+ size_bytes=size,
+ ref_count=sys.getrefcount(obj),
+ details={
+ "module": getattr(type(obj), "__module__", "unknown"),
+ }
+ )
+ large_objects.append(entry)
+ except Exception:
+ pass # Skip objects that can't be measured
+
+ # Sort by size descending
+ large_objects.sort(key=lambda x: x.size_bytes, reverse=True)
+
+ # Limit to top 100
+ return large_objects[:100]
+
+ except Exception as e:
+ logger.warning(f"Failed to get large objects: {e}")
+ return []
+
+ def get_dump_history(self) -> List[MemoryDump]:
+ """Get recent dump history"""
+ with self._lock:
+ return list(self._dump_history)
+
+ def get_latest_dump(self) -> Optional[MemoryDump]:
+ """Get the most recent dump"""
+ with self._lock:
+ return self._dump_history[-1] if self._dump_history else None
+
+ def compare_dumps(
+ self,
+ dump1: MemoryDump,
+ dump2: MemoryDump
+ ) -> Dict:
+ """
+ Compare two memory dumps to identify changes.
+
+ Args:
+ dump1: First (earlier) dump
+ dump2: Second (later) dump
+
+ Returns:
+ Dictionary with comparison results
+ """
+ return {
+ "time_delta_seconds": dump2.timestamp - dump1.timestamp,
+ "gpu_memory_change_mb": dump2.used_gpu_memory_mb - dump1.used_gpu_memory_mb,
+ "cpu_memory_change_mb": dump2.used_cpu_memory_mb - dump1.used_cpu_memory_mb,
+ "model_count_change": len(dump2.loaded_models) - len(dump1.loaded_models),
+ "prediction_count_change": dump2.active_predictions - dump1.active_predictions,
+ "dump1_timestamp": dump1.timestamp,
+ "dump2_timestamp": dump2.timestamp,
+ }
+
+ def to_dict(self, dump: MemoryDump) -> Dict:
+ """Convert a MemoryDump to a dictionary for JSON serialization"""
+ return {
+ "timestamp": dump.timestamp,
+ "gpu": {
+ "total_mb": dump.total_gpu_memory_mb,
+ "used_mb": dump.used_gpu_memory_mb,
+ "free_mb": dump.free_gpu_memory_mb,
+ "utilization_percent": (
+ dump.used_gpu_memory_mb / dump.total_gpu_memory_mb * 100
+ if dump.total_gpu_memory_mb > 0 else 0
+ ),
+ },
+ "cpu": {
+ "total_mb": dump.total_cpu_memory_mb,
+ "used_mb": dump.used_cpu_memory_mb,
+ },
+ "models": dump.loaded_models,
+ "predictions": {
+ "active": dump.active_predictions,
+ "queue_depth": dump.queue_depth,
+ },
+ "service_pool": dump.service_pool_stats,
+ "recovery": dump.recovery_state,
+ "gc": dump.gc_stats,
+ "large_objects_count": len(dump.python_objects),
+ }
+
+
+# Global memory dumper instance
+_memory_dumper: Optional[MemoryDumper] = None
+
+
+def get_memory_dumper() -> MemoryDumper:
+ """Get the global MemoryDumper instance"""
+ global _memory_dumper
+ if _memory_dumper is None:
+ _memory_dumper = MemoryDumper()
+ return _memory_dumper
+
+
+def shutdown_memory_dumper():
+ """Shutdown the global MemoryDumper instance"""
+ global _memory_dumper
+ _memory_dumper = None
+
+
+# =============================================================================
+# Section 7.2: Prometheus Metrics Export
+# =============================================================================
+
+class PrometheusMetrics:
+ """
+ Prometheus metrics exporter for memory management.
+
+ Exposes metrics in Prometheus text format for monitoring:
+ - GPU/CPU memory usage
+ - Model lifecycle metrics
+ - Prediction semaphore metrics
+ - Service pool metrics
+ - Recovery metrics
+ """
+
+ # Metric names
+ METRIC_PREFIX = "tool_ocr_memory_"
+
+ def __init__(self):
+ """Initialize PrometheusMetrics"""
+ self._custom_metrics: Dict[str, float] = {}
+ self._lock = threading.Lock()
+ logger.info("PrometheusMetrics initialized")
+
+ def set_custom_metric(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
+ """
+ Set a custom metric value.
+
+ Args:
+ name: Metric name
+ value: Metric value
+ labels: Optional labels for the metric
+ """
+ with self._lock:
+ key = name if not labels else f"{name}{{{self._format_labels(labels)}}}"
+ self._custom_metrics[key] = value
+
+ def _format_labels(self, labels: Dict[str, str]) -> str:
+ """Format labels for Prometheus"""
+ return ",".join(f'{k}="{v}"' for k, v in sorted(labels.items()))
+
+ def _format_metric(self, name: str, value: float, help_text: str, metric_type: str = "gauge") -> str:
+ """Format a single metric in Prometheus format"""
+ lines = [
+ f"# HELP {self.METRIC_PREFIX}{name} {help_text}",
+ f"# TYPE {self.METRIC_PREFIX}{name} {metric_type}",
+ f"{self.METRIC_PREFIX}{name} {value}",
+ ]
+ return "\n".join(lines)
+
+ def _format_metric_with_labels(
+ self,
+ name: str,
+ values: List[Tuple[Dict[str, str], float]],
+ help_text: str,
+ metric_type: str = "gauge"
+ ) -> str:
+ """Format a metric with labels in Prometheus format"""
+ lines = [
+ f"# HELP {self.METRIC_PREFIX}{name} {help_text}",
+ f"# TYPE {self.METRIC_PREFIX}{name} {metric_type}",
+ ]
+ for labels, value in values:
+ label_str = self._format_labels(labels)
+ lines.append(f"{self.METRIC_PREFIX}{name}{{{label_str}}} {value}")
+ return "\n".join(lines)
+
+ def export_metrics(self) -> str:
+ """
+ Export all metrics in Prometheus text format.
+
+ Returns:
+ String containing metrics in Prometheus exposition format
+ """
+ metrics = []
+
+ # GPU Memory metrics
+ try:
+ guard = MemoryGuard()
+ stats = guard.get_memory_stats()
+
+ metrics.append(self._format_metric(
+ "gpu_total_bytes",
+ stats.gpu_total_mb * 1024 * 1024,
+ "Total GPU memory in bytes"
+ ))
+ metrics.append(self._format_metric(
+ "gpu_used_bytes",
+ stats.gpu_used_mb * 1024 * 1024,
+ "Used GPU memory in bytes"
+ ))
+ metrics.append(self._format_metric(
+ "gpu_free_bytes",
+ stats.gpu_free_mb * 1024 * 1024,
+ "Free GPU memory in bytes"
+ ))
+ metrics.append(self._format_metric(
+ "gpu_utilization_ratio",
+ stats.gpu_used_ratio,
+ "GPU memory utilization ratio (0-1)"
+ ))
+ metrics.append(self._format_metric(
+ "cpu_used_bytes",
+ stats.cpu_used_mb * 1024 * 1024,
+ "Used CPU memory in bytes"
+ ))
+ metrics.append(self._format_metric(
+ "cpu_available_bytes",
+ stats.cpu_available_mb * 1024 * 1024,
+ "Available CPU memory in bytes"
+ ))
+
+ guard.shutdown()
+ except Exception as e:
+ logger.debug(f"Could not get memory stats for metrics: {e}")
+
+ # Model Manager metrics
+ try:
+ model_manager = get_model_manager()
+ model_stats = model_manager.get_model_stats()
+
+ metrics.append(self._format_metric(
+ "models_loaded_total",
+ model_stats.get("total_models", 0),
+ "Total number of loaded models"
+ ))
+ metrics.append(self._format_metric(
+ "models_memory_bytes",
+ model_stats.get("total_estimated_memory_mb", 0) * 1024 * 1024,
+ "Estimated total memory used by loaded models in bytes"
+ ))
+
+ # Per-model metrics
+ model_values = []
+ for model_id, info in model_stats.get("models", {}).items():
+ model_values.append((
+ {"model_id": model_id},
+ info.get("ref_count", 0)
+ ))
+ if model_values:
+ metrics.append(self._format_metric_with_labels(
+ "model_ref_count",
+ model_values,
+ "Reference count per model"
+ ))
+
+ except Exception as e:
+ logger.debug(f"Could not get model stats for metrics: {e}")
+
+ # Prediction Semaphore metrics
+ try:
+ semaphore = get_prediction_semaphore()
+ sem_stats = semaphore.get_stats()
+
+ metrics.append(self._format_metric(
+ "predictions_active",
+ sem_stats.get("active_predictions", 0),
+ "Number of currently active predictions"
+ ))
+ metrics.append(self._format_metric(
+ "predictions_queue_depth",
+ sem_stats.get("queue_depth", 0),
+ "Number of predictions waiting in queue"
+ ))
+ metrics.append(self._format_metric(
+ "predictions_total",
+ sem_stats.get("total_predictions", 0),
+ "Total number of predictions processed",
+ metric_type="counter"
+ ))
+ metrics.append(self._format_metric(
+ "predictions_timeouts_total",
+ sem_stats.get("total_timeouts", 0),
+ "Total number of prediction timeouts",
+ metric_type="counter"
+ ))
+ metrics.append(self._format_metric(
+ "predictions_avg_wait_seconds",
+ sem_stats.get("average_wait_seconds", 0),
+ "Average wait time for predictions in seconds"
+ ))
+ metrics.append(self._format_metric(
+ "predictions_max_concurrent",
+ sem_stats.get("max_concurrent", 2),
+ "Maximum concurrent predictions allowed"
+ ))
+
+ except Exception as e:
+ logger.debug(f"Could not get semaphore stats for metrics: {e}")
+
+ # Service Pool metrics
+ try:
+ from app.services.service_pool import get_service_pool
+ pool = get_service_pool()
+ pool_stats = pool.get_pool_stats()
+
+ metrics.append(self._format_metric(
+ "pool_services_total",
+ pool_stats.get("total_services", 0),
+ "Total number of services in pool"
+ ))
+ metrics.append(self._format_metric(
+ "pool_services_available",
+ pool_stats.get("available_services", 0),
+ "Number of available services in pool"
+ ))
+ metrics.append(self._format_metric(
+ "pool_services_in_use",
+ pool_stats.get("in_use_services", 0),
+ "Number of services currently in use"
+ ))
+
+ pool_metrics = pool_stats.get("metrics", {})
+ metrics.append(self._format_metric(
+ "pool_acquisitions_total",
+ pool_metrics.get("total_acquisitions", 0),
+ "Total number of service acquisitions",
+ metric_type="counter"
+ ))
+ metrics.append(self._format_metric(
+ "pool_releases_total",
+ pool_metrics.get("total_releases", 0),
+ "Total number of service releases",
+ metric_type="counter"
+ ))
+ metrics.append(self._format_metric(
+ "pool_timeouts_total",
+ pool_metrics.get("total_timeouts", 0),
+ "Total number of acquisition timeouts",
+ metric_type="counter"
+ ))
+ metrics.append(self._format_metric(
+ "pool_errors_total",
+ pool_metrics.get("total_errors", 0),
+ "Total number of pool errors",
+ metric_type="counter"
+ ))
+
+ except Exception as e:
+ logger.debug(f"Could not get pool stats for metrics: {e}")
+
+ # Recovery Manager metrics
+ try:
+ recovery_manager = get_recovery_manager()
+ recovery_state = recovery_manager.get_state()
+
+ metrics.append(self._format_metric(
+ "recovery_count_total",
+ recovery_state.get("recovery_count", 0),
+ "Total number of recovery attempts",
+ metric_type="counter"
+ ))
+ metrics.append(self._format_metric(
+ "recovery_in_cooldown",
+ 1 if recovery_state.get("in_cooldown", False) else 0,
+ "Whether recovery is currently in cooldown (1=yes, 0=no)"
+ ))
+ metrics.append(self._format_metric(
+ "recovery_cooldown_remaining_seconds",
+ recovery_state.get("cooldown_remaining_seconds", 0),
+ "Remaining cooldown time in seconds"
+ ))
+ metrics.append(self._format_metric(
+ "recovery_recent_count",
+ recovery_state.get("recent_recoveries", 0),
+ "Number of recent recovery attempts within window"
+ ))
+
+ except Exception as e:
+ logger.debug(f"Could not get recovery stats for metrics: {e}")
+
+ # Custom metrics
+ with self._lock:
+ for name, value in self._custom_metrics.items():
+ if "{" in name:
+ # Metric with labels
+ base_name = name.split("{")[0]
+ metrics.append(f"{self.METRIC_PREFIX}{name} {value}")
+ else:
+ metrics.append(f"{self.METRIC_PREFIX}{name} {value}")
+
+ return "\n\n".join(metrics) + "\n"
+
+
+# Global Prometheus metrics instance
+_prometheus_metrics: Optional[PrometheusMetrics] = None
+
+
+def get_prometheus_metrics() -> PrometheusMetrics:
+ """Get the global PrometheusMetrics instance"""
+ global _prometheus_metrics
+ if _prometheus_metrics is None:
+ _prometheus_metrics = PrometheusMetrics()
+ return _prometheus_metrics
+
+
+def shutdown_prometheus_metrics():
+ """Shutdown the global PrometheusMetrics instance"""
+ global _prometheus_metrics
+ _prometheus_metrics = None
diff --git a/backend/app/services/ocr_service.py b/backend/app/services/ocr_service.py
index 1b57dd7..a419a6b 100644
--- a/backend/app/services/ocr_service.py
+++ b/backend/app/services/ocr_service.py
@@ -25,6 +25,7 @@ except ImportError:
from app.core.config import settings
from app.services.office_converter import OfficeConverter, OfficeConverterError
+from app.services.memory_manager import get_model_manager, MemoryConfig, MemoryGuard, prediction_context
# Import dual-track components
try:
@@ -96,6 +97,26 @@ class OCRService:
self._model_last_used = {} # Track last usage time for each model
self._memory_warning_logged = False
+ # Initialize MemoryGuard for enhanced memory monitoring
+ self._memory_guard = None
+ if settings.enable_model_lifecycle_management:
+ try:
+ memory_config = MemoryConfig(
+ warning_threshold=settings.memory_warning_threshold,
+ critical_threshold=settings.memory_critical_threshold,
+ emergency_threshold=settings.memory_emergency_threshold,
+ model_idle_timeout_seconds=settings.pp_structure_idle_timeout_seconds,
+ gpu_memory_limit_mb=settings.gpu_memory_limit_mb,
+ enable_cpu_fallback=settings.enable_cpu_fallback,
+ )
+ self._memory_guard = MemoryGuard(memory_config)
+ logger.debug("MemoryGuard initialized for OCRService")
+ except Exception as e:
+ logger.warning(f"Failed to initialize MemoryGuard: {e}")
+
+ # Track if CPU fallback was activated
+ self._cpu_fallback_active = False
+
self._detect_and_configure_gpu()
# Log GPU optimization settings
@@ -217,53 +238,91 @@ class OCRService:
def _check_gpu_memory_usage(self):
"""
Check GPU memory usage and log warnings if approaching limits.
- Implements memory optimization for RTX 4060 8GB.
+ Uses MemoryGuard for enhanced monitoring with multiple backends.
"""
if not self.use_gpu or not settings.enable_memory_optimization:
return
try:
- device_id = self.gpu_info.get('device_id', 0)
- memory_allocated = paddle.device.cuda.memory_allocated(device_id)
- memory_allocated_mb = memory_allocated / (1024**2)
- memory_limit_mb = settings.gpu_memory_limit_mb
+ # Use MemoryGuard if available for better monitoring
+ if self._memory_guard:
+ stats = self._memory_guard.get_memory_stats()
- utilization = (memory_allocated_mb / memory_limit_mb * 100) if memory_limit_mb > 0 else 0
+ # Log based on usage ratio
+ if stats.gpu_used_ratio > 0.90 and not self._memory_warning_logged:
+ logger.warning(
+ f"GPU memory usage critical: {stats.gpu_used_mb:.0f}MB / {stats.gpu_total_mb:.0f}MB "
+ f"({stats.gpu_used_ratio*100:.1f}%)"
+ )
+ logger.warning("Consider enabling auto_unload_unused_models or reducing batch size")
+ self._memory_warning_logged = True
- if utilization > 90 and not self._memory_warning_logged:
- logger.warning(f"GPU memory usage high: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
- logger.warning("Consider enabling auto_unload_unused_models or reducing batch size")
- self._memory_warning_logged = True
- elif utilization > 75:
- logger.info(f"GPU memory: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
+ # Trigger emergency cleanup if enabled
+ if settings.enable_emergency_cleanup:
+ self._cleanup_unused_models()
+ self._memory_guard.clear_gpu_cache()
+
+ elif stats.gpu_used_ratio > 0.75:
+ logger.info(
+ f"GPU memory: {stats.gpu_used_mb:.0f}MB / {stats.gpu_total_mb:.0f}MB "
+ f"({stats.gpu_used_ratio*100:.1f}%)"
+ )
+ else:
+ # Fallback to original implementation
+ device_id = self.gpu_info.get('device_id', 0)
+ memory_allocated = paddle.device.cuda.memory_allocated(device_id)
+ memory_allocated_mb = memory_allocated / (1024**2)
+ memory_limit_mb = settings.gpu_memory_limit_mb
+
+ utilization = (memory_allocated_mb / memory_limit_mb * 100) if memory_limit_mb > 0 else 0
+
+ if utilization > 90 and not self._memory_warning_logged:
+ logger.warning(f"GPU memory usage high: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
+ logger.warning("Consider enabling auto_unload_unused_models or reducing batch size")
+ self._memory_warning_logged = True
+ elif utilization > 75:
+ logger.info(f"GPU memory: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
except Exception as e:
logger.debug(f"Memory check failed: {e}")
def _cleanup_unused_models(self):
"""
- Clean up unused language models to free GPU memory.
+ Clean up unused models (including PP-StructureV3) to free GPU memory.
Models idle longer than model_idle_timeout_seconds will be unloaded.
+
+ Note: PP-StructureV3 is NO LONGER exempted from cleanup - it will be
+ unloaded based on pp_structure_idle_timeout_seconds configuration.
"""
if not settings.auto_unload_unused_models:
return
current_time = datetime.now()
- timeout = settings.model_idle_timeout_seconds
models_to_remove = []
for lang, last_used in self._model_last_used.items():
- if lang == 'structure': # Don't unload structure engine
- continue
+ # Use different timeout for structure engine vs language models
+ if lang == 'structure':
+ timeout = settings.pp_structure_idle_timeout_seconds
+ else:
+ timeout = settings.model_idle_timeout_seconds
+
idle_seconds = (current_time - last_used).total_seconds()
if idle_seconds > timeout:
models_to_remove.append(lang)
- for lang in models_to_remove:
- if lang in self.ocr_engines:
- logger.info(f"Unloading idle OCR engine for {lang} (idle {timeout}s)")
- del self.ocr_engines[lang]
- del self._model_last_used[lang]
+ for model_key in models_to_remove:
+ if model_key == 'structure':
+ if self.structure_engine is not None:
+ logger.info(f"Unloading idle PP-StructureV3 engine (idle {settings.pp_structure_idle_timeout_seconds}s)")
+ self._unload_structure_engine()
+ if model_key in self._model_last_used:
+ del self._model_last_used[model_key]
+ elif model_key in self.ocr_engines:
+ logger.info(f"Unloading idle OCR engine for {model_key} (idle {settings.model_idle_timeout_seconds}s)")
+ del self.ocr_engines[model_key]
+ if model_key in self._model_last_used:
+ del self._model_last_used[model_key]
if models_to_remove and self.use_gpu:
# Clear CUDA cache
@@ -273,6 +332,41 @@ class OCRService:
except Exception as e:
logger.debug(f"Cache clear failed: {e}")
+ def _unload_structure_engine(self):
+ """
+ Properly unload PP-StructureV3 engine and free GPU memory.
+ """
+ if self.structure_engine is None:
+ return
+
+ try:
+ # Clear internal engine components
+ if hasattr(self.structure_engine, 'table_engine'):
+ self.structure_engine.table_engine = None
+ if hasattr(self.structure_engine, 'text_detector'):
+ self.structure_engine.text_detector = None
+ if hasattr(self.structure_engine, 'text_recognizer'):
+ self.structure_engine.text_recognizer = None
+ if hasattr(self.structure_engine, 'layout_predictor'):
+ self.structure_engine.layout_predictor = None
+
+ # Delete the engine
+ del self.structure_engine
+ self.structure_engine = None
+
+ # Force garbage collection
+ gc.collect()
+
+ # Clear GPU cache
+ if self.use_gpu:
+ paddle.device.cuda.empty_cache()
+
+ logger.info("PP-StructureV3 engine unloaded successfully")
+
+ except Exception as e:
+ logger.warning(f"Error unloading PP-StructureV3: {e}")
+ self.structure_engine = None
+
def clear_gpu_cache(self):
"""
Manually clear GPU memory cache.
@@ -519,46 +613,160 @@ class OCRService:
logger.warning(f"GPU memory cleanup failed (non-critical): {e}")
# Don't fail the processing if cleanup fails
- def check_gpu_memory(self, required_mb: int = 2000) -> bool:
+ def check_gpu_memory(self, required_mb: int = 2000, enable_fallback: bool = True) -> bool:
"""
- Check if sufficient GPU memory is available.
+ Check if sufficient GPU memory is available using MemoryGuard.
+
+ This method now uses MemoryGuard for accurate memory queries across
+ multiple backends (pynvml, torch, paddle) instead of returning True
+ blindly for PaddlePaddle-only environments.
Args:
required_mb: Required memory in MB (default 2000MB for OCR models)
+ enable_fallback: If True and CPU fallback is enabled, switch to CPU mode
+ when memory is insufficient instead of returning False
Returns:
- True if sufficient memory is available or GPU is not used
+ True if sufficient memory is available, GPU is not used, or CPU fallback activated
"""
- try:
- # Check GPU memory using torch if available, otherwise use PaddlePaddle
- free_memory = None
+ # If not using GPU, always return True
+ if not self.use_gpu:
+ return True
- if TORCH_AVAILABLE and torch.cuda.is_available():
- free_memory = torch.cuda.mem_get_info()[0] / 1024**2
- elif paddle.device.is_compiled_with_cuda():
- # PaddlePaddle doesn't have direct API to get free memory,
- # so we rely on cleanup and continue
- logger.debug("Using PaddlePaddle GPU, memory info not directly available")
+ try:
+ # Use MemoryGuard if available for accurate multi-backend memory queries
+ if self._memory_guard:
+ is_available, stats = self._memory_guard.check_memory(
+ required_mb=required_mb,
+ device_id=self.gpu_info.get('device_id', 0)
+ )
+
+ if not is_available:
+ logger.warning(
+ f"GPU memory check failed: {stats.gpu_free_mb:.0f}MB free, "
+ f"{required_mb}MB required ({stats.gpu_used_ratio*100:.1f}% used)"
+ )
+
+ # Try to free memory
+ logger.info("Attempting memory cleanup before retry...")
+ self._cleanup_unused_models()
+ self._memory_guard.clear_gpu_cache()
+
+ # Check again
+ is_available, stats = self._memory_guard.check_memory(required_mb=required_mb)
+
+ if not is_available:
+ # Memory still insufficient after cleanup
+ if enable_fallback and settings.enable_cpu_fallback:
+ logger.warning(
+ f"Insufficient GPU memory ({stats.gpu_free_mb:.0f}MB) after cleanup. "
+ f"Activating CPU fallback mode."
+ )
+ self._activate_cpu_fallback()
+ return True # Continue with CPU
+ else:
+ logger.error(
+ f"Insufficient GPU memory: {stats.gpu_free_mb:.0f}MB available, "
+ f"{required_mb}MB required"
+ )
+ return False
+
+ logger.debug(
+ f"GPU memory check passed: {stats.gpu_free_mb:.0f}MB free "
+ f"({stats.gpu_used_ratio*100:.1f}% used)"
+ )
return True
- if free_memory is not None:
- if free_memory < required_mb:
- logger.warning(f"Low GPU memory: {free_memory:.0f}MB available, {required_mb}MB required")
- # Try to free memory
- self.cleanup_gpu_memory()
- # Check again
- if TORCH_AVAILABLE and torch.cuda.is_available():
- free_memory = torch.cuda.mem_get_info()[0] / 1024**2
- if free_memory < required_mb:
- logger.error(f"Insufficient GPU memory after cleanup: {free_memory:.0f}MB")
- return False
- logger.debug(f"GPU memory check passed: {free_memory:.0f}MB available")
+ else:
+ # Fallback to original implementation
+ free_memory = None
+
+ if TORCH_AVAILABLE and torch.cuda.is_available():
+ free_memory = torch.cuda.mem_get_info()[0] / 1024**2
+ elif paddle.device.is_compiled_with_cuda():
+ # PaddlePaddle doesn't have direct API to get free memory,
+ # use allocated memory to estimate
+ device_id = self.gpu_info.get('device_id', 0)
+ allocated = paddle.device.cuda.memory_allocated(device_id) / (1024**2)
+ total = settings.gpu_memory_limit_mb
+ free_memory = max(0, total - allocated)
+ logger.debug(f"Estimated free GPU memory: {free_memory:.0f}MB (total: {total}MB, allocated: {allocated:.0f}MB)")
+
+ if free_memory is not None:
+ if free_memory < required_mb:
+ logger.warning(f"Low GPU memory: {free_memory:.0f}MB available, {required_mb}MB required")
+ self.cleanup_gpu_memory()
+
+ # Recheck
+ if TORCH_AVAILABLE and torch.cuda.is_available():
+ free_memory = torch.cuda.mem_get_info()[0] / 1024**2
+ else:
+ allocated = paddle.device.cuda.memory_allocated(device_id) / (1024**2)
+ free_memory = max(0, total - allocated)
+
+ if free_memory < required_mb:
+ if enable_fallback and settings.enable_cpu_fallback:
+ logger.warning(f"Insufficient GPU memory after cleanup. Activating CPU fallback.")
+ self._activate_cpu_fallback()
+ return True
+ else:
+ logger.error(f"Insufficient GPU memory after cleanup: {free_memory:.0f}MB")
+ return False
+
+ logger.debug(f"GPU memory check passed: {free_memory:.0f}MB available")
+
+ return True
- return True
except Exception as e:
logger.warning(f"GPU memory check failed: {e}")
return True # Continue processing even if check fails
+ def _activate_cpu_fallback(self):
+ """
+ Activate CPU fallback mode when GPU memory is insufficient.
+ This disables GPU usage for the current service instance.
+ """
+ if self._cpu_fallback_active:
+ return # Already in CPU mode
+
+ logger.warning("=== CPU FALLBACK MODE ACTIVATED ===")
+ logger.warning("GPU memory insufficient, switching to CPU processing")
+ logger.warning("Performance will be significantly reduced")
+
+ self._cpu_fallback_active = True
+ self.use_gpu = False
+
+ # Update GPU info to reflect fallback
+ self.gpu_info['cpu_fallback'] = True
+ self.gpu_info['fallback_reason'] = 'GPU memory insufficient'
+
+ # Clear GPU cache to free memory
+ if self._memory_guard:
+ self._memory_guard.clear_gpu_cache()
+
+ def _restore_gpu_mode(self):
+ """
+ Attempt to restore GPU mode after CPU fallback.
+ Called when memory pressure has been relieved.
+ """
+ if not self._cpu_fallback_active:
+ return
+
+ if not self.gpu_available:
+ return
+
+ # Check if GPU memory is now available
+ if self._memory_guard:
+ is_available, stats = self._memory_guard.check_memory(
+ required_mb=settings.structure_model_memory_mb
+ )
+ if is_available:
+ logger.info("GPU memory available, restoring GPU mode")
+ self._cpu_fallback_active = False
+ self.use_gpu = True
+ self.gpu_info.pop('cpu_fallback', None)
+ self.gpu_info.pop('fallback_reason', None)
+
def convert_pdf_to_images(self, pdf_path: Path, output_dir: Path) -> List[Path]:
"""
Convert PDF to images (one per page)
@@ -626,6 +834,24 @@ class OCRService:
threshold = confidence_threshold if confidence_threshold is not None else self.confidence_threshold
try:
+ # Pre-operation memory check: Try to restore GPU if in fallback and memory available
+ if self._cpu_fallback_active:
+ self._restore_gpu_mode()
+ if not self._cpu_fallback_active:
+ logger.info("GPU mode restored for processing")
+
+ # Initial memory check before starting any heavy processing
+ # Estimate memory requirement based on image type
+ estimated_memory_mb = 2500 # Conservative estimate for full OCR + layout
+ if detect_layout:
+ estimated_memory_mb += 500 # Additional for PP-StructureV3
+
+ if not self.check_gpu_memory(required_mb=estimated_memory_mb, enable_fallback=True):
+ logger.warning(
+ f"Pre-operation memory check failed ({estimated_memory_mb}MB required). "
+ f"Processing will attempt to proceed but may encounter issues."
+ )
+
# Check if file is Office document
if self.office_converter.is_office_document(image_path):
logger.info(f"Detected Office document: {image_path.name}, converting to PDF")
@@ -748,9 +974,12 @@ class OCRService:
# Get OCR engine (for non-PDF images)
ocr_engine = self.get_ocr_engine(lang)
- # Check GPU memory before OCR processing
- if not self.check_gpu_memory(required_mb=1500):
- logger.warning("Insufficient GPU memory for OCR, attempting to proceed anyway")
+ # Secondary memory check before OCR processing
+ if not self.check_gpu_memory(required_mb=1500, enable_fallback=True):
+ logger.warning(
+ f"OCR memory check: insufficient GPU memory (1500MB required). "
+ f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU (low memory)'}"
+ )
# Get the actual image dimensions that OCR will use
from PIL import Image
@@ -950,6 +1179,18 @@ class OCRService:
Tuple of (layout_data, images_metadata)
"""
try:
+ # Pre-operation memory check for layout analysis
+ if self._cpu_fallback_active:
+ self._restore_gpu_mode()
+ if not self._cpu_fallback_active:
+ logger.info("GPU mode restored for layout analysis")
+
+ if not self.check_gpu_memory(required_mb=2000, enable_fallback=True):
+ logger.warning(
+ f"Layout analysis pre-check: insufficient GPU memory (2000MB required). "
+ f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU'}"
+ )
+
structure_engine = self._ensure_structure_engine(pp_structure_params)
# Try enhanced processing first
@@ -998,11 +1239,21 @@ class OCRService:
# Standard processing (original implementation)
logger.info(f"Running standard layout analysis on {image_path.name}")
- # Check GPU memory before processing
- if not self.check_gpu_memory(required_mb=2000):
- logger.warning("Insufficient GPU memory for PP-StructureV3, attempting to proceed anyway")
+ # Memory check before PP-StructureV3 processing
+ if not self.check_gpu_memory(required_mb=2000, enable_fallback=True):
+ logger.warning(
+ f"PP-StructureV3 memory check: insufficient GPU memory (2000MB required). "
+ f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU (low memory)'}"
+ )
- results = structure_engine.predict(str(image_path))
+ # Use prediction semaphore to control concurrent predictions
+ # This prevents OOM errors from multiple simultaneous PP-StructureV3.predict() calls
+ with prediction_context(timeout=settings.service_acquire_timeout_seconds) as acquired:
+ if not acquired:
+ logger.error("Failed to acquire prediction slot (timeout), returning empty layout")
+ return None, []
+
+ results = structure_engine.predict(str(image_path))
layout_elements = []
images_metadata = []
@@ -1254,6 +1505,46 @@ class OCRService:
if temp_pdf_path:
unified_doc.metadata.original_filename = file_path.name
+ # HYBRID MODE: Check if Direct track missed images (e.g., inline image blocks)
+ # If so, use OCR to extract images and merge them into the Direct result
+ pages_with_missing_images = self.direct_extraction_engine.check_document_for_missing_images(
+ actual_file_path
+ )
+ if pages_with_missing_images:
+ logger.info(f"Hybrid mode: Direct track missing images on pages {pages_with_missing_images}, using OCR to extract images")
+ try:
+ # Run OCR on the file to extract images
+ ocr_result = self.process_file_traditional(
+ actual_file_path, lang, detect_layout=True,
+ confidence_threshold=confidence_threshold,
+ output_dir=output_dir, pp_structure_params=pp_structure_params
+ )
+
+ # Convert OCR result to extract images
+ ocr_unified = self.ocr_to_unified_converter.convert(
+ ocr_result, actual_file_path, 0.0, lang
+ )
+
+ # Merge OCR-extracted images into Direct track result
+ images_added = self._merge_ocr_images_into_direct(
+ unified_doc, ocr_unified, pages_with_missing_images
+ )
+ if images_added > 0:
+ logger.info(f"Hybrid mode: Added {images_added} images from OCR to Direct track result")
+ unified_doc.metadata.processing_track = ProcessingTrack.HYBRID
+ else:
+ # Fallback: OCR didn't find images either, render inline image blocks directly
+ logger.info("Hybrid mode: OCR didn't find images, falling back to inline image rendering")
+ images_added = self.direct_extraction_engine.render_inline_image_regions(
+ actual_file_path, unified_doc, pages_with_missing_images, output_dir
+ )
+ if images_added > 0:
+ logger.info(f"Hybrid mode: Rendered {images_added} inline image regions")
+ unified_doc.metadata.processing_track = ProcessingTrack.HYBRID
+ except Exception as e:
+ logger.warning(f"Hybrid mode image extraction failed: {e}")
+ # Continue with Direct track result without images
+
# Use OCR track (either by recommendation or fallback)
if recommendation.track == "ocr":
# Use OCR for scanned documents, images, etc.
@@ -1269,17 +1560,19 @@ class OCRService:
)
unified_doc.document_id = document_id
- # Update processing track metadata
- unified_doc.metadata.processing_track = (
- ProcessingTrack.DIRECT if recommendation.track == "direct"
- else ProcessingTrack.OCR
- )
+ # Update processing track metadata (only if not already set to HYBRID)
+ if unified_doc.metadata.processing_track != ProcessingTrack.HYBRID:
+ unified_doc.metadata.processing_track = (
+ ProcessingTrack.DIRECT if recommendation.track == "direct"
+ else ProcessingTrack.OCR
+ )
# Calculate total processing time
processing_time = (datetime.now() - start_time).total_seconds()
unified_doc.metadata.processing_time = processing_time
- logger.info(f"Document processing completed in {processing_time:.2f}s using {recommendation.track} track")
+ actual_track = unified_doc.metadata.processing_track.value
+ logger.info(f"Document processing completed in {processing_time:.2f}s using {actual_track} track")
return unified_doc
@@ -1290,6 +1583,75 @@ class OCRService:
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
)
+ def _merge_ocr_images_into_direct(
+ self,
+ direct_doc: 'UnifiedDocument',
+ ocr_doc: 'UnifiedDocument',
+ pages_with_missing_images: List[int]
+ ) -> int:
+ """
+ Merge OCR-extracted images into Direct track result.
+
+ This is used in hybrid mode when Direct track couldn't extract certain
+ images (like logos composed of inline image blocks).
+
+ Args:
+ direct_doc: UnifiedDocument from Direct track
+ ocr_doc: UnifiedDocument from OCR track
+ pages_with_missing_images: List of page numbers (1-indexed) that need images
+
+ Returns:
+ Number of images added
+ """
+ images_added = 0
+
+ try:
+ # Get image element types to look for
+ image_types = {ElementType.FIGURE, ElementType.IMAGE, ElementType.LOGO}
+
+ for page_num in pages_with_missing_images:
+ # Find the target page in direct_doc
+ direct_page = None
+ for page in direct_doc.pages:
+ if page.page_number == page_num:
+ direct_page = page
+ break
+
+ if not direct_page:
+ continue
+
+ # Find the source page in ocr_doc
+ ocr_page = None
+ for page in ocr_doc.pages:
+ if page.page_number == page_num:
+ ocr_page = page
+ break
+
+ if not ocr_page:
+ continue
+
+ # Extract image elements from OCR page
+ for element in ocr_page.elements:
+ if element.type in image_types:
+ # Assign new element ID to avoid conflicts
+ new_element_id = f"hybrid_{element.element_id}"
+ element.element_id = new_element_id
+
+ # Add to direct page
+ direct_page.elements.append(element)
+ images_added += 1
+ logger.debug(f"Added image element {new_element_id} to page {page_num}")
+
+ # Update image count in direct_doc metadata
+ if images_added > 0:
+ current_images = direct_doc.metadata.total_images or 0
+ direct_doc.metadata.total_images = current_images + images_added
+
+ except Exception as e:
+ logger.error(f"Error merging OCR images into Direct track: {e}")
+
+ return images_added
+
def process_file_traditional(
self,
file_path: Path,
@@ -1441,13 +1803,16 @@ class OCRService:
UnifiedDocument if dual-track is enabled and use_dual_track=True,
Dict with legacy format otherwise
"""
- if use_dual_track and self.dual_track_enabled:
- # Use dual-track processing
+ # Use dual-track processing if:
+ # 1. use_dual_track is True (auto-detection), OR
+ # 2. force_track is specified (explicit track selection)
+ if (use_dual_track or force_track) and self.dual_track_enabled:
+ # Use dual-track processing (or forced track)
return self.process_with_dual_track(
file_path, lang, detect_layout, confidence_threshold, output_dir, force_track, pp_structure_params
)
else:
- # Use traditional OCR processing
+ # Use traditional OCR processing (no force_track support)
return self.process_file_traditional(
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
)
diff --git a/backend/app/services/pdf_generator_service.py b/backend/app/services/pdf_generator_service.py
index 4a4bf17..5e2e462 100644
--- a/backend/app/services/pdf_generator_service.py
+++ b/backend/app/services/pdf_generator_service.py
@@ -572,8 +572,10 @@ class PDFGeneratorService:
processing_track = unified_doc.metadata.get('processing_track')
# Route to track-specific rendering method
- is_direct_track = (processing_track == 'direct' or
- processing_track == ProcessingTrack.DIRECT)
+ # ProcessingTrack is (str, Enum), so comparing with enum value works for both string and enum
+ # HYBRID track uses Direct track rendering (Direct text/tables + OCR images)
+ is_direct_track = (processing_track == ProcessingTrack.DIRECT or
+ processing_track == ProcessingTrack.HYBRID)
logger.info(f"Processing track: {processing_track}, using {'Direct' if is_direct_track else 'OCR'} track rendering")
@@ -675,8 +677,11 @@ class PDFGeneratorService:
logger.info("=== Direct Track PDF Generation ===")
logger.info(f"Total pages: {len(unified_doc.pages)}")
- # Set current track for helper methods
- self.current_processing_track = 'direct'
+ # Set current track for helper methods (may be DIRECT or HYBRID)
+ if hasattr(unified_doc, 'metadata') and unified_doc.metadata:
+ self.current_processing_track = unified_doc.metadata.processing_track
+ else:
+ self.current_processing_track = ProcessingTrack.DIRECT
# Get page dimensions from first page (for canvas initialization)
if not unified_doc.pages:
@@ -1074,11 +1079,16 @@ class PDFGeneratorService:
# *** 優先級 1: 檢查 ocr_dimensions (UnifiedDocument 轉換來的) ***
if 'ocr_dimensions' in ocr_data:
dims = ocr_data['ocr_dimensions']
- w = float(dims.get('width', 0))
- h = float(dims.get('height', 0))
- if w > 0 and h > 0:
- logger.info(f"使用 ocr_dimensions 欄位的頁面尺寸: {w:.1f} x {h:.1f}")
- return (w, h)
+ # Handle both dict format {'width': w, 'height': h} and
+ # list format [{'page': 1, 'width': w, 'height': h}, ...]
+ if isinstance(dims, list) and len(dims) > 0:
+ dims = dims[0] # Use first page dimensions
+ if isinstance(dims, dict):
+ w = float(dims.get('width', 0))
+ h = float(dims.get('height', 0))
+ if w > 0 and h > 0:
+ logger.info(f"使用 ocr_dimensions 欄位的頁面尺寸: {w:.1f} x {h:.1f}")
+ return (w, h)
# *** 優先級 2: 檢查原始 JSON 的 dimensions ***
if 'dimensions' in ocr_data:
@@ -1418,8 +1428,8 @@ class PDFGeneratorService:
# Set font with track-specific styling
# Note: OCR track has no StyleInfo (extracted from images), so no advanced formatting
style_info = region.get('style')
- is_direct_track = (self.current_processing_track == 'direct' or
- self.current_processing_track == ProcessingTrack.DIRECT)
+ is_direct_track = (self.current_processing_track == ProcessingTrack.DIRECT or
+ self.current_processing_track == ProcessingTrack.HYBRID)
if style_info and is_direct_track:
# Direct track: Apply rich styling from StyleInfo
@@ -1661,10 +1671,15 @@ class PDFGeneratorService:
return
# Construct full path to image
+ # saved_path is relative to result_dir (e.g., "imgs/element_id.png")
image_path = result_dir / image_path_str
+ # Fallback for legacy data
if not image_path.exists():
- logger.warning(f"Image not found: {image_path}")
+ image_path = result_dir / Path(image_path_str).name
+
+ if not image_path.exists():
+ logger.warning(f"Image not found: {image_path_str} (in {result_dir})")
return
# Get bbox for positioning
@@ -2289,12 +2304,30 @@ class PDFGeneratorService:
col_widths = element.metadata['column_widths']
logger.debug(f"Using extracted column widths: {col_widths}")
- # Create table without rowHeights (will use canvas scaling instead)
- t = Table(table_content, colWidths=col_widths)
+ # Use original row heights from extraction if available
+ # Row heights must match the number of data rows exactly
+ row_heights_list = None
+ if element.metadata and 'row_heights' in element.metadata:
+ extracted_row_heights = element.metadata['row_heights']
+ num_data_rows = len(table_content)
+ num_height_rows = len(extracted_row_heights)
+
+ if num_height_rows == num_data_rows:
+ row_heights_list = extracted_row_heights
+ logger.debug(f"Using extracted row heights ({num_height_rows} rows): {row_heights_list}")
+ else:
+ # Row counts don't match - this can happen with merged cells or empty rows
+ logger.warning(f"Row height mismatch: {num_height_rows} heights for {num_data_rows} data rows, falling back to auto-sizing")
+
+ # Create table with both column widths and row heights for accurate sizing
+ t = Table(table_content, colWidths=col_widths, rowHeights=row_heights_list)
# Apply style with minimal padding to reduce table extension
+ # Use Chinese font to support special characters (℃, μm, ≦, ×, Ω, etc.)
+ font_for_table = self.font_name if self.font_registered else 'Helvetica'
style = TableStyle([
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('FONTNAME', (0, 0), (-1, -1), font_for_table),
('FONTSIZE', (0, 0), (-1, -1), 8),
('ALIGN', (0, 0), (-1, -1), 'LEFT'),
('VALIGN', (0, 0), (-1, -1), 'TOP'),
@@ -2307,8 +2340,8 @@ class PDFGeneratorService:
])
t.setStyle(style)
- # CRITICAL: Use canvas scaling to fit table within bbox
- # This is more reliable than rowHeights which doesn't always work
+ # Use canvas scaling as fallback to fit table within bbox
+ # With proper row heights, scaling should be minimal (close to 1.0)
# Step 1: Wrap to get actual rendered size
actual_width, actual_height = t.wrapOn(pdf_canvas, table_width * 10, table_height * 10)
@@ -2358,11 +2391,16 @@ class PDFGeneratorService:
logger.warning(f"No image path for element {element.element_id}")
return
- # Construct full path
+ # Construct full path to image
+ # saved_path is relative to result_dir (e.g., "document_id_p1_img0.png")
image_path = result_dir / image_path_str
+ # Fallback for legacy data
if not image_path.exists():
- logger.warning(f"Image not found: {image_path}")
+ image_path = result_dir / Path(image_path_str).name
+
+ if not image_path.exists():
+ logger.warning(f"Image not found: {image_path_str} (in {result_dir})")
return
# Get bbox
@@ -2388,7 +2426,7 @@ class PDFGeneratorService:
preserveAspectRatio=True
)
- logger.debug(f"Drew image: {image_path_str}")
+ logger.debug(f"Drew image: {image_path} (from: {original_path_str})")
except Exception as e:
logger.error(f"Failed to draw image element {element.element_id}: {e}")
diff --git a/backend/app/services/pp_structure_enhanced.py b/backend/app/services/pp_structure_enhanced.py
index d1f00ea..4331c38 100644
--- a/backend/app/services/pp_structure_enhanced.py
+++ b/backend/app/services/pp_structure_enhanced.py
@@ -21,6 +21,8 @@ except ImportError:
import paddle
from paddleocr import PPStructureV3
from app.models.unified_document import ElementType
+from app.core.config import settings
+from app.services.memory_manager import prediction_context
logger = logging.getLogger(__name__)
@@ -96,8 +98,22 @@ class PPStructureEnhanced:
try:
logger.info(f"Enhanced PP-StructureV3 analysis on {image_path.name}")
- # Perform structure analysis
- results = self.structure_engine.predict(str(image_path))
+ # Perform structure analysis with semaphore control
+ # This prevents OOM errors from multiple simultaneous predictions
+ with prediction_context(timeout=settings.service_acquire_timeout_seconds) as acquired:
+ if not acquired:
+ logger.error("Failed to acquire prediction slot (timeout), returning empty result")
+ return {
+ 'has_parsing_res_list': False,
+ 'elements': [],
+ 'total_elements': 0,
+ 'images': [],
+ 'tables': [],
+ 'element_types': {},
+ 'error': 'Prediction slot timeout'
+ }
+
+ results = self.structure_engine.predict(str(image_path))
all_elements = []
all_images = []
diff --git a/backend/app/services/service_pool.py b/backend/app/services/service_pool.py
new file mode 100644
index 0000000..07db0b1
--- /dev/null
+++ b/backend/app/services/service_pool.py
@@ -0,0 +1,468 @@
+"""
+Tool_OCR - OCR Service Pool
+Manages a pool of OCRService instances to prevent duplicate model loading
+and control concurrent GPU operations.
+"""
+
+import asyncio
+import logging
+import threading
+import time
+from contextlib import contextmanager
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any, Dict, List, Optional, TYPE_CHECKING
+
+from app.services.memory_manager import get_model_manager, MemoryConfig
+
+if TYPE_CHECKING:
+ from app.services.ocr_service import OCRService
+
+logger = logging.getLogger(__name__)
+
+
+class ServiceState(Enum):
+ """State of a pooled service"""
+ AVAILABLE = "available"
+ IN_USE = "in_use"
+ UNHEALTHY = "unhealthy"
+ INITIALIZING = "initializing"
+
+
+@dataclass
+class PooledService:
+ """Wrapper for a pooled OCRService instance"""
+ service: Any # OCRService
+ device: str
+ state: ServiceState = ServiceState.AVAILABLE
+ created_at: float = field(default_factory=time.time)
+ last_used: float = field(default_factory=time.time)
+ use_count: int = 0
+ error_count: int = 0
+ current_task_id: Optional[str] = None
+
+
+class PoolConfig:
+ """Configuration for the service pool"""
+
+ def __init__(
+ self,
+ max_services_per_device: int = 1,
+ max_total_services: int = 2,
+ acquire_timeout_seconds: float = 300.0,
+ max_queue_size: int = 50,
+ health_check_interval_seconds: int = 60,
+ max_consecutive_errors: int = 3,
+ service_idle_timeout_seconds: int = 600,
+ enable_auto_scaling: bool = False,
+ ):
+ self.max_services_per_device = max_services_per_device
+ self.max_total_services = max_total_services
+ self.acquire_timeout_seconds = acquire_timeout_seconds
+ self.max_queue_size = max_queue_size
+ self.health_check_interval_seconds = health_check_interval_seconds
+ self.max_consecutive_errors = max_consecutive_errors
+ self.service_idle_timeout_seconds = service_idle_timeout_seconds
+ self.enable_auto_scaling = enable_auto_scaling
+
+
+class OCRServicePool:
+ """
+ Pool of OCRService instances with concurrency control.
+
+ Features:
+ - Per-device instance management (one service per GPU)
+ - Queue-based task distribution
+ - Semaphore-based concurrency limits
+ - Health monitoring
+ - Automatic service recovery
+ """
+
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ """Singleton pattern"""
+ with cls._lock:
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
+
+ def __init__(self, config: Optional[PoolConfig] = None):
+ if self._initialized:
+ return
+
+ self.config = config or PoolConfig()
+ self.services: Dict[str, List[PooledService]] = {}
+ self.semaphores: Dict[str, threading.Semaphore] = {}
+ self.queues: Dict[str, List] = {}
+ self._pool_lock = threading.RLock()
+ self._condition = threading.Condition(self._pool_lock)
+
+ # Metrics
+ self._metrics = {
+ "total_acquisitions": 0,
+ "total_releases": 0,
+ "total_timeouts": 0,
+ "total_errors": 0,
+ "queue_waits": 0,
+ }
+
+ # Initialize default device pool
+ self._initialize_device("GPU:0")
+
+ self._initialized = True
+ logger.info("OCRServicePool initialized")
+
+ def _initialize_device(self, device: str):
+ """Initialize pool resources for a device"""
+ with self._pool_lock:
+ if device not in self.services:
+ self.services[device] = []
+ self.semaphores[device] = threading.Semaphore(
+ self.config.max_services_per_device
+ )
+ self.queues[device] = []
+ logger.info(f"Initialized pool for device {device}")
+
+ def _create_service(self, device: str) -> PooledService:
+ """
+ Create a new OCRService instance for the pool.
+
+ Args:
+ device: Device identifier (e.g., "GPU:0", "CPU")
+
+ Returns:
+ PooledService wrapper
+ """
+ # Import here to avoid circular imports
+ from app.services.ocr_service import OCRService
+
+ logger.info(f"Creating new OCRService for device {device}")
+ start_time = time.time()
+
+ # Create service instance
+ service = OCRService()
+
+ creation_time = time.time() - start_time
+ logger.info(f"OCRService created in {creation_time:.2f}s for device {device}")
+
+ return PooledService(
+ service=service,
+ device=device,
+ state=ServiceState.AVAILABLE
+ )
+
+ def acquire(
+ self,
+ device: str = "GPU:0",
+ timeout: Optional[float] = None,
+ task_id: Optional[str] = None
+ ) -> Optional[PooledService]:
+ """
+ Acquire an OCRService from the pool.
+
+ Args:
+ device: Preferred device (e.g., "GPU:0")
+ timeout: Maximum time to wait for a service
+ task_id: Optional task ID for tracking
+
+ Returns:
+ PooledService if available, None if timeout
+ """
+ timeout = timeout or self.config.acquire_timeout_seconds
+ self._initialize_device(device)
+
+ start_time = time.time()
+ deadline = start_time + timeout
+
+ with self._condition:
+ while True:
+ # Try to get an available service
+ service = self._try_acquire_service(device, task_id)
+ if service:
+ self._metrics["total_acquisitions"] += 1
+ return service
+
+ # Check if we can create a new service
+ if self._can_create_service(device):
+ try:
+ pooled = self._create_service(device)
+ pooled.state = ServiceState.IN_USE
+ pooled.current_task_id = task_id
+ pooled.use_count += 1
+ self.services[device].append(pooled)
+ self._metrics["total_acquisitions"] += 1
+ logger.info(f"Created and acquired new service for {device}")
+ return pooled
+ except Exception as e:
+ logger.error(f"Failed to create service for {device}: {e}")
+ self._metrics["total_errors"] += 1
+
+ # Wait for a service to become available
+ remaining = deadline - time.time()
+ if remaining <= 0:
+ self._metrics["total_timeouts"] += 1
+ logger.warning(f"Timeout waiting for service on {device}")
+ return None
+
+ self._metrics["queue_waits"] += 1
+ logger.debug(f"Waiting for service on {device} (timeout: {remaining:.1f}s)")
+ self._condition.wait(timeout=min(remaining, 1.0))
+
+ def _try_acquire_service(self, device: str, task_id: Optional[str]) -> Optional[PooledService]:
+ """Try to acquire an available service without waiting"""
+ for pooled in self.services.get(device, []):
+ if pooled.state == ServiceState.AVAILABLE:
+ pooled.state = ServiceState.IN_USE
+ pooled.last_used = time.time()
+ pooled.use_count += 1
+ pooled.current_task_id = task_id
+ logger.debug(f"Acquired existing service for {device} (use #{pooled.use_count})")
+ return pooled
+ return None
+
+ def _can_create_service(self, device: str) -> bool:
+ """Check if a new service can be created"""
+ device_count = len(self.services.get(device, []))
+ total_count = sum(len(services) for services in self.services.values())
+
+ return (
+ device_count < self.config.max_services_per_device and
+ total_count < self.config.max_total_services
+ )
+
+ def release(self, pooled: PooledService, error: Optional[Exception] = None):
+ """
+ Release a service back to the pool.
+
+ Args:
+ pooled: The pooled service to release
+ error: Optional error that occurred during use
+ """
+ with self._condition:
+ if error:
+ pooled.error_count += 1
+ self._metrics["total_errors"] += 1
+ logger.warning(f"Service released with error: {error}")
+
+ # Mark unhealthy if too many errors
+ if pooled.error_count >= self.config.max_consecutive_errors:
+ pooled.state = ServiceState.UNHEALTHY
+ logger.error(f"Service marked unhealthy after {pooled.error_count} errors")
+ else:
+ pooled.state = ServiceState.AVAILABLE
+ else:
+ pooled.error_count = 0 # Reset error count on success
+ pooled.state = ServiceState.AVAILABLE
+
+ pooled.last_used = time.time()
+ pooled.current_task_id = None
+ self._metrics["total_releases"] += 1
+
+ # Clean up GPU memory after release
+ try:
+ model_manager = get_model_manager()
+ model_manager.memory_guard.clear_gpu_cache()
+ except Exception as e:
+ logger.debug(f"Cache clear after release failed: {e}")
+
+ # Notify waiting threads
+ self._condition.notify_all()
+
+ logger.debug(f"Service released for device {pooled.device}")
+
+ @contextmanager
+ def acquire_context(
+ self,
+ device: str = "GPU:0",
+ timeout: Optional[float] = None,
+ task_id: Optional[str] = None
+ ):
+ """
+ Context manager for acquiring and releasing a service.
+
+ Usage:
+ with pool.acquire_context("GPU:0") as pooled:
+ result = pooled.service.process(...)
+ """
+ pooled = None
+ error = None
+ try:
+ pooled = self.acquire(device, timeout, task_id)
+ if pooled is None:
+ raise TimeoutError(f"Failed to acquire service for {device}")
+ yield pooled
+ except Exception as e:
+ error = e
+ raise
+ finally:
+ if pooled:
+ self.release(pooled, error)
+
+ def get_service(self, device: str = "GPU:0") -> Optional["OCRService"]:
+ """
+ Get a service directly (for backward compatibility).
+
+ This acquires a service and returns the underlying OCRService.
+ The caller is responsible for calling release_service() when done.
+
+ Args:
+ device: Device identifier
+
+ Returns:
+ OCRService instance or None
+ """
+ pooled = self.acquire(device)
+ if pooled:
+ return pooled.service
+ return None
+
+ def get_pool_stats(self) -> Dict:
+ """Get current pool statistics"""
+ with self._pool_lock:
+ stats = {
+ "devices": {},
+ "metrics": self._metrics.copy(),
+ "total_services": 0,
+ "available_services": 0,
+ "in_use_services": 0,
+ }
+
+ for device, services in self.services.items():
+ available = sum(1 for s in services if s.state == ServiceState.AVAILABLE)
+ in_use = sum(1 for s in services if s.state == ServiceState.IN_USE)
+ unhealthy = sum(1 for s in services if s.state == ServiceState.UNHEALTHY)
+
+ stats["devices"][device] = {
+ "total": len(services),
+ "available": available,
+ "in_use": in_use,
+ "unhealthy": unhealthy,
+ "max_allowed": self.config.max_services_per_device,
+ }
+
+ stats["total_services"] += len(services)
+ stats["available_services"] += available
+ stats["in_use_services"] += in_use
+
+ return stats
+
+ def health_check(self) -> Dict:
+ """
+ Perform health check on all pooled services.
+
+ Returns:
+ Health check results
+ """
+ results = {
+ "healthy": True,
+ "services": [],
+ "timestamp": time.time()
+ }
+
+ with self._pool_lock:
+ for device, services in self.services.items():
+ for idx, pooled in enumerate(services):
+ service_health = {
+ "device": device,
+ "index": idx,
+ "state": pooled.state.value,
+ "error_count": pooled.error_count,
+ "use_count": pooled.use_count,
+ "idle_seconds": time.time() - pooled.last_used,
+ }
+
+ # Check if service is responsive
+ if pooled.state == ServiceState.AVAILABLE:
+ try:
+ # Simple check - verify service has required attributes
+ has_process = hasattr(pooled.service, 'process')
+ has_gpu_status = hasattr(pooled.service, 'get_gpu_status')
+ service_health["responsive"] = has_process and has_gpu_status
+ except Exception as e:
+ service_health["responsive"] = False
+ service_health["error"] = str(e)
+ results["healthy"] = False
+ else:
+ service_health["responsive"] = pooled.state != ServiceState.UNHEALTHY
+
+ if pooled.state == ServiceState.UNHEALTHY:
+ results["healthy"] = False
+
+ results["services"].append(service_health)
+
+ return results
+
+ def recover_unhealthy(self):
+ """
+ Attempt to recover unhealthy services.
+ """
+ with self._pool_lock:
+ for device, services in self.services.items():
+ for idx, pooled in enumerate(services):
+ if pooled.state == ServiceState.UNHEALTHY:
+ logger.info(f"Attempting to recover unhealthy service {device}:{idx}")
+ try:
+ # Remove old service
+ services.remove(pooled)
+
+ # Create new service
+ new_pooled = self._create_service(device)
+ services.append(new_pooled)
+ logger.info(f"Successfully recovered service {device}:{idx}")
+
+ except Exception as e:
+ logger.error(f"Failed to recover service {device}:{idx}: {e}")
+
+ def shutdown(self):
+ """
+ Shutdown the pool and cleanup all services.
+ """
+ logger.info("OCRServicePool shutdown started")
+
+ with self._pool_lock:
+ for device, services in self.services.items():
+ for pooled in services:
+ try:
+ # Clean up service resources
+ if hasattr(pooled.service, 'cleanup_gpu_memory'):
+ pooled.service.cleanup_gpu_memory()
+ except Exception as e:
+ logger.warning(f"Error cleaning up service: {e}")
+
+ # Clear all pools
+ self.services.clear()
+ self.semaphores.clear()
+ self.queues.clear()
+
+ logger.info("OCRServicePool shutdown completed")
+
+
+# Global singleton instance
+_service_pool: Optional[OCRServicePool] = None
+
+
+def get_service_pool(config: Optional[PoolConfig] = None) -> OCRServicePool:
+ """
+ Get the global OCRServicePool instance.
+
+ Args:
+ config: Optional configuration (only used on first call)
+
+ Returns:
+ OCRServicePool singleton instance
+ """
+ global _service_pool
+ if _service_pool is None:
+ _service_pool = OCRServicePool(config)
+ return _service_pool
+
+
+def shutdown_service_pool():
+ """Shutdown the global service pool"""
+ global _service_pool
+ if _service_pool is not None:
+ _service_pool.shutdown()
+ _service_pool = None
diff --git a/backend/tests/services/test_memory_manager.py b/backend/tests/services/test_memory_manager.py
new file mode 100644
index 0000000..8f02096
--- /dev/null
+++ b/backend/tests/services/test_memory_manager.py
@@ -0,0 +1,1986 @@
+"""
+Tests for Memory Management Components
+
+Tests ModelManager, MemoryGuard, and related functionality.
+"""
+
+import gc
+import pytest
+import threading
+import time
+from unittest.mock import Mock, patch, MagicMock
+import sys
+
+# Mock paddle before importing memory_manager to avoid import errors
+# when paddle is not installed in the test environment
+paddle_mock = MagicMock()
+paddle_mock.is_compiled_with_cuda.return_value = False
+paddle_mock.device.cuda.device_count.return_value = 0
+paddle_mock.device.cuda.memory_allocated.return_value = 0
+paddle_mock.device.cuda.memory_reserved.return_value = 0
+paddle_mock.device.cuda.empty_cache = MagicMock()
+sys.modules['paddle'] = paddle_mock
+
+from app.services.memory_manager import (
+ ModelManager,
+ ModelEntry,
+ MemoryGuard,
+ MemoryConfig,
+ MemoryStats,
+ MemoryBackend,
+ get_model_manager,
+ shutdown_model_manager,
+)
+
+
+class TestMemoryConfig:
+ """Tests for MemoryConfig class"""
+
+ def test_default_values(self):
+ """Test default configuration values"""
+ config = MemoryConfig()
+ assert config.warning_threshold == 0.80
+ assert config.critical_threshold == 0.95
+ assert config.emergency_threshold == 0.98
+ assert config.model_idle_timeout_seconds == 300
+ assert config.enable_auto_cleanup is True
+ assert config.max_concurrent_predictions == 2
+
+ def test_custom_values(self):
+ """Test custom configuration values"""
+ config = MemoryConfig(
+ warning_threshold=0.70,
+ critical_threshold=0.85,
+ model_idle_timeout_seconds=600,
+ )
+ assert config.warning_threshold == 0.70
+ assert config.critical_threshold == 0.85
+ assert config.model_idle_timeout_seconds == 600
+
+
+class TestMemoryGuard:
+ """Tests for MemoryGuard class"""
+
+ def setup_method(self):
+ """Setup for each test"""
+ self.config = MemoryConfig(
+ warning_threshold=0.80,
+ critical_threshold=0.95,
+ emergency_threshold=0.98,
+ )
+
+ def test_initialization(self):
+ """Test MemoryGuard initialization"""
+ guard = MemoryGuard(self.config)
+ assert guard.config == self.config
+ assert guard.backend is not None
+ guard.shutdown()
+
+ def test_get_memory_stats(self):
+ """Test getting memory statistics"""
+ guard = MemoryGuard(self.config)
+ stats = guard.get_memory_stats()
+ assert isinstance(stats, MemoryStats)
+ assert stats.timestamp > 0
+ guard.shutdown()
+
+ def test_check_memory_below_warning(self):
+ """Test memory check when below warning threshold"""
+ guard = MemoryGuard(self.config)
+
+ # Mock stats to be below warning
+ with patch.object(guard, 'get_memory_stats') as mock_stats:
+ mock_stats.return_value = MemoryStats(
+ gpu_used_ratio=0.50,
+ gpu_free_mb=4000,
+ gpu_total_mb=8000,
+ )
+ is_available, stats = guard.check_memory(required_mb=1000)
+ assert is_available is True
+
+ guard.shutdown()
+
+ def test_check_memory_above_warning(self):
+ """Test memory check when above warning threshold"""
+ guard = MemoryGuard(self.config)
+
+ # Mock stats to be above warning
+ with patch.object(guard, 'get_memory_stats') as mock_stats:
+ mock_stats.return_value = MemoryStats(
+ gpu_used_ratio=0.85,
+ gpu_free_mb=1200,
+ gpu_total_mb=8000,
+ )
+ is_available, stats = guard.check_memory(required_mb=500)
+ # Should still return True (warning, not critical)
+ assert is_available is True
+
+ guard.shutdown()
+
+ def test_check_memory_above_critical(self):
+ """Test memory check when above critical threshold"""
+ guard = MemoryGuard(self.config)
+
+ # Mock stats to be above critical
+ with patch.object(guard, 'get_memory_stats') as mock_stats:
+ mock_stats.return_value = MemoryStats(
+ gpu_used_ratio=0.96,
+ gpu_free_mb=320,
+ gpu_total_mb=8000,
+ )
+ is_available, stats = guard.check_memory(required_mb=100)
+ # Should return False (critical)
+ assert is_available is False
+
+ guard.shutdown()
+
+ def test_check_memory_insufficient_free(self):
+ """Test memory check when insufficient free memory"""
+ guard = MemoryGuard(self.config)
+
+ # Mock stats with insufficient free memory
+ with patch.object(guard, 'get_memory_stats') as mock_stats:
+ mock_stats.return_value = MemoryStats(
+ gpu_used_ratio=0.70,
+ gpu_free_mb=500,
+ gpu_total_mb=8000,
+ )
+ is_available, stats = guard.check_memory(required_mb=1000)
+ # Should return False (not enough free)
+ assert is_available is False
+
+ guard.shutdown()
+
+ def test_alert_history(self):
+ """Test alert history functionality"""
+ guard = MemoryGuard(self.config)
+
+ # Trigger some alerts
+ guard._add_alert("warning", "Test warning")
+ guard._add_alert("critical", "Test critical")
+
+ alerts = guard.get_alerts()
+ assert len(alerts) == 2
+ assert alerts[0]["level"] == "warning"
+ assert alerts[1]["level"] == "critical"
+
+ guard.shutdown()
+
+ def test_clear_gpu_cache(self):
+ """Test GPU cache clearing"""
+ guard = MemoryGuard(self.config)
+ # Should not raise even if no GPU
+ guard.clear_gpu_cache()
+ guard.shutdown()
+
+
+class TestModelManager:
+ """Tests for ModelManager class"""
+
+ def setup_method(self):
+ """Reset singleton before each test"""
+ shutdown_model_manager()
+ ModelManager._instance = None
+ ModelManager._lock = threading.Lock()
+
+ def teardown_method(self):
+ """Cleanup after each test"""
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def test_singleton_pattern(self):
+ """Test that ModelManager is a singleton"""
+ config = MemoryConfig()
+ manager1 = ModelManager(config)
+ manager2 = ModelManager()
+ assert manager1 is manager2
+ manager1.teardown()
+
+ def test_get_or_load_model_new(self):
+ """Test loading a new model"""
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ mock_model = Mock()
+ loader_called = False
+
+ def loader():
+ nonlocal loader_called
+ loader_called = True
+ return mock_model
+
+ model = manager.get_or_load_model(
+ model_id="test_model",
+ loader_func=loader,
+ estimated_memory_mb=100
+ )
+
+ assert model is mock_model
+ assert loader_called is True
+ assert "test_model" in manager.models
+ assert manager.models["test_model"].ref_count == 1
+
+ manager.teardown()
+
+ def test_get_or_load_model_cached(self):
+ """Test getting a cached model"""
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ mock_model = Mock()
+ load_count = 0
+
+ def loader():
+ nonlocal load_count
+ load_count += 1
+ return mock_model
+
+ # First load
+ model1 = manager.get_or_load_model("test_model", loader)
+ # Second load (should return cached)
+ model2 = manager.get_or_load_model("test_model", loader)
+
+ assert model1 is model2
+ assert load_count == 1 # Loader should only be called once
+ assert manager.models["test_model"].ref_count == 2
+
+ manager.teardown()
+
+ def test_release_model(self):
+ """Test releasing model references"""
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ model = manager.get_or_load_model("test_model", lambda: Mock())
+ assert manager.models["test_model"].ref_count == 1
+
+ manager.release_model("test_model")
+ assert manager.models["test_model"].ref_count == 0
+
+ manager.teardown()
+
+ def test_unload_model(self):
+ """Test unloading a model"""
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ manager.get_or_load_model("test_model", lambda: Mock())
+ manager.release_model("test_model")
+
+ success = manager.unload_model("test_model")
+ assert success is True
+ assert "test_model" not in manager.models
+
+ manager.teardown()
+
+ def test_unload_model_with_references(self):
+ """Test that model with active references is not unloaded"""
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ manager.get_or_load_model("test_model", lambda: Mock())
+ # Don't release - ref_count is still 1
+
+ success = manager.unload_model("test_model", force=False)
+ assert success is False
+ assert "test_model" in manager.models
+
+ manager.teardown()
+
+ def test_unload_model_force(self):
+ """Test force unloading a model with references"""
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ manager.get_or_load_model("test_model", lambda: Mock())
+
+ success = manager.unload_model("test_model", force=True)
+ assert success is True
+ assert "test_model" not in manager.models
+
+ manager.teardown()
+
+ def test_cleanup_callback(self):
+ """Test cleanup callback is called on unload"""
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ cleanup_called = False
+ def cleanup():
+ nonlocal cleanup_called
+ cleanup_called = True
+
+ manager.get_or_load_model(
+ "test_model",
+ lambda: Mock(),
+ cleanup_callback=cleanup
+ )
+ manager.release_model("test_model")
+ manager.unload_model("test_model")
+
+ assert cleanup_called is True
+ manager.teardown()
+
+ def test_get_model_stats(self):
+ """Test getting model statistics"""
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ manager.get_or_load_model("model1", lambda: Mock(), estimated_memory_mb=100)
+ manager.get_or_load_model("model2", lambda: Mock(), estimated_memory_mb=200)
+
+ stats = manager.get_model_stats()
+ assert stats["total_models"] == 2
+ assert "model1" in stats["models"]
+ assert "model2" in stats["models"]
+ assert stats["total_estimated_memory_mb"] == 300
+
+ manager.teardown()
+
+ def test_idle_cleanup(self):
+ """Test idle model cleanup"""
+ config = MemoryConfig(
+ model_idle_timeout_seconds=1, # Short timeout for testing
+ memory_check_interval_seconds=60, # Don't auto-cleanup
+ )
+ manager = ModelManager(config)
+
+ manager.get_or_load_model("test_model", lambda: Mock())
+ manager.release_model("test_model")
+
+ # Manually set last_used to simulate idle
+ manager.models["test_model"].last_used = time.time() - 10
+
+ # Manually trigger cleanup
+ manager._cleanup_idle_models()
+
+ assert "test_model" not in manager.models
+ manager.teardown()
+
+
+class TestGetModelManager:
+ """Tests for get_model_manager helper function"""
+
+ def setup_method(self):
+ """Reset singleton before each test"""
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def teardown_method(self):
+ """Cleanup after each test"""
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def test_get_model_manager_creates_singleton(self):
+ """Test that get_model_manager creates a singleton"""
+ manager1 = get_model_manager()
+ manager2 = get_model_manager()
+ assert manager1 is manager2
+ shutdown_model_manager()
+
+ def test_shutdown_model_manager(self):
+ """Test shutdown_model_manager cleans up"""
+ manager = get_model_manager()
+ manager.get_or_load_model("test", lambda: Mock())
+
+ shutdown_model_manager()
+
+ # Should be able to create new manager
+ new_manager = get_model_manager()
+ assert "test" not in new_manager.models
+ shutdown_model_manager()
+
+
+class TestConcurrency:
+ """Tests for concurrent access"""
+
+ def setup_method(self):
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def teardown_method(self):
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def test_concurrent_model_access(self):
+ """Test concurrent model loading"""
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ load_count = 0
+ lock = threading.Lock()
+
+ def loader():
+ nonlocal load_count
+ with lock:
+ load_count += 1
+ time.sleep(0.1) # Simulate slow load
+ return Mock()
+
+ results = []
+
+ def worker():
+ model = manager.get_or_load_model("shared_model", loader)
+ results.append(model)
+
+ threads = [threading.Thread(target=worker) for _ in range(5)]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ # All threads should get the same model
+ assert len(set(id(r) for r in results)) == 1
+ # Loader should only be called once
+ assert load_count == 1
+ # Ref count should match thread count
+ assert manager.models["shared_model"].ref_count == 5
+
+ manager.teardown()
+
+
+class TestPredictionSemaphore:
+ """Tests for PredictionSemaphore class"""
+
+ def setup_method(self):
+ """Reset singleton before each test"""
+ from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore
+ shutdown_prediction_semaphore()
+ PredictionSemaphore._instance = None
+ PredictionSemaphore._lock = threading.Lock()
+
+ def teardown_method(self):
+ """Cleanup after each test"""
+ from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore
+ shutdown_prediction_semaphore()
+ PredictionSemaphore._instance = None
+
+ def test_singleton_pattern(self):
+ """Test that PredictionSemaphore is a singleton"""
+ from app.services.memory_manager import PredictionSemaphore
+ sem1 = PredictionSemaphore(max_concurrent=2)
+ sem2 = PredictionSemaphore(max_concurrent=4) # Different config should be ignored
+ assert sem1 is sem2
+ assert sem1._max_concurrent == 2
+
+ def test_acquire_release(self):
+ """Test basic acquire and release"""
+ from app.services.memory_manager import PredictionSemaphore
+ sem = PredictionSemaphore(max_concurrent=2)
+
+ # Acquire first slot
+ assert sem.acquire(task_id="task1") is True
+ assert sem._active_predictions == 1
+
+ # Acquire second slot
+ assert sem.acquire(task_id="task2") is True
+ assert sem._active_predictions == 2
+
+ # Release one
+ sem.release(task_id="task1")
+ assert sem._active_predictions == 1
+
+ # Release another
+ sem.release(task_id="task2")
+ assert sem._active_predictions == 0
+
+ def test_acquire_blocks_when_full(self):
+ """Test that acquire blocks when all slots are taken"""
+ from app.services.memory_manager import PredictionSemaphore
+ sem = PredictionSemaphore(max_concurrent=1)
+
+ # Acquire the only slot
+ assert sem.acquire(task_id="task1") is True
+
+ # Try to acquire another with short timeout - should fail
+ result = sem.acquire(timeout=0.1, task_id="task2")
+ assert result is False
+ assert sem._total_timeouts == 1
+
+ # Release first slot
+ sem.release(task_id="task1")
+
+ # Now should succeed
+ assert sem.acquire(task_id="task3") is True
+
+ def test_get_stats(self):
+ """Test statistics tracking"""
+ from app.services.memory_manager import PredictionSemaphore
+ sem = PredictionSemaphore(max_concurrent=2)
+
+ sem.acquire(task_id="task1")
+ sem.acquire(task_id="task2")
+ sem.release(task_id="task1")
+
+ stats = sem.get_stats()
+ assert stats["max_concurrent"] == 2
+ assert stats["active_predictions"] == 1
+ assert stats["total_predictions"] == 2
+ assert stats["total_timeouts"] == 0
+
+ def test_concurrent_acquire(self):
+ """Test concurrent access to semaphore"""
+ from app.services.memory_manager import PredictionSemaphore
+ sem = PredictionSemaphore(max_concurrent=2)
+
+ results = []
+ acquired_count = 0
+ lock = threading.Lock()
+
+ def worker(worker_id):
+ nonlocal acquired_count
+ if sem.acquire(timeout=1.0, task_id=f"task_{worker_id}"):
+ with lock:
+ acquired_count += 1
+ time.sleep(0.1) # Simulate work
+ sem.release(task_id=f"task_{worker_id}")
+ results.append(worker_id)
+
+ # Start 4 workers but only 2 slots
+ threads = [threading.Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ # All should complete eventually
+ assert len(results) == 4
+
+
+class TestPredictionContext:
+ """Tests for prediction_context helper function"""
+
+ def setup_method(self):
+ """Reset singleton before each test"""
+ from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore
+ shutdown_prediction_semaphore()
+ PredictionSemaphore._instance = None
+ PredictionSemaphore._lock = threading.Lock()
+
+ def teardown_method(self):
+ """Cleanup after each test"""
+ from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore
+ shutdown_prediction_semaphore()
+ PredictionSemaphore._instance = None
+
+ def test_context_manager_success(self):
+ """Test context manager for successful prediction"""
+ from app.services.memory_manager import prediction_context, get_prediction_semaphore
+
+ # Initialize semaphore
+ sem = get_prediction_semaphore(max_concurrent=2)
+
+ with prediction_context(timeout=5.0, task_id="test") as acquired:
+ assert acquired is True
+ assert sem._active_predictions == 1
+
+ # After exiting context, slot should be released
+ assert sem._active_predictions == 0
+
+ def test_context_manager_with_exception(self):
+ """Test context manager releases on exception"""
+ from app.services.memory_manager import prediction_context, get_prediction_semaphore
+
+ sem = get_prediction_semaphore(max_concurrent=2)
+
+ with pytest.raises(ValueError):
+ with prediction_context(task_id="test") as acquired:
+ assert acquired is True
+ raise ValueError("Test error")
+
+ # Should still release the slot
+ assert sem._active_predictions == 0
+
+ def test_context_manager_timeout(self):
+ """Test context manager when timeout occurs"""
+ from app.services.memory_manager import prediction_context, get_prediction_semaphore
+
+ sem = get_prediction_semaphore(max_concurrent=1)
+
+ # Acquire the only slot
+ sem.acquire(task_id="blocker")
+
+ # Context manager should timeout
+ with prediction_context(timeout=0.1, task_id="waiter") as acquired:
+ assert acquired is False
+
+ # Release blocker
+ sem.release(task_id="blocker")
+
+
+class TestBatchProcessor:
+ """Tests for BatchProcessor class"""
+
+ def test_add_item(self):
+ """Test adding items to batch queue"""
+ from app.services.memory_manager import BatchProcessor, BatchItem, BatchPriority
+
+ processor = BatchProcessor(max_batch_size=5)
+ item = BatchItem(item_id="test1", data="data1", priority=BatchPriority.NORMAL)
+ processor.add_item(item)
+
+ assert processor.get_queue_size() == 1
+
+ def test_add_items_sorted_by_priority(self):
+ """Test that items are sorted by priority"""
+ from app.services.memory_manager import BatchProcessor, BatchItem, BatchPriority
+
+ processor = BatchProcessor(max_batch_size=5)
+
+ processor.add_item(BatchItem(item_id="low", data="low", priority=BatchPriority.LOW))
+ processor.add_item(BatchItem(item_id="high", data="high", priority=BatchPriority.HIGH))
+ processor.add_item(BatchItem(item_id="normal", data="normal", priority=BatchPriority.NORMAL))
+
+ # High priority should be first
+ assert processor._queue[0].item_id == "high"
+ assert processor._queue[1].item_id == "normal"
+ assert processor._queue[2].item_id == "low"
+
+ def test_process_batch(self):
+ """Test processing a batch of items"""
+ from app.services.memory_manager import BatchProcessor, BatchItem, BatchPriority
+
+ processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False)
+
+ for i in range(5):
+ processor.add_item(BatchItem(item_id=f"item{i}", data=i))
+
+ results = processor.process_batch(lambda x: x * 2)
+
+ # Should process max_batch_size items
+ assert len(results) == 3
+ assert all(r.success for r in results)
+ assert processor.get_queue_size() == 2 # 2 remaining
+
+ def test_process_all(self):
+ """Test processing all items"""
+ from app.services.memory_manager import BatchProcessor, BatchItem
+
+ processor = BatchProcessor(max_batch_size=2, cleanup_between_batches=False)
+
+ for i in range(5):
+ processor.add_item(BatchItem(item_id=f"item{i}", data=i))
+
+ results = processor.process_all(lambda x: x * 2)
+
+ assert len(results) == 5
+ assert processor.get_queue_size() == 0
+
+ def test_process_with_failure(self):
+ """Test handling of processing failures"""
+ from app.services.memory_manager import BatchProcessor, BatchItem
+
+ processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False)
+
+ processor.add_item(BatchItem(item_id="good", data=1))
+ processor.add_item(BatchItem(item_id="bad", data="error"))
+
+ def processor_func(data):
+ if data == "error":
+ raise ValueError("Test error")
+ return data * 2
+
+ results = processor.process_all(processor_func)
+
+ assert len(results) == 2
+ assert results[0].success is True
+ assert results[1].success is False
+ assert "Test error" in results[1].error
+
+ def test_memory_constraint_batching(self):
+ """Test that batches respect memory constraints"""
+ from app.services.memory_manager import BatchProcessor, BatchItem
+
+ processor = BatchProcessor(
+ max_batch_size=10,
+ max_memory_per_batch_mb=100.0,
+ cleanup_between_batches=False
+ )
+
+ # Add items that exceed memory limit
+ processor.add_item(BatchItem(item_id="item1", data=1, estimated_memory_mb=60.0))
+ processor.add_item(BatchItem(item_id="item2", data=2, estimated_memory_mb=60.0))
+ processor.add_item(BatchItem(item_id="item3", data=3, estimated_memory_mb=60.0))
+
+ results = processor.process_batch(lambda x: x)
+
+ # Should only process items that fit in memory limit
+ # First item (60MB) fits, second (60MB) doesn't exceed 100MB together, third does
+ assert len(results) == 1 or len(results) == 2
+ assert processor.get_queue_size() >= 1
+
+ def test_get_stats(self):
+ """Test statistics tracking"""
+ from app.services.memory_manager import BatchProcessor, BatchItem
+
+ processor = BatchProcessor(max_batch_size=2, cleanup_between_batches=False)
+ processor.add_item(BatchItem(item_id="item1", data=1))
+ processor.add_item(BatchItem(item_id="item2", data=2))
+
+ processor.process_all(lambda x: x)
+
+ stats = processor.get_stats()
+ assert stats["total_processed"] == 2
+ assert stats["total_batches"] == 1
+ assert stats["total_failures"] == 0
+
+
+class TestProgressiveLoader:
+ """Tests for ProgressiveLoader class"""
+
+ def test_initialize(self):
+ """Test initializing loader with page count"""
+ from app.services.memory_manager import ProgressiveLoader
+
+ loader = ProgressiveLoader(lookahead_pages=2)
+ loader.initialize(total_pages=10)
+
+ stats = loader.get_stats()
+ assert stats["total_pages"] == 10
+ assert stats["loaded_pages_count"] == 0
+
+ def test_load_page(self):
+ """Test loading a single page"""
+ from app.services.memory_manager import ProgressiveLoader
+
+ loader = ProgressiveLoader(lookahead_pages=2, cleanup_after_pages=10)
+ loader.initialize(total_pages=5)
+
+ page_data = loader.load_page(0, lambda p: f"page_{p}_data")
+
+ assert page_data == "page_0_data"
+ assert 0 in loader.get_loaded_pages()
+
+ def test_load_page_caching(self):
+ """Test that loaded pages are cached"""
+ from app.services.memory_manager import ProgressiveLoader
+
+ load_count = 0
+
+ def loader_func(page_num):
+ nonlocal load_count
+ load_count += 1
+ return f"page_{page_num}"
+
+ loader = ProgressiveLoader(lookahead_pages=2, cleanup_after_pages=10)
+ loader.initialize(total_pages=5)
+
+ # Load same page twice
+ loader.load_page(0, loader_func)
+ loader.load_page(0, loader_func)
+
+ assert load_count == 1 # Should only load once
+
+ def test_unload_distant_pages(self):
+ """Test that distant pages are unloaded"""
+ from app.services.memory_manager import ProgressiveLoader
+
+ loader = ProgressiveLoader(lookahead_pages=1, cleanup_after_pages=100)
+ loader.initialize(total_pages=10)
+
+ # Load several pages
+ for i in range(5):
+ loader.load_page(i, lambda p: f"page_{p}")
+
+ # After loading page 4, distant pages should be unloaded
+ loaded = loader.get_loaded_pages()
+ # Should keep only pages near current (4): pages 3, 4, and potentially 5
+ assert 0 not in loaded # Page 0 should be unloaded
+
+ def test_iterate_pages(self):
+ """Test iterating through all pages"""
+ from app.services.memory_manager import ProgressiveLoader
+
+ loader = ProgressiveLoader(lookahead_pages=0, cleanup_after_pages=100)
+ loader.initialize(total_pages=5)
+
+ results = loader.iterate_pages(
+ loader_func=lambda p: f"page_{p}",
+ processor_func=lambda p, data: f"processed_{data}"
+ )
+
+ assert len(results) == 5
+ assert results[0] == "processed_page_0"
+ assert results[4] == "processed_page_4"
+
+ def test_progress_callback(self):
+ """Test progress callback during iteration"""
+ from app.services.memory_manager import ProgressiveLoader
+
+ loader = ProgressiveLoader(lookahead_pages=0, cleanup_after_pages=100)
+ loader.initialize(total_pages=3)
+
+ progress_reports = []
+
+ def callback(current, total):
+ progress_reports.append((current, total))
+
+ loader.iterate_pages(
+ loader_func=lambda p: p,
+ processor_func=lambda p, d: d,
+ progress_callback=callback
+ )
+
+ assert progress_reports == [(1, 3), (2, 3), (3, 3)]
+
+ def test_clear(self):
+ """Test clearing loaded pages"""
+ from app.services.memory_manager import ProgressiveLoader
+
+ loader = ProgressiveLoader(cleanup_after_pages=100)
+ loader.initialize(total_pages=5)
+
+ for i in range(3):
+ loader.load_page(i, lambda p: p)
+
+ loader.clear()
+
+ assert loader.get_loaded_pages() == []
+
+
+class TestPriorityOperationQueue:
+ """Tests for PriorityOperationQueue class"""
+
+ def test_enqueue_dequeue(self):
+ """Test basic enqueue and dequeue"""
+ from app.services.memory_manager import PriorityOperationQueue, BatchPriority
+
+ queue = PriorityOperationQueue(max_size=10)
+
+ queue.enqueue("item1", "data1", BatchPriority.NORMAL)
+ result = queue.dequeue()
+
+ assert result is not None
+ item_id, data, priority = result
+ assert item_id == "item1"
+ assert data == "data1"
+ assert priority == BatchPriority.NORMAL
+
+ def test_priority_ordering(self):
+ """Test that higher priority items are dequeued first"""
+ from app.services.memory_manager import PriorityOperationQueue, BatchPriority
+
+ queue = PriorityOperationQueue()
+
+ queue.enqueue("low", "low_data", BatchPriority.LOW)
+ queue.enqueue("high", "high_data", BatchPriority.HIGH)
+ queue.enqueue("normal", "normal_data", BatchPriority.NORMAL)
+ queue.enqueue("critical", "critical_data", BatchPriority.CRITICAL)
+
+ # Dequeue in priority order
+ item_id, _, _ = queue.dequeue()
+ assert item_id == "critical"
+
+ item_id, _, _ = queue.dequeue()
+ assert item_id == "high"
+
+ item_id, _, _ = queue.dequeue()
+ assert item_id == "normal"
+
+ item_id, _, _ = queue.dequeue()
+ assert item_id == "low"
+
+ def test_cancel(self):
+ """Test cancelling an operation"""
+ from app.services.memory_manager import PriorityOperationQueue, BatchPriority
+
+ queue = PriorityOperationQueue()
+
+ queue.enqueue("item1", "data1", BatchPriority.NORMAL)
+ queue.enqueue("item2", "data2", BatchPriority.NORMAL)
+
+ # Cancel item1
+ assert queue.cancel("item1") is True
+
+ # Dequeue should skip cancelled item
+ result = queue.dequeue()
+ assert result[0] == "item2"
+
+ def test_dequeue_empty_returns_none(self):
+ """Test that dequeue on empty queue returns None"""
+ from app.services.memory_manager import PriorityOperationQueue
+
+ queue = PriorityOperationQueue()
+ result = queue.dequeue(timeout=0.1)
+ assert result is None
+
+ def test_max_size_limit(self):
+ """Test that queue respects max size"""
+ from app.services.memory_manager import PriorityOperationQueue, BatchPriority
+
+ queue = PriorityOperationQueue(max_size=2)
+
+ assert queue.enqueue("item1", "data1") is True
+ assert queue.enqueue("item2", "data2") is True
+ # Third item should fail without timeout
+ assert queue.enqueue("item3", "data3", timeout=0.1) is False
+
+ def test_get_stats(self):
+ """Test queue statistics"""
+ from app.services.memory_manager import PriorityOperationQueue, BatchPriority
+
+ queue = PriorityOperationQueue()
+
+ queue.enqueue("item1", "data1", BatchPriority.HIGH)
+ queue.enqueue("item2", "data2", BatchPriority.LOW)
+ queue.dequeue()
+
+ stats = queue.get_stats()
+ assert stats["total_enqueued"] == 2
+ assert stats["total_dequeued"] == 1
+ assert stats["queue_size"] == 1
+
+
+class TestRecoveryManager:
+ """Tests for RecoveryManager class"""
+
+ def test_can_attempt_recovery(self):
+ """Test recovery attempt checking"""
+ from app.services.memory_manager import RecoveryManager
+
+ manager = RecoveryManager(
+ cooldown_seconds=1.0,
+ max_recovery_attempts=3
+ )
+
+ can_recover, reason = manager.can_attempt_recovery()
+ assert can_recover is True
+ assert "allowed" in reason.lower()
+
+ def test_cooldown_period(self):
+ """Test that cooldown period is enforced"""
+ from app.services.memory_manager import RecoveryManager
+
+ manager = RecoveryManager(
+ cooldown_seconds=60.0,
+ max_recovery_attempts=10
+ )
+
+ # First recovery
+ manager.attempt_recovery()
+
+ # Should be in cooldown
+ assert manager.is_in_cooldown() is True
+
+ can_recover, reason = manager.can_attempt_recovery()
+ assert can_recover is False
+ assert "cooldown" in reason.lower()
+
+ def test_max_recovery_attempts(self):
+ """Test that max recovery attempts are enforced"""
+ from app.services.memory_manager import RecoveryManager
+
+ manager = RecoveryManager(
+ cooldown_seconds=0.01, # Very short cooldown for testing
+ max_recovery_attempts=2,
+ recovery_window_seconds=60.0
+ )
+
+ # Perform max attempts
+ for _ in range(2):
+ manager.attempt_recovery()
+ time.sleep(0.02) # Wait for cooldown
+
+ # Next attempt should be blocked
+ can_recover, reason = manager.can_attempt_recovery()
+ assert can_recover is False
+ assert "max" in reason.lower()
+
+ def test_recovery_callbacks(self):
+ """Test recovery event callbacks"""
+ from app.services.memory_manager import RecoveryManager
+
+ manager = RecoveryManager(cooldown_seconds=0.01)
+
+ start_called = False
+ complete_called = False
+ complete_success = None
+
+ def on_start():
+ nonlocal start_called
+ start_called = True
+
+ def on_complete(success):
+ nonlocal complete_called, complete_success
+ complete_called = True
+ complete_success = success
+
+ manager.register_callbacks(on_start=on_start, on_complete=on_complete)
+ manager.attempt_recovery()
+
+ assert start_called is True
+ assert complete_called is True
+ assert complete_success is not None
+
+ def test_get_state(self):
+ """Test getting recovery state"""
+ from app.services.memory_manager import RecoveryManager
+
+ manager = RecoveryManager(cooldown_seconds=60.0)
+ manager.attempt_recovery(error="Test error")
+
+ state = manager.get_state()
+ assert state["recovery_count"] == 1
+ assert state["in_cooldown"] is True
+ assert state["last_error"] == "Test error"
+ assert state["cooldown_remaining_seconds"] > 0
+
+ def test_emergency_release(self):
+ """Test emergency memory release"""
+ from app.services.memory_manager import RecoveryManager
+
+ manager = RecoveryManager()
+
+ # Emergency release without model manager
+ result = manager.emergency_release(model_manager=None)
+ # Should complete without error (may or may not free memory in test env)
+ assert isinstance(result, bool)
+
+
+# =============================================================================
+# Section 1.2: Test model reload after unload
+# =============================================================================
+
+class TestModelReloadAfterUnload:
+ """Tests for model reload after unload functionality"""
+
+ def setup_method(self):
+ """Reset singleton before each test"""
+ shutdown_model_manager()
+ ModelManager._instance = None
+ ModelManager._lock = threading.Lock()
+
+ def teardown_method(self):
+ """Cleanup after each test"""
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def test_reload_after_unload(self):
+ """Test that a model can be reloaded after being unloaded"""
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ load_count = 0
+
+ def loader():
+ nonlocal load_count
+ load_count += 1
+ return Mock(name=f"model_instance_{load_count}")
+
+ # First load
+ model1 = manager.get_or_load_model("test_model", loader, estimated_memory_mb=100)
+ assert load_count == 1
+ assert "test_model" in manager.models
+
+ # Release and unload
+ manager.release_model("test_model")
+ success = manager.unload_model("test_model")
+ assert success is True
+ assert "test_model" not in manager.models
+
+ # Reload
+ model2 = manager.get_or_load_model("test_model", loader, estimated_memory_mb=100)
+ assert load_count == 2 # Loader called again
+ assert "test_model" in manager.models
+
+ # Models should be different instances
+ assert model1 is not model2
+
+ manager.teardown()
+
+ def test_reload_preserves_config(self):
+ """Test that reloaded model uses the same configuration"""
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ cleanup_count = 0
+
+ def cleanup():
+ nonlocal cleanup_count
+ cleanup_count += 1
+
+ # Load with cleanup callback
+ manager.get_or_load_model(
+ "test_model",
+ lambda: Mock(),
+ estimated_memory_mb=200,
+ cleanup_callback=cleanup
+ )
+
+ # Release, unload (cleanup should be called)
+ manager.release_model("test_model")
+ manager.unload_model("test_model")
+ assert cleanup_count == 1
+
+ # Reload with new cleanup callback
+ def new_cleanup():
+ nonlocal cleanup_count
+ cleanup_count += 10
+
+ manager.get_or_load_model(
+ "test_model",
+ lambda: Mock(),
+ estimated_memory_mb=300,
+ cleanup_callback=new_cleanup
+ )
+
+ # Verify new estimated memory
+ assert manager.models["test_model"].estimated_memory_mb == 300
+
+ # Release and unload again
+ manager.release_model("test_model")
+ manager.unload_model("test_model")
+ assert cleanup_count == 11 # New cleanup was called
+
+ manager.teardown()
+
+ def test_concurrent_reload(self):
+ """Test concurrent reload operations"""
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ load_count = 0
+ lock = threading.Lock()
+
+ def loader():
+ nonlocal load_count
+ with lock:
+ load_count += 1
+ time.sleep(0.05) # Simulate slow load
+ return Mock()
+
+ # Load, release, unload
+ manager.get_or_load_model("test_model", loader)
+ manager.release_model("test_model")
+ manager.unload_model("test_model")
+
+ # Concurrent reload attempts
+ results = []
+
+ def worker():
+ model = manager.get_or_load_model("test_model", loader)
+ results.append(model)
+
+ threads = [threading.Thread(target=worker) for _ in range(5)]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ # All threads should get the same model instance
+ assert len(results) == 5
+ assert len(set(id(r) for r in results)) == 1
+ # Only one additional load should have occurred
+ assert load_count == 2
+
+ manager.teardown()
+
+
+# =============================================================================
+# Section 4.2: Test memory savings with selective processing
+# =============================================================================
+
+class TestSelectiveProcessingMemorySavings:
+ """Tests for memory savings with selective processing"""
+
+ def test_batch_processor_memory_constraints(self):
+ """Test that batch processor respects memory constraints"""
+ from app.services.memory_manager import BatchProcessor, BatchItem
+
+ # Create processor with strict memory limit
+ processor = BatchProcessor(
+ max_batch_size=10,
+ max_memory_per_batch_mb=100.0,
+ cleanup_between_batches=False
+ )
+
+ # Add items with known memory estimates
+ processor.add_item(BatchItem(item_id="small1", data=1, estimated_memory_mb=20.0))
+ processor.add_item(BatchItem(item_id="small2", data=2, estimated_memory_mb=20.0))
+ processor.add_item(BatchItem(item_id="small3", data=3, estimated_memory_mb=20.0))
+ processor.add_item(BatchItem(item_id="large1", data=4, estimated_memory_mb=80.0))
+
+ # First batch should include items that fit
+ results = processor.process_batch(lambda x: x)
+
+ # Items should be processed respecting memory limit
+ total_memory_in_batch = sum(
+ item.estimated_memory_mb
+ for item in processor._queue
+ ) + sum(20.0 for _ in results) # Processed items had 20MB each
+
+ # Remaining items should include the large one
+ assert processor.get_queue_size() >= 1
+
+ def test_progressive_loader_memory_efficiency(self):
+ """Test that progressive loader manages memory efficiently"""
+ from app.services.memory_manager import ProgressiveLoader
+
+ loader = ProgressiveLoader(lookahead_pages=1, cleanup_after_pages=100)
+ loader.initialize(total_pages=10)
+
+ pages_loaded = []
+
+ def loader_func(page_num):
+ pages_loaded.append(page_num)
+ return f"page_data_{page_num}"
+
+ # Load pages sequentially
+ for i in range(10):
+ loader.load_page(i, loader_func)
+
+ # Only recent pages should be kept in memory
+ loaded = loader.get_loaded_pages()
+
+ # Should have unloaded distant pages
+ assert 0 not in loaded # First page should be unloaded
+ assert len(loaded) <= 3 # Current + lookahead
+
+ loader.clear()
+
+ def test_priority_queue_processing_order(self):
+ """Test that priority queue processes high priority items first"""
+ from app.services.memory_manager import PriorityOperationQueue, BatchPriority
+
+ queue = PriorityOperationQueue()
+
+ # Add items with different priorities
+ queue.enqueue("low1", "data", BatchPriority.LOW)
+ queue.enqueue("critical1", "data", BatchPriority.CRITICAL)
+ queue.enqueue("normal1", "data", BatchPriority.NORMAL)
+ queue.enqueue("high1", "data", BatchPriority.HIGH)
+
+ # Process in priority order
+ order = []
+ while True:
+ result = queue.dequeue(timeout=0.01)
+ if result is None:
+ break
+ order.append(result[0])
+
+ assert order[0] == "critical1"
+ assert order[1] == "high1"
+ assert order[2] == "normal1"
+ assert order[3] == "low1"
+
+
+# =============================================================================
+# Section 5.2: Test recovery under various scenarios
+# =============================================================================
+
+class TestRecoveryScenarios:
+ """Tests for recovery under various scenarios"""
+
+ def test_recovery_after_oom_simulation(self):
+ """Test recovery behavior after simulated OOM"""
+ from app.services.memory_manager import RecoveryManager
+
+ manager = RecoveryManager(
+ cooldown_seconds=0.1,
+ max_recovery_attempts=5
+ )
+
+ # Simulate OOM recovery
+ success = manager.attempt_recovery(error="CUDA out of memory")
+ assert success is not None # Recovery attempted
+
+ # Check state
+ state = manager.get_state()
+ assert state["recovery_count"] == 1
+ assert state["last_error"] == "CUDA out of memory"
+
+ def test_recovery_cooldown_enforcement(self):
+ """Test that cooldown period is strictly enforced"""
+ from app.services.memory_manager import RecoveryManager
+
+ manager = RecoveryManager(
+ cooldown_seconds=1.0,
+ max_recovery_attempts=10
+ )
+
+ # First recovery
+ manager.attempt_recovery()
+ assert manager.is_in_cooldown() is True
+
+ # Try immediate second recovery
+ can_recover, reason = manager.can_attempt_recovery()
+ assert can_recover is False
+ assert "cooldown" in reason.lower()
+
+ # Wait for cooldown
+ time.sleep(1.1)
+ can_recover, reason = manager.can_attempt_recovery()
+ assert can_recover is True
+
+ def test_recovery_max_attempts_window(self):
+ """Test that max recovery attempts are enforced within window"""
+ from app.services.memory_manager import RecoveryManager
+
+ manager = RecoveryManager(
+ cooldown_seconds=0.01,
+ max_recovery_attempts=3,
+ recovery_window_seconds=60.0
+ )
+
+ # Perform max attempts
+ for i in range(3):
+ manager.attempt_recovery(error=f"Error {i}")
+ time.sleep(0.02) # Wait for cooldown
+
+ # Next attempt should be blocked
+ can_recover, reason = manager.can_attempt_recovery()
+ assert can_recover is False
+ assert "max" in reason.lower()
+
+ def test_emergency_release_with_model_manager(self):
+ """Test emergency release unloads models"""
+ from app.services.memory_manager import RecoveryManager
+
+ # Create a model manager with test models
+ shutdown_model_manager()
+ ModelManager._instance = None
+ config = MemoryConfig()
+ model_manager = ModelManager(config)
+
+ # Load some test models
+ model_manager.get_or_load_model("model1", lambda: Mock(), estimated_memory_mb=100)
+ model_manager.get_or_load_model("model2", lambda: Mock(), estimated_memory_mb=200)
+
+ # Release references
+ model_manager.release_model("model1")
+ model_manager.release_model("model2")
+
+ assert len(model_manager.models) == 2
+
+ # Emergency release
+ recovery_manager = RecoveryManager()
+ result = recovery_manager.emergency_release(model_manager=model_manager)
+
+ # Models should be unloaded
+ assert len(model_manager.models) == 0
+
+ model_manager.teardown()
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+
+# =============================================================================
+# Section 6.1: Test shutdown sequence
+# =============================================================================
+
+class TestShutdownSequence:
+ """Tests for shutdown sequence"""
+
+ def test_model_manager_teardown(self):
+ """Test that teardown properly cleans up models"""
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ # Load models
+ manager.get_or_load_model("model1", lambda: Mock())
+ manager.get_or_load_model("model2", lambda: Mock())
+
+ assert len(manager.models) == 2
+
+ # Teardown
+ manager.teardown()
+
+ assert len(manager.models) == 0
+ assert manager._monitor_running is False
+
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def test_cleanup_callbacks_called_on_teardown(self):
+ """Test that cleanup callbacks are called during teardown"""
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ cleanup_calls = []
+
+ def cleanup1():
+ cleanup_calls.append("model1")
+
+ def cleanup2():
+ cleanup_calls.append("model2")
+
+ manager.get_or_load_model("model1", lambda: Mock(), cleanup_callback=cleanup1)
+ manager.get_or_load_model("model2", lambda: Mock(), cleanup_callback=cleanup2)
+
+ # Teardown with force unload
+ manager.teardown()
+
+ # Both callbacks should have been called
+ assert "model1" in cleanup_calls
+ assert "model2" in cleanup_calls
+
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def test_prediction_semaphore_shutdown(self):
+ """Test prediction semaphore shutdown"""
+ from app.services.memory_manager import (
+ get_prediction_semaphore,
+ shutdown_prediction_semaphore,
+ PredictionSemaphore
+ )
+
+ shutdown_prediction_semaphore()
+ PredictionSemaphore._instance = None
+ PredictionSemaphore._lock = threading.Lock()
+
+ sem = get_prediction_semaphore(max_concurrent=2)
+ sem.acquire(task_id="test1")
+
+ # Shutdown should reset the semaphore
+ shutdown_prediction_semaphore()
+
+ # New instance should be fresh
+ new_sem = get_prediction_semaphore(max_concurrent=3)
+ assert new_sem._active_predictions == 0
+
+ shutdown_prediction_semaphore()
+ PredictionSemaphore._instance = None
+
+
+# =============================================================================
+# Section 6.2: Test cleanup in error scenarios
+# =============================================================================
+
+class TestCleanupInErrorScenarios:
+ """Tests for cleanup in error scenarios"""
+
+ def test_cleanup_after_loader_exception(self):
+ """Test cleanup when model loader raises exception"""
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ def failing_loader():
+ raise RuntimeError("Loader failed")
+
+ with pytest.raises(RuntimeError):
+ manager.get_or_load_model("failing_model", failing_loader)
+
+ # Model should not be in the manager
+ assert "failing_model" not in manager.models
+
+ manager.teardown()
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def test_cleanup_after_processing_error(self):
+ """Test cleanup after processing error in batch processor"""
+ from app.services.memory_manager import BatchProcessor, BatchItem
+
+ processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False)
+
+ processor.add_item(BatchItem(item_id="good1", data=1))
+ processor.add_item(BatchItem(item_id="bad", data="error"))
+ processor.add_item(BatchItem(item_id="good2", data=2))
+
+ def processor_func(data):
+ if data == "error":
+ raise ValueError("Processing error")
+ return data * 2
+
+ results = processor.process_all(processor_func)
+
+ # Good items should succeed, bad item should fail
+ assert len(results) == 3
+ assert results[0].success is True
+ assert results[1].success is False
+ assert results[2].success is True
+
+ # Stats should reflect failure
+ stats = processor.get_stats()
+ assert stats["total_failures"] == 1
+
+ def test_pool_release_with_error(self):
+ """Test that pool properly handles release with error"""
+ from app.services.service_pool import (
+ OCRServicePool,
+ PoolConfig,
+ PooledService,
+ ServiceState,
+ shutdown_service_pool
+ )
+
+ shutdown_service_pool()
+ OCRServicePool._instance = None
+ OCRServicePool._lock = threading.Lock()
+
+ config = PoolConfig(max_consecutive_errors=2)
+ pool = OCRServicePool(config)
+
+ # Pre-populate with a mock service
+ mock_service = Mock()
+ pooled_service = PooledService(service=mock_service, device="GPU:0")
+ pool.services["GPU:0"].append(pooled_service)
+
+ # Acquire and release with errors
+ pooled = pool.acquire(device="GPU:0")
+ pool.release(pooled, error=Exception("Error 1"))
+ assert pooled.error_count == 1
+ assert pooled.state == ServiceState.AVAILABLE
+
+ pooled = pool.acquire(device="GPU:0")
+ pool.release(pooled, error=Exception("Error 2"))
+ assert pooled.error_count == 2
+ assert pooled.state == ServiceState.UNHEALTHY
+
+ pool.shutdown()
+ shutdown_service_pool()
+ OCRServicePool._instance = None
+
+
+# =============================================================================
+# Section 8.1: Memory leak detection tests
+# =============================================================================
+
+class TestMemoryLeakDetection:
+ """Tests for memory leak detection"""
+
+ def test_no_leak_on_model_cycle(self):
+ """Test that loading and unloading models doesn't leak"""
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ initial_model_count = len(manager.models)
+
+ # Perform multiple load/unload cycles
+ for i in range(10):
+ manager.get_or_load_model(f"temp_model_{i}", lambda: Mock())
+ manager.release_model(f"temp_model_{i}")
+ manager.unload_model(f"temp_model_{i}")
+
+ # Should be back to initial state
+ assert len(manager.models) == initial_model_count
+
+ # Force gc and check
+ gc.collect()
+
+ manager.teardown()
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def test_no_leak_on_semaphore_cycle(self):
+ """Test that semaphore acquire/release doesn't leak"""
+ from app.services.memory_manager import (
+ PredictionSemaphore,
+ shutdown_prediction_semaphore
+ )
+
+ shutdown_prediction_semaphore()
+ PredictionSemaphore._instance = None
+ PredictionSemaphore._lock = threading.Lock()
+
+ sem = PredictionSemaphore(max_concurrent=2)
+
+ # Perform many acquire/release cycles
+ for i in range(100):
+ sem.acquire(task_id=f"task_{i}")
+ sem.release(task_id=f"task_{i}")
+
+ # Active predictions should be 0
+ assert sem._active_predictions == 0
+ assert sem._queue_depth == 0
+
+ shutdown_prediction_semaphore()
+ PredictionSemaphore._instance = None
+
+ def test_no_leak_in_batch_processor(self):
+ """Test that batch processor doesn't leak items"""
+ from app.services.memory_manager import BatchProcessor, BatchItem
+
+ processor = BatchProcessor(max_batch_size=5, cleanup_between_batches=False)
+
+ # Add and process many items
+ for i in range(50):
+ processor.add_item(BatchItem(item_id=f"item_{i}", data=i))
+
+ processor.process_all(lambda x: x * 2)
+
+ # Queue should be empty
+ assert processor.get_queue_size() == 0
+
+ # Stats should be accurate
+ stats = processor.get_stats()
+ assert stats["total_processed"] == 50
+
+
+# =============================================================================
+# Section 8.1: Stress tests with concurrent requests
+# =============================================================================
+
+class TestStressConcurrentRequests:
+ """Stress tests with concurrent requests"""
+
+ def test_concurrent_model_access_stress(self):
+ """Stress test concurrent model access"""
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ load_count = 0
+ lock = threading.Lock()
+
+ def loader():
+ nonlocal load_count
+ with lock:
+ load_count += 1
+ return Mock()
+
+ results = []
+ errors = []
+
+ def worker(worker_id):
+ try:
+ model = manager.get_or_load_model("shared_model", loader)
+ time.sleep(0.01) # Simulate work
+ manager.release_model("shared_model")
+ results.append(worker_id)
+ except Exception as e:
+ errors.append(str(e))
+
+ # Launch many concurrent workers
+ threads = [threading.Thread(target=worker, args=(i,)) for i in range(20)]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ # All workers should complete without errors
+ assert len(errors) == 0
+ assert len(results) == 20
+
+ # Loader should only be called once
+ assert load_count == 1
+
+ manager.teardown()
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def test_concurrent_semaphore_stress(self):
+ """Stress test concurrent semaphore operations"""
+ from app.services.memory_manager import (
+ PredictionSemaphore,
+ shutdown_prediction_semaphore
+ )
+
+ shutdown_prediction_semaphore()
+ PredictionSemaphore._instance = None
+ PredictionSemaphore._lock = threading.Lock()
+
+ sem = PredictionSemaphore(max_concurrent=3)
+
+ results = []
+ max_concurrent_observed = 0
+ current_count = 0
+ lock = threading.Lock()
+
+ def worker(worker_id):
+ nonlocal max_concurrent_observed, current_count
+ if sem.acquire(timeout=10.0, task_id=f"task_{worker_id}"):
+ with lock:
+ current_count += 1
+ max_concurrent_observed = max(max_concurrent_observed, current_count)
+
+ time.sleep(0.02) # Simulate work
+
+ with lock:
+ current_count -= 1
+
+ sem.release(task_id=f"task_{worker_id}")
+ results.append(worker_id)
+
+ # Launch many concurrent workers
+ threads = [threading.Thread(target=worker, args=(i,)) for i in range(15)]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ # All should complete
+ assert len(results) == 15
+
+ # Max concurrent should not exceed limit
+ assert max_concurrent_observed <= 3
+
+ shutdown_prediction_semaphore()
+ PredictionSemaphore._instance = None
+
+ def test_concurrent_batch_processing(self):
+ """Stress test concurrent batch processing"""
+ from app.services.memory_manager import BatchProcessor, BatchItem
+
+ processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False)
+
+ # Add items from multiple threads
+ def add_items(start_id):
+ for i in range(10):
+ processor.add_item(BatchItem(item_id=f"item_{start_id}_{i}", data=i))
+
+ threads = [threading.Thread(target=add_items, args=(i,)) for i in range(5)]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ # Should have 50 items
+ assert processor.get_queue_size() == 50
+
+ # Process all
+ results = processor.process_all(lambda x: x)
+ assert len(results) == 50
+
+
+# =============================================================================
+# Section 8.1: Performance benchmarks
+# =============================================================================
+
+class TestPerformanceBenchmarks:
+ """Performance benchmark tests"""
+
+ def test_model_load_performance(self):
+ """Benchmark model loading performance"""
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ config = MemoryConfig()
+ manager = ModelManager(config)
+
+ load_times = []
+
+ for i in range(5):
+ start = time.time()
+ manager.get_or_load_model(f"model_{i}", lambda: Mock())
+ load_times.append(time.time() - start)
+
+ # Average load time should be reasonable (< 100ms for mock)
+ avg_load_time = sum(load_times) / len(load_times)
+ assert avg_load_time < 0.1 # 100ms
+
+ manager.teardown()
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def test_semaphore_throughput(self):
+ """Benchmark semaphore throughput"""
+ from app.services.memory_manager import (
+ PredictionSemaphore,
+ shutdown_prediction_semaphore
+ )
+
+ shutdown_prediction_semaphore()
+ PredictionSemaphore._instance = None
+ PredictionSemaphore._lock = threading.Lock()
+
+ sem = PredictionSemaphore(max_concurrent=10)
+
+ start = time.time()
+ iterations = 1000
+
+ for i in range(iterations):
+ sem.acquire(timeout=1.0)
+ sem.release()
+
+ elapsed = time.time() - start
+
+ # Should handle at least 10000 ops/sec
+ ops_per_sec = iterations / elapsed
+ assert ops_per_sec > 1000
+
+ shutdown_prediction_semaphore()
+ PredictionSemaphore._instance = None
+
+ def test_batch_processor_throughput(self):
+ """Benchmark batch processor throughput"""
+ from app.services.memory_manager import BatchProcessor, BatchItem
+
+ processor = BatchProcessor(max_batch_size=100, cleanup_between_batches=False)
+
+ # Add many items
+ for i in range(1000):
+ processor.add_item(BatchItem(item_id=f"item_{i}", data=i))
+
+ start = time.time()
+ results = processor.process_all(lambda x: x * 2)
+ elapsed = time.time() - start
+
+ # Should process at least 10000 items/sec
+ items_per_sec = len(results) / elapsed
+ assert items_per_sec > 1000
+
+ stats = processor.get_stats()
+ assert stats["total_processed"] == 1000
+
+
+# =============================================================================
+# Tests for Memory Dump and Prometheus Metrics (Section 5.2 & 7.2)
+# =============================================================================
+
+class TestMemoryDumper:
+ """Tests for MemoryDumper class"""
+
+ def setup_method(self):
+ """Reset singletons before each test"""
+ from app.services.memory_manager import shutdown_memory_dumper
+ shutdown_memory_dumper()
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def teardown_method(self):
+ """Cleanup after each test"""
+ from app.services.memory_manager import shutdown_memory_dumper
+ shutdown_memory_dumper()
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def test_create_dump(self):
+ """Test creating a memory dump"""
+ from app.services.memory_manager import MemoryDumper, MemoryDump
+
+ dumper = MemoryDumper()
+ dump = dumper.create_dump()
+
+ assert isinstance(dump, MemoryDump)
+ assert dump.timestamp > 0
+ assert isinstance(dump.loaded_models, list)
+ assert isinstance(dump.gc_stats, dict)
+
+ def test_dump_history(self):
+ """Test dump history tracking"""
+ from app.services.memory_manager import MemoryDumper
+
+ dumper = MemoryDumper()
+
+ # Create multiple dumps
+ for _ in range(5):
+ dumper.create_dump()
+
+ history = dumper.get_dump_history()
+ assert len(history) == 5
+
+ latest = dumper.get_latest_dump()
+ assert latest is history[-1]
+
+ def test_dump_comparison(self):
+ """Test comparing two dumps"""
+ from app.services.memory_manager import MemoryDumper
+
+ dumper = MemoryDumper()
+
+ dump1 = dumper.create_dump()
+ time.sleep(0.1)
+ dump2 = dumper.create_dump()
+
+ comparison = dumper.compare_dumps(dump1, dump2)
+
+ assert "time_delta_seconds" in comparison
+ assert comparison["time_delta_seconds"] > 0
+ assert "gpu_memory_change_mb" in comparison
+ assert "cpu_memory_change_mb" in comparison
+
+ def test_dump_to_dict(self):
+ """Test converting dump to dictionary"""
+ from app.services.memory_manager import MemoryDumper
+
+ dumper = MemoryDumper()
+ dump = dumper.create_dump()
+ dump_dict = dumper.to_dict(dump)
+
+ assert "timestamp" in dump_dict
+ assert "gpu" in dump_dict
+ assert "cpu" in dump_dict
+ assert "models" in dump_dict
+ assert "predictions" in dump_dict
+
+
+class TestPrometheusMetrics:
+ """Tests for PrometheusMetrics class"""
+
+ def setup_method(self):
+ """Reset singletons before each test"""
+ from app.services.memory_manager import shutdown_prometheus_metrics
+ shutdown_prometheus_metrics()
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def teardown_method(self):
+ """Cleanup after each test"""
+ from app.services.memory_manager import shutdown_prometheus_metrics
+ shutdown_prometheus_metrics()
+ shutdown_model_manager()
+ ModelManager._instance = None
+
+ def test_export_metrics(self):
+ """Test exporting metrics in Prometheus format"""
+ from app.services.memory_manager import PrometheusMetrics
+
+ prometheus = PrometheusMetrics()
+ metrics = prometheus.export_metrics()
+
+ # Should be a non-empty string
+ assert isinstance(metrics, str)
+ assert len(metrics) > 0
+
+ # Should contain expected metric prefixes
+ assert "tool_ocr_memory_" in metrics
+
+ def test_metric_format(self):
+ """Test that metrics follow Prometheus format"""
+ from app.services.memory_manager import PrometheusMetrics
+
+ prometheus = PrometheusMetrics()
+ metrics = prometheus.export_metrics()
+
+ lines = metrics.split("\n")
+
+ # Check for HELP and TYPE comments
+ help_lines = [l for l in lines if l.startswith("# HELP")]
+ type_lines = [l for l in lines if l.startswith("# TYPE")]
+
+ assert len(help_lines) > 0
+ assert len(type_lines) > 0
+
+ def test_custom_metrics(self):
+ """Test setting custom metrics"""
+ from app.services.memory_manager import PrometheusMetrics
+
+ prometheus = PrometheusMetrics()
+
+ prometheus.set_custom_metric("custom_value", 42.0)
+ prometheus.set_custom_metric("labeled_value", 100.0, {"env": "test"})
+
+ metrics = prometheus.export_metrics()
+
+ assert "custom_value" in metrics or "42" in metrics
+
+ def test_get_prometheus_metrics_singleton(self):
+ """Test prometheus metrics singleton"""
+ from app.services.memory_manager import get_prometheus_metrics, shutdown_prometheus_metrics
+
+ metrics1 = get_prometheus_metrics()
+ metrics2 = get_prometheus_metrics()
+
+ assert metrics1 is metrics2
+
+ shutdown_prometheus_metrics()
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/backend/tests/services/test_ocr_memory_integration.py b/backend/tests/services/test_ocr_memory_integration.py
new file mode 100644
index 0000000..8de172a
--- /dev/null
+++ b/backend/tests/services/test_ocr_memory_integration.py
@@ -0,0 +1,380 @@
+"""
+Tests for OCR Service Memory Integration
+
+Tests the integration of MemoryGuard with OCRService patterns,
+including pre-operation memory checks and CPU fallback logic.
+"""
+
+import pytest
+from unittest.mock import Mock, patch, MagicMock
+import sys
+
+# Mock paddle before importing memory_manager
+paddle_mock = MagicMock()
+paddle_mock.is_compiled_with_cuda.return_value = False
+paddle_mock.device.cuda.device_count.return_value = 0
+paddle_mock.device.cuda.memory_allocated.return_value = 0
+paddle_mock.device.cuda.memory_reserved.return_value = 0
+paddle_mock.device.cuda.empty_cache = MagicMock()
+sys.modules['paddle'] = paddle_mock
+
+from app.services.memory_manager import (
+ MemoryGuard,
+ MemoryConfig,
+ MemoryStats,
+)
+
+
+class TestMemoryGuardIntegration:
+ """Tests for MemoryGuard integration patterns used in OCRService"""
+
+ def setup_method(self):
+ """Setup for each test"""
+ self.config = MemoryConfig(
+ warning_threshold=0.80,
+ critical_threshold=0.95,
+ emergency_threshold=0.98,
+ enable_cpu_fallback=True,
+ )
+
+ def teardown_method(self):
+ """Cleanup after each test"""
+ pass
+
+ def test_memory_check_below_threshold_allows_processing(self):
+ """Test that memory check returns True when below thresholds"""
+ guard = MemoryGuard(self.config)
+
+ # Mock stats below warning threshold
+ with patch.object(guard, 'get_memory_stats') as mock_stats:
+ mock_stats.return_value = MemoryStats(
+ gpu_used_ratio=0.50,
+ gpu_free_mb=4000,
+ gpu_total_mb=8000,
+ )
+
+ is_available, stats = guard.check_memory(required_mb=2000)
+
+ assert is_available is True
+ assert stats.gpu_free_mb >= 2000
+
+ guard.shutdown()
+
+ def test_memory_check_above_critical_blocks_processing(self):
+ """Test that memory check returns False when above critical threshold"""
+ guard = MemoryGuard(self.config)
+
+ # Mock stats above critical threshold
+ with patch.object(guard, 'get_memory_stats') as mock_stats:
+ mock_stats.return_value = MemoryStats(
+ gpu_used_ratio=0.96,
+ gpu_free_mb=320,
+ gpu_total_mb=8000,
+ )
+
+ is_available, stats = guard.check_memory(required_mb=1000)
+
+ assert is_available is False
+
+ guard.shutdown()
+
+ def test_memory_check_insufficient_free_memory(self):
+ """Test that memory check returns False when free memory < required"""
+ guard = MemoryGuard(self.config)
+
+ # Mock stats with insufficient free memory but below critical ratio
+ with patch.object(guard, 'get_memory_stats') as mock_stats:
+ mock_stats.return_value = MemoryStats(
+ gpu_used_ratio=0.70,
+ gpu_free_mb=500,
+ gpu_total_mb=8000,
+ )
+
+ is_available, stats = guard.check_memory(required_mb=1000)
+
+ # Should return False (not enough free memory)
+ assert is_available is False
+
+ guard.shutdown()
+
+
+class TestCPUFallbackPattern:
+ """Tests for CPU fallback pattern as used in OCRService"""
+
+ def test_cpu_fallback_activation_pattern(self):
+ """Test the CPU fallback activation pattern"""
+ # Simulate the pattern used in OCRService._activate_cpu_fallback
+
+ class MockOCRService:
+ def __init__(self):
+ self._cpu_fallback_active = False
+ self.use_gpu = True
+ self.gpu_available = True
+ self.gpu_info = {'device_id': 0}
+ self._memory_guard = Mock()
+
+ def _activate_cpu_fallback(self):
+ if self._cpu_fallback_active:
+ return
+
+ self._cpu_fallback_active = True
+ self.use_gpu = False
+ self.gpu_info['cpu_fallback'] = True
+ self.gpu_info['fallback_reason'] = 'GPU memory insufficient'
+
+ if self._memory_guard:
+ self._memory_guard.clear_gpu_cache()
+
+ service = MockOCRService()
+
+ # Verify initial state
+ assert service._cpu_fallback_active is False
+ assert service.use_gpu is True
+
+ # Activate fallback
+ service._activate_cpu_fallback()
+
+ # Verify fallback state
+ assert service._cpu_fallback_active is True
+ assert service.use_gpu is False
+ assert service.gpu_info.get('cpu_fallback') is True
+ service._memory_guard.clear_gpu_cache.assert_called_once()
+
+ def test_cpu_fallback_idempotent(self):
+ """Test that CPU fallback activation is idempotent"""
+ class MockOCRService:
+ def __init__(self):
+ self._cpu_fallback_active = False
+ self.use_gpu = True
+ self._memory_guard = Mock()
+ self.gpu_info = {}
+
+ def _activate_cpu_fallback(self):
+ if self._cpu_fallback_active:
+ return
+ self._cpu_fallback_active = True
+ self.use_gpu = False
+ if self._memory_guard:
+ self._memory_guard.clear_gpu_cache()
+
+ service = MockOCRService()
+
+ # Activate twice
+ service._activate_cpu_fallback()
+ service._activate_cpu_fallback()
+
+ # clear_gpu_cache should only be called once
+ assert service._memory_guard.clear_gpu_cache.call_count == 1
+
+ def test_gpu_mode_restoration_pattern(self):
+ """Test the GPU mode restoration pattern"""
+ # Simulate the pattern used in OCRService._restore_gpu_mode
+
+ class MockOCRService:
+ def __init__(self):
+ self._cpu_fallback_active = True
+ self.use_gpu = False
+ self.gpu_available = True
+ self.gpu_info = {
+ 'device_id': 0,
+ 'cpu_fallback': True,
+ 'fallback_reason': 'test'
+ }
+ self._memory_guard = Mock()
+
+ def _restore_gpu_mode(self):
+ if not self._cpu_fallback_active:
+ return
+
+ if not self.gpu_available:
+ return
+
+ # Check if GPU memory is now available
+ if self._memory_guard:
+ is_available, stats = self._memory_guard.check_memory(required_mb=2000)
+ if is_available:
+ self._cpu_fallback_active = False
+ self.use_gpu = True
+ self.gpu_info.pop('cpu_fallback', None)
+ self.gpu_info.pop('fallback_reason', None)
+
+ service = MockOCRService()
+
+ # Mock memory guard to indicate sufficient memory
+ mock_stats = Mock()
+ mock_stats.gpu_free_mb = 5000
+ service._memory_guard.check_memory.return_value = (True, mock_stats)
+
+ # Restore GPU mode
+ service._restore_gpu_mode()
+
+ # Verify GPU mode restored
+ assert service._cpu_fallback_active is False
+ assert service.use_gpu is True
+ assert 'cpu_fallback' not in service.gpu_info
+
+ def test_gpu_mode_not_restored_when_memory_still_low(self):
+ """Test that GPU mode is not restored when memory is still low"""
+ class MockOCRService:
+ def __init__(self):
+ self._cpu_fallback_active = True
+ self.use_gpu = False
+ self.gpu_available = True
+ self.gpu_info = {'cpu_fallback': True}
+ self._memory_guard = Mock()
+
+ def _restore_gpu_mode(self):
+ if not self._cpu_fallback_active:
+ return
+ if not self.gpu_available:
+ return
+ if self._memory_guard:
+ is_available, stats = self._memory_guard.check_memory(required_mb=2000)
+ if is_available:
+ self._cpu_fallback_active = False
+ self.use_gpu = True
+
+ service = MockOCRService()
+
+ # Mock memory guard to indicate insufficient memory
+ mock_stats = Mock()
+ mock_stats.gpu_free_mb = 500
+ service._memory_guard.check_memory.return_value = (False, mock_stats)
+
+ # Try to restore GPU mode
+ service._restore_gpu_mode()
+
+ # Verify still in fallback mode
+ assert service._cpu_fallback_active is True
+ assert service.use_gpu is False
+
+
+class TestPreOperationMemoryCheckPattern:
+ """Tests for pre-operation memory check pattern as used in OCRService"""
+
+ def test_pre_operation_check_with_fallback(self):
+ """Test the pre-operation memory check pattern with fallback"""
+ guard = MemoryGuard(MemoryConfig(
+ warning_threshold=0.80,
+ critical_threshold=0.95,
+ enable_cpu_fallback=True,
+ ))
+
+ # Simulate the pattern:
+ # 1. Check if in CPU fallback mode
+ # 2. Try to restore GPU mode if memory available
+ # 3. Perform memory check for operation
+
+ class MockService:
+ def __init__(self):
+ self._cpu_fallback_active = False
+ self.use_gpu = True
+ self.gpu_available = True
+ self._memory_guard = guard
+
+ def _restore_gpu_mode(self):
+ pass # Simplified
+
+ def pre_operation_check(self, required_mb: int) -> bool:
+ # Try restore first
+ if self._cpu_fallback_active:
+ self._restore_gpu_mode()
+
+ # Perform memory check
+ if not self.use_gpu:
+ return True # CPU mode, no GPU check needed
+
+ is_available, stats = self._memory_guard.check_memory(required_mb=required_mb)
+ return is_available
+
+ service = MockService()
+
+ # Mock sufficient memory
+ with patch.object(guard, 'get_memory_stats') as mock_stats:
+ mock_stats.return_value = MemoryStats(
+ gpu_used_ratio=0.50,
+ gpu_free_mb=4000,
+ gpu_total_mb=8000,
+ )
+
+ result = service.pre_operation_check(required_mb=2000)
+ assert result is True
+
+ guard.shutdown()
+
+ def test_pre_operation_check_returns_true_in_cpu_mode(self):
+ """Test that pre-operation check returns True when in CPU mode"""
+ class MockService:
+ def __init__(self):
+ self._cpu_fallback_active = True
+ self.use_gpu = False
+ self._memory_guard = Mock()
+
+ def pre_operation_check(self, required_mb: int) -> bool:
+ if not self.use_gpu:
+ return True # CPU mode, no GPU check needed
+ return False
+
+ service = MockService()
+ result = service.pre_operation_check(required_mb=5000)
+
+ # Should return True because we're in CPU mode
+ assert result is True
+ # Memory guard should not be called
+ service._memory_guard.check_memory.assert_not_called()
+
+
+class TestMemoryCheckWithCleanup:
+ """Tests for memory check with cleanup pattern"""
+
+ def test_memory_check_triggers_cleanup_on_failure(self):
+ """Test that memory check triggers cleanup when insufficient"""
+ guard = MemoryGuard(MemoryConfig(
+ warning_threshold=0.80,
+ critical_threshold=0.95,
+ ))
+
+ # Track cleanup calls
+ cleanup_called = False
+
+ def mock_cleanup():
+ nonlocal cleanup_called
+ cleanup_called = True
+
+ class MockService:
+ def __init__(self):
+ self._memory_guard = guard
+ self.cleanup_func = mock_cleanup
+
+ def check_gpu_memory(self, required_mb: int) -> bool:
+ # First check
+ with patch.object(self._memory_guard, 'get_memory_stats') as mock_stats:
+ # First call - low memory
+ mock_stats.return_value = MemoryStats(
+ gpu_used_ratio=0.96,
+ gpu_free_mb=300,
+ gpu_total_mb=8000,
+ )
+
+ is_available, stats = self._memory_guard.check_memory(required_mb=required_mb)
+
+ if not is_available:
+ # Trigger cleanup
+ self.cleanup_func()
+ self._memory_guard.clear_gpu_cache()
+ return False
+
+ return True
+
+ service = MockService()
+ result = service.check_gpu_memory(required_mb=1000)
+
+ # Cleanup should have been triggered
+ assert cleanup_called is True
+ assert result is False
+
+ guard.shutdown()
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/backend/tests/services/test_service_pool.py b/backend/tests/services/test_service_pool.py
new file mode 100644
index 0000000..1f09e6d
--- /dev/null
+++ b/backend/tests/services/test_service_pool.py
@@ -0,0 +1,387 @@
+"""
+Tests for OCR Service Pool
+
+Tests OCRServicePool functionality including acquire, release, and concurrency.
+"""
+
+import pytest
+import threading
+import time
+from unittest.mock import Mock, patch, MagicMock
+import sys
+
+# Mock paddle before importing service_pool to avoid import errors
+# when paddle is not installed in the test environment
+paddle_mock = MagicMock()
+paddle_mock.is_compiled_with_cuda.return_value = False
+paddle_mock.device.cuda.device_count.return_value = 0
+paddle_mock.device.cuda.memory_allocated.return_value = 0
+paddle_mock.device.cuda.memory_reserved.return_value = 0
+paddle_mock.device.cuda.empty_cache = MagicMock()
+sys.modules['paddle'] = paddle_mock
+
+from app.services.service_pool import (
+ OCRServicePool,
+ PooledService,
+ PoolConfig,
+ ServiceState,
+ get_service_pool,
+ shutdown_service_pool,
+)
+
+
+class TestPoolConfig:
+ """Tests for PoolConfig class"""
+
+ def test_default_values(self):
+ """Test default configuration values"""
+ config = PoolConfig()
+ assert config.max_services_per_device == 1
+ assert config.max_total_services == 2
+ assert config.acquire_timeout_seconds == 300.0
+ assert config.max_queue_size == 50
+ assert config.max_consecutive_errors == 3
+
+ def test_custom_values(self):
+ """Test custom configuration values"""
+ config = PoolConfig(
+ max_services_per_device=2,
+ max_total_services=4,
+ acquire_timeout_seconds=60.0,
+ )
+ assert config.max_services_per_device == 2
+ assert config.max_total_services == 4
+ assert config.acquire_timeout_seconds == 60.0
+
+
+class TestPooledService:
+ """Tests for PooledService class"""
+
+ def test_creation(self):
+ """Test PooledService creation"""
+ mock_service = Mock()
+ pooled = PooledService(
+ service=mock_service,
+ device="GPU:0",
+ )
+ assert pooled.service is mock_service
+ assert pooled.device == "GPU:0"
+ assert pooled.state == ServiceState.AVAILABLE
+ assert pooled.use_count == 0
+ assert pooled.error_count == 0
+
+
+class TestOCRServicePool:
+ """Tests for OCRServicePool class"""
+
+ def setup_method(self):
+ """Reset singleton before each test"""
+ shutdown_service_pool()
+ OCRServicePool._instance = None
+ OCRServicePool._lock = threading.Lock()
+
+ def teardown_method(self):
+ """Cleanup after each test"""
+ shutdown_service_pool()
+ OCRServicePool._instance = None
+
+ def test_singleton_pattern(self):
+ """Test that OCRServicePool is a singleton"""
+ pool1 = OCRServicePool()
+ pool2 = OCRServicePool()
+ assert pool1 is pool2
+ pool1.shutdown()
+
+ def test_initialize_device(self):
+ """Test device initialization"""
+ config = PoolConfig()
+ pool = OCRServicePool(config)
+
+ # Default device should be initialized
+ assert "GPU:0" in pool.services
+ assert "GPU:0" in pool.semaphores
+
+ # Test adding new device
+ pool._initialize_device("GPU:1")
+ assert "GPU:1" in pool.services
+ assert "GPU:1" in pool.semaphores
+
+ pool.shutdown()
+
+ def test_acquire_creates_service(self):
+ """Test that acquire creates a new service if none available"""
+ config = PoolConfig(max_services_per_device=1)
+ pool = OCRServicePool(config)
+
+ # Pre-populate with a mock service
+ mock_service = Mock()
+ mock_service.process = Mock()
+ mock_service.get_gpu_status = Mock()
+ pooled_service = PooledService(service=mock_service, device="GPU:0")
+ pool.services["GPU:0"].append(pooled_service)
+
+ pooled = pool.acquire(device="GPU:0", timeout=5.0)
+ assert pooled is not None
+ assert pooled.state == ServiceState.IN_USE
+ assert pooled.use_count == 1
+
+ pool.shutdown()
+
+ def test_acquire_reuses_available_service(self):
+ """Test that acquire reuses available services"""
+ config = PoolConfig(max_services_per_device=1)
+ pool = OCRServicePool(config)
+
+ # Pre-populate with a mock service
+ mock_service = Mock()
+ pooled_service = PooledService(service=mock_service, device="GPU:0")
+ pool.services["GPU:0"].append(pooled_service)
+
+ # First acquire
+ pooled1 = pool.acquire(device="GPU:0")
+ service_id = id(pooled1.service)
+ pool.release(pooled1)
+
+ # Second acquire should get the same service
+ pooled2 = pool.acquire(device="GPU:0")
+ assert id(pooled2.service) == service_id
+ assert pooled2.use_count == 2
+
+ pool.shutdown()
+
+ def test_release_makes_service_available(self):
+ """Test that release makes service available again"""
+ config = PoolConfig()
+ pool = OCRServicePool(config)
+
+ # Pre-populate with a mock service
+ mock_service = Mock()
+ pooled_service = PooledService(service=mock_service, device="GPU:0")
+ pool.services["GPU:0"].append(pooled_service)
+
+ pooled = pool.acquire(device="GPU:0")
+ assert pooled.state == ServiceState.IN_USE
+
+ pool.release(pooled)
+ assert pooled.state == ServiceState.AVAILABLE
+
+ pool.shutdown()
+
+ def test_release_with_error(self):
+ """Test that release with error increments error count"""
+ config = PoolConfig(max_consecutive_errors=3)
+ pool = OCRServicePool(config)
+
+ # Pre-populate with a mock service
+ mock_service = Mock()
+ pooled_service = PooledService(service=mock_service, device="GPU:0")
+ pool.services["GPU:0"].append(pooled_service)
+
+ pooled = pool.acquire(device="GPU:0")
+ pool.release(pooled, error=Exception("Test error"))
+
+ assert pooled.error_count == 1
+ assert pooled.state == ServiceState.AVAILABLE
+
+ pool.shutdown()
+
+ def test_release_marks_unhealthy_after_errors(self):
+ """Test that service is marked unhealthy after too many errors"""
+ config = PoolConfig(max_consecutive_errors=2)
+ pool = OCRServicePool(config)
+
+ # Pre-populate with a mock service
+ mock_service = Mock()
+ pooled_service = PooledService(service=mock_service, device="GPU:0")
+ pool.services["GPU:0"].append(pooled_service)
+
+ pooled = pool.acquire(device="GPU:0")
+ pool.release(pooled, error=Exception("Error 1"))
+
+ pooled = pool.acquire(device="GPU:0")
+ pool.release(pooled, error=Exception("Error 2"))
+
+ assert pooled.state == ServiceState.UNHEALTHY
+ assert pooled.error_count == 2
+
+ pool.shutdown()
+
+ def test_acquire_context_manager(self):
+ """Test context manager for acquire/release"""
+ config = PoolConfig()
+ pool = OCRServicePool(config)
+
+ # Pre-populate with a mock service
+ mock_service = Mock()
+ pooled_service = PooledService(service=mock_service, device="GPU:0")
+ pool.services["GPU:0"].append(pooled_service)
+
+ with pool.acquire_context(device="GPU:0") as pooled:
+ assert pooled is not None
+ assert pooled.state == ServiceState.IN_USE
+
+ # After context, service should be available
+ assert pooled.state == ServiceState.AVAILABLE
+
+ pool.shutdown()
+
+ def test_acquire_context_manager_with_error(self):
+ """Test context manager releases on error"""
+ config = PoolConfig()
+ pool = OCRServicePool(config)
+
+ # Pre-populate with a mock service
+ mock_service = Mock()
+ pooled_service = PooledService(service=mock_service, device="GPU:0")
+ pool.services["GPU:0"].append(pooled_service)
+
+ with pytest.raises(ValueError):
+ with pool.acquire_context(device="GPU:0") as pooled:
+ raise ValueError("Test error")
+
+ # Service should still be available after error
+ assert pooled.error_count == 1
+
+ pool.shutdown()
+
+ def test_acquire_timeout(self):
+ """Test that acquire times out when no service available"""
+ config = PoolConfig(
+ max_services_per_device=1,
+ max_total_services=1,
+ )
+ pool = OCRServicePool(config)
+
+ # Pre-populate with a mock service
+ mock_service = Mock()
+ pooled_service = PooledService(service=mock_service, device="GPU:0")
+ pool.services["GPU:0"].append(pooled_service)
+
+ # Acquire the only service
+ pooled1 = pool.acquire(device="GPU:0")
+ assert pooled1 is not None
+
+ # Try to acquire another - should timeout
+ pooled2 = pool.acquire(device="GPU:0", timeout=0.5)
+ assert pooled2 is None
+
+ pool.shutdown()
+
+ def test_get_pool_stats(self):
+ """Test pool statistics"""
+ config = PoolConfig()
+ pool = OCRServicePool(config)
+
+ # Pre-populate with a mock service
+ mock_service = Mock()
+ pooled_service = PooledService(service=mock_service, device="GPU:0")
+ pool.services["GPU:0"].append(pooled_service)
+
+ # Acquire a service
+ pooled = pool.acquire(device="GPU:0")
+
+ stats = pool.get_pool_stats()
+ assert stats["total_services"] == 1
+ assert stats["in_use_services"] == 1
+ assert stats["available_services"] == 0
+ assert stats["metrics"]["total_acquisitions"] == 1
+
+ pool.release(pooled)
+
+ stats = pool.get_pool_stats()
+ assert stats["available_services"] == 1
+ assert stats["metrics"]["total_releases"] == 1
+
+ pool.shutdown()
+
+ def test_health_check(self):
+ """Test health check functionality"""
+ config = PoolConfig()
+ pool = OCRServicePool(config)
+
+ # Pre-populate with a mock service
+ mock_service = Mock()
+ mock_service.process = Mock()
+ mock_service.get_gpu_status = Mock()
+ pooled_service = PooledService(service=mock_service, device="GPU:0")
+ pool.services["GPU:0"].append(pooled_service)
+
+ # Acquire and release to update use_count
+ pooled = pool.acquire(device="GPU:0")
+ pool.release(pooled)
+
+ health = pool.health_check()
+ assert health["healthy"] is True
+ assert len(health["services"]) == 1
+ assert health["services"][0]["responsive"] is True
+
+ pool.shutdown()
+
+ def test_concurrent_acquire(self):
+ """Test concurrent service acquisition"""
+ config = PoolConfig(
+ max_services_per_device=2,
+ max_total_services=2,
+ )
+ pool = OCRServicePool(config)
+
+ # Pre-populate with 2 mock services
+ for i in range(2):
+ mock_service = Mock()
+ pooled_service = PooledService(service=mock_service, device="GPU:0")
+ pool.services["GPU:0"].append(pooled_service)
+
+ results = []
+
+ def worker(worker_id):
+ pooled = pool.acquire(device="GPU:0", timeout=5.0, task_id=f"task_{worker_id}")
+ if pooled:
+ results.append((worker_id, pooled))
+ time.sleep(0.1) # Simulate work
+ pool.release(pooled)
+
+ threads = [threading.Thread(target=worker, args=(i,)) for i in range(4)]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ # All workers should have acquired a service
+ assert len(results) == 4
+
+ pool.shutdown()
+
+
+class TestGetServicePool:
+ """Tests for get_service_pool helper function"""
+
+ def setup_method(self):
+ """Reset singleton before each test"""
+ shutdown_service_pool()
+ OCRServicePool._instance = None
+
+ def teardown_method(self):
+ """Cleanup after each test"""
+ shutdown_service_pool()
+ OCRServicePool._instance = None
+
+ def test_get_service_pool_creates_singleton(self):
+ """Test that get_service_pool creates a singleton"""
+ pool1 = get_service_pool()
+ pool2 = get_service_pool()
+ assert pool1 is pool2
+ shutdown_service_pool()
+
+ def test_shutdown_service_pool(self):
+ """Test shutdown_service_pool cleans up"""
+ pool = get_service_pool()
+ shutdown_service_pool()
+
+ # Should be able to create new pool
+ new_pool = get_service_pool()
+ assert new_pool._initialized is True
+ shutdown_service_pool()
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/frontend/package-lock.json b/frontend/package-lock.json
index 89113c9..ffec023 100644
--- a/frontend/package-lock.json
+++ b/frontend/package-lock.json
@@ -8,6 +8,7 @@
"name": "frontend",
"version": "0.0.0",
"dependencies": {
+ "@radix-ui/react-select": "^2.2.6",
"@tanstack/react-query": "^5.90.7",
"axios": "^1.13.2",
"class-variance-authority": "^0.7.0",
@@ -87,7 +88,6 @@
"integrity": "sha512-e7jT4DxYvIDLk1ZHmU/m/mB19rex9sv0c2ftBtjSBv+kVM/902eh0fINUzD7UwLLNR+jU585GxUJ8/EBfAM5fw==",
"dev": true,
"license": "MIT",
- "peer": true,
"dependencies": {
"@babel/code-frame": "^7.27.1",
"@babel/generator": "^7.28.5",
@@ -947,6 +947,44 @@
"node": "^18.18.0 || ^20.9.0 || >=21.1.0"
}
},
+ "node_modules/@floating-ui/core": {
+ "version": "1.7.3",
+ "resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.7.3.tgz",
+ "integrity": "sha512-sGnvb5dmrJaKEZ+LDIpguvdX3bDlEllmv4/ClQ9awcmCZrlx5jQyyMWFM5kBI+EyNOCDDiKk8il0zeuX3Zlg/w==",
+ "license": "MIT",
+ "dependencies": {
+ "@floating-ui/utils": "^0.2.10"
+ }
+ },
+ "node_modules/@floating-ui/dom": {
+ "version": "1.7.4",
+ "resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.7.4.tgz",
+ "integrity": "sha512-OOchDgh4F2CchOX94cRVqhvy7b3AFb+/rQXyswmzmGakRfkMgoWVjfnLWkRirfLEfuD4ysVW16eXzwt3jHIzKA==",
+ "license": "MIT",
+ "dependencies": {
+ "@floating-ui/core": "^1.7.3",
+ "@floating-ui/utils": "^0.2.10"
+ }
+ },
+ "node_modules/@floating-ui/react-dom": {
+ "version": "2.1.6",
+ "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.1.6.tgz",
+ "integrity": "sha512-4JX6rEatQEvlmgU80wZyq9RT96HZJa88q8hp0pBd+LrczeDI4o6uA2M+uvxngVHo4Ihr8uibXxH6+70zhAFrVw==",
+ "license": "MIT",
+ "dependencies": {
+ "@floating-ui/dom": "^1.7.4"
+ },
+ "peerDependencies": {
+ "react": ">=16.8.0",
+ "react-dom": ">=16.8.0"
+ }
+ },
+ "node_modules/@floating-ui/utils": {
+ "version": "0.2.10",
+ "resolved": "https://registry.npmjs.org/@floating-ui/utils/-/utils-0.2.10.tgz",
+ "integrity": "sha512-aGTxbpbg8/b5JfU1HXSrbH3wXZuLPJcNEcZQFMxLs3oSzgtVu6nFPkbbGGUvBcUjKV2YyB9Wxxabo+HEH9tcRQ==",
+ "license": "MIT"
+ },
"node_modules/@humanfs/core": {
"version": "0.19.1",
"resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz",
@@ -1272,6 +1310,502 @@
"node": ">= 8"
}
},
+ "node_modules/@radix-ui/number": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/@radix-ui/number/-/number-1.1.1.tgz",
+ "integrity": "sha512-MkKCwxlXTgz6CFoJx3pCwn07GKp36+aZyu/u2Ln2VrA5DcdyCZkASEDBTd8x5whTQQL5CiYf4prXKLcgQdv29g==",
+ "license": "MIT"
+ },
+ "node_modules/@radix-ui/primitive": {
+ "version": "1.1.3",
+ "resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.3.tgz",
+ "integrity": "sha512-JTF99U/6XIjCBo0wqkU5sK10glYe27MRRsfwoiq5zzOEZLHU3A3KCMa5X/azekYRCJ0HlwI0crAXS/5dEHTzDg==",
+ "license": "MIT"
+ },
+ "node_modules/@radix-ui/react-arrow": {
+ "version": "1.1.7",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.1.7.tgz",
+ "integrity": "sha512-F+M1tLhO+mlQaOWspE8Wstg+z6PwxwRd8oQ8IXceWz92kfAmalTRf0EjrouQeo7QssEPfCn05B4Ihs1K9WQ/7w==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/react-primitive": "2.1.3"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "@types/react-dom": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
+ "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ },
+ "@types/react-dom": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-collection": {
+ "version": "1.1.7",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.7.tgz",
+ "integrity": "sha512-Fh9rGN0MoI4ZFUNyfFVNU4y9LUz93u9/0K+yLgA2bwRojxM8JU1DyvvMBabnZPBgMWREAJvU2jjVzq+LrFUglw==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/react-compose-refs": "1.1.2",
+ "@radix-ui/react-context": "1.1.2",
+ "@radix-ui/react-primitive": "2.1.3",
+ "@radix-ui/react-slot": "1.2.3"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "@types/react-dom": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
+ "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ },
+ "@types/react-dom": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-compose-refs": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.2.tgz",
+ "integrity": "sha512-z4eqJvfiNnFMHIIvXP3CY57y2WJs5g2v3X0zm9mEJkrkNv4rDxu+sg9Jh8EkXyeqBkB7SOcboo9dMVqhyrACIg==",
+ "license": "MIT",
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-context": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.2.tgz",
+ "integrity": "sha512-jCi/QKUM2r1Ju5a3J64TH2A5SpKAgh0LpknyqdQ4m6DCV0xJ2HG1xARRwNGPQfi1SLdLWZ1OJz6F4OMBBNiGJA==",
+ "license": "MIT",
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-direction": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-direction/-/react-direction-1.1.1.tgz",
+ "integrity": "sha512-1UEWRX6jnOA2y4H5WczZ44gOOjTEmlqv1uNW4GAJEO5+bauCBhv8snY65Iw5/VOS/ghKN9gr2KjnLKxrsvoMVw==",
+ "license": "MIT",
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-dismissable-layer": {
+ "version": "1.1.11",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.11.tgz",
+ "integrity": "sha512-Nqcp+t5cTB8BinFkZgXiMJniQH0PsUt2k51FUhbdfeKvc4ACcG2uQniY/8+h1Yv6Kza4Q7lD7PQV0z0oicE0Mg==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/primitive": "1.1.3",
+ "@radix-ui/react-compose-refs": "1.1.2",
+ "@radix-ui/react-primitive": "2.1.3",
+ "@radix-ui/react-use-callback-ref": "1.1.1",
+ "@radix-ui/react-use-escape-keydown": "1.1.1"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "@types/react-dom": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
+ "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ },
+ "@types/react-dom": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-focus-guards": {
+ "version": "1.1.3",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-guards/-/react-focus-guards-1.1.3.tgz",
+ "integrity": "sha512-0rFg/Rj2Q62NCm62jZw0QX7a3sz6QCQU0LpZdNrJX8byRGaGVTqbrW9jAoIAHyMQqsNpeZ81YgSizOt5WXq0Pw==",
+ "license": "MIT",
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-focus-scope": {
+ "version": "1.1.7",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.7.tgz",
+ "integrity": "sha512-t2ODlkXBQyn7jkl6TNaw/MtVEVvIGelJDCG41Okq/KwUsJBwQ4XVZsHAVUkK4mBv3ewiAS3PGuUWuY2BoK4ZUw==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/react-compose-refs": "1.1.2",
+ "@radix-ui/react-primitive": "2.1.3",
+ "@radix-ui/react-use-callback-ref": "1.1.1"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "@types/react-dom": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
+ "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ },
+ "@types/react-dom": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-id": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-id/-/react-id-1.1.1.tgz",
+ "integrity": "sha512-kGkGegYIdQsOb4XjsfM97rXsiHaBwco+hFI66oO4s9LU+PLAC5oJ7khdOVFxkhsmlbpUqDAvXw11CluXP+jkHg==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/react-use-layout-effect": "1.1.1"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-popper": {
+ "version": "1.2.8",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.8.tgz",
+ "integrity": "sha512-0NJQ4LFFUuWkE7Oxf0htBKS6zLkkjBH+hM1uk7Ng705ReR8m/uelduy1DBo0PyBXPKVnBA6YBlU94MBGXrSBCw==",
+ "license": "MIT",
+ "dependencies": {
+ "@floating-ui/react-dom": "^2.0.0",
+ "@radix-ui/react-arrow": "1.1.7",
+ "@radix-ui/react-compose-refs": "1.1.2",
+ "@radix-ui/react-context": "1.1.2",
+ "@radix-ui/react-primitive": "2.1.3",
+ "@radix-ui/react-use-callback-ref": "1.1.1",
+ "@radix-ui/react-use-layout-effect": "1.1.1",
+ "@radix-ui/react-use-rect": "1.1.1",
+ "@radix-ui/react-use-size": "1.1.1",
+ "@radix-ui/rect": "1.1.1"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "@types/react-dom": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
+ "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ },
+ "@types/react-dom": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-portal": {
+ "version": "1.1.9",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz",
+ "integrity": "sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/react-primitive": "2.1.3",
+ "@radix-ui/react-use-layout-effect": "1.1.1"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "@types/react-dom": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
+ "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ },
+ "@types/react-dom": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-primitive": {
+ "version": "2.1.3",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz",
+ "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/react-slot": "1.2.3"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "@types/react-dom": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
+ "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ },
+ "@types/react-dom": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-select": {
+ "version": "2.2.6",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.2.6.tgz",
+ "integrity": "sha512-I30RydO+bnn2PQztvo25tswPH+wFBjehVGtmagkU78yMdwTwVf12wnAOF+AeP8S2N8xD+5UPbGhkUfPyvT+mwQ==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/number": "1.1.1",
+ "@radix-ui/primitive": "1.1.3",
+ "@radix-ui/react-collection": "1.1.7",
+ "@radix-ui/react-compose-refs": "1.1.2",
+ "@radix-ui/react-context": "1.1.2",
+ "@radix-ui/react-direction": "1.1.1",
+ "@radix-ui/react-dismissable-layer": "1.1.11",
+ "@radix-ui/react-focus-guards": "1.1.3",
+ "@radix-ui/react-focus-scope": "1.1.7",
+ "@radix-ui/react-id": "1.1.1",
+ "@radix-ui/react-popper": "1.2.8",
+ "@radix-ui/react-portal": "1.1.9",
+ "@radix-ui/react-primitive": "2.1.3",
+ "@radix-ui/react-slot": "1.2.3",
+ "@radix-ui/react-use-callback-ref": "1.1.1",
+ "@radix-ui/react-use-controllable-state": "1.2.2",
+ "@radix-ui/react-use-layout-effect": "1.1.1",
+ "@radix-ui/react-use-previous": "1.1.1",
+ "@radix-ui/react-visually-hidden": "1.2.3",
+ "aria-hidden": "^1.2.4",
+ "react-remove-scroll": "^2.6.3"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "@types/react-dom": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
+ "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ },
+ "@types/react-dom": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-slot": {
+ "version": "1.2.3",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz",
+ "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/react-compose-refs": "1.1.2"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-use-callback-ref": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.1.1.tgz",
+ "integrity": "sha512-FkBMwD+qbGQeMu1cOHnuGB6x4yzPjho8ap5WtbEJ26umhgqVXbhekKUQO+hZEL1vU92a3wHwdp0HAcqAUF5iDg==",
+ "license": "MIT",
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-use-controllable-state": {
+ "version": "1.2.2",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-1.2.2.tgz",
+ "integrity": "sha512-BjasUjixPFdS+NKkypcyyN5Pmg83Olst0+c6vGov0diwTEo6mgdqVR6hxcEgFuh4QrAs7Rc+9KuGJ9TVCj0Zzg==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/react-use-effect-event": "0.0.2",
+ "@radix-ui/react-use-layout-effect": "1.1.1"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-use-effect-event": {
+ "version": "0.0.2",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-use-effect-event/-/react-use-effect-event-0.0.2.tgz",
+ "integrity": "sha512-Qp8WbZOBe+blgpuUT+lw2xheLP8q0oatc9UpmiemEICxGvFLYmHm9QowVZGHtJlGbS6A6yJ3iViad/2cVjnOiA==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/react-use-layout-effect": "1.1.1"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-use-escape-keydown": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-1.1.1.tgz",
+ "integrity": "sha512-Il0+boE7w/XebUHyBjroE+DbByORGR9KKmITzbR7MyQ4akpORYP/ZmbhAr0DG7RmmBqoOnZdy2QlvajJ2QA59g==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/react-use-callback-ref": "1.1.1"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-use-layout-effect": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-1.1.1.tgz",
+ "integrity": "sha512-RbJRS4UWQFkzHTTwVymMTUv8EqYhOp8dOOviLj2ugtTiXRaRQS7GLGxZTLL1jWhMeoSCf5zmcZkqTl9IiYfXcQ==",
+ "license": "MIT",
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-use-previous": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-use-previous/-/react-use-previous-1.1.1.tgz",
+ "integrity": "sha512-2dHfToCj/pzca2Ck724OZ5L0EVrr3eHRNsG/b3xQJLA2hZpVCS99bLAX+hm1IHXDEnzU6by5z/5MIY794/a8NQ==",
+ "license": "MIT",
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-use-rect": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-use-rect/-/react-use-rect-1.1.1.tgz",
+ "integrity": "sha512-QTYuDesS0VtuHNNvMh+CjlKJ4LJickCMUAqjlE3+j8w+RlRpwyX3apEQKGFzbZGdo7XNG1tXa+bQqIE7HIXT2w==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/rect": "1.1.1"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-use-size": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-use-size/-/react-use-size-1.1.1.tgz",
+ "integrity": "sha512-ewrXRDTAqAXlkl6t/fkXWNAhFX9I+CkKlw6zjEwk86RSPKwZr3xpBRso655aqYafwtnbpHLj6toFzmd6xdVptQ==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/react-use-layout-effect": "1.1.1"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/react-visually-hidden": {
+ "version": "1.2.3",
+ "resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.2.3.tgz",
+ "integrity": "sha512-pzJq12tEaaIhqjbzpCuv/OypJY/BPavOofm+dbab+MHLajy277+1lLm6JFcGgF5eskJ6mquGirhXY2GD/8u8Ug==",
+ "license": "MIT",
+ "dependencies": {
+ "@radix-ui/react-primitive": "2.1.3"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "@types/react-dom": "*",
+ "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
+ "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ },
+ "@types/react-dom": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@radix-ui/rect": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/@radix-ui/rect/-/rect-1.1.1.tgz",
+ "integrity": "sha512-HPwpGIzkl28mWyZqG52jiqDJ12waP11Pa1lGoiyUkIEuMLBP0oeK/C89esbXrxsky5we7dfd8U58nm0SgAWpVw==",
+ "license": "MIT"
+ },
"node_modules/@rolldown/pluginutils": {
"version": "1.0.0-beta.47",
"resolved": "https://registry.npmjs.org/@rolldown/pluginutils/-/pluginutils-1.0.0-beta.47.tgz",
@@ -1990,7 +2524,6 @@
"integrity": "sha512-GNWcUTRBgIRJD5zj+Tq0fKOJ5XZajIiBroOF0yvj2bSU1WvNdYS/dn9UxwsujGW4JX06dnHyjV2y9rRaybH0iQ==",
"dev": true,
"license": "MIT",
- "peer": true,
"dependencies": {
"undici-types": "~7.16.0"
}
@@ -2000,7 +2533,6 @@
"resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.4.tgz",
"integrity": "sha512-tBFxBp9Nfyy5rsmefN+WXc1JeW/j2BpBHFdLZbEVfs9wn3E3NRFxwV0pJg8M1qQAexFpvz73hJXFofV0ZAu92A==",
"license": "MIT",
- "peer": true,
"dependencies": {
"csstype": "^3.0.2"
}
@@ -2009,7 +2541,7 @@
"version": "19.2.3",
"resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-19.2.3.tgz",
"integrity": "sha512-jp2L/eY6fn+KgVVQAOqYItbF0VY/YApe5Mz2F0aykSO8gx31bYCZyvSeYxCHKvzHG5eZjc+zyaS5BrBWya2+kQ==",
- "dev": true,
+ "devOptional": true,
"license": "MIT",
"peerDependencies": {
"@types/react": "^19.2.0"
@@ -2067,7 +2599,6 @@
"integrity": "sha512-tK3GPFWbirvNgsNKto+UmB/cRtn6TZfyw0D6IKrW55n6Vbs7KJoZtI//kpTKzE/DUmmnAFD8/Ca46s7Obs92/w==",
"dev": true,
"license": "MIT",
- "peer": true,
"dependencies": {
"@typescript-eslint/scope-manager": "8.46.4",
"@typescript-eslint/types": "8.46.4",
@@ -2326,7 +2857,6 @@
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"dev": true,
"license": "MIT",
- "peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -2384,6 +2914,18 @@
"dev": true,
"license": "Python-2.0"
},
+ "node_modules/aria-hidden": {
+ "version": "1.2.6",
+ "resolved": "https://registry.npmjs.org/aria-hidden/-/aria-hidden-1.2.6.tgz",
+ "integrity": "sha512-ik3ZgC9dY/lYVVM++OISsaYDeg1tb0VtP5uL3ouh1koGOaUMDPpbFIei4JkFimWUFPn90sbMNMXQAIVOlnYKJA==",
+ "license": "MIT",
+ "dependencies": {
+ "tslib": "^2.0.0"
+ },
+ "engines": {
+ "node": ">=10"
+ }
+ },
"node_modules/asynckit": {
"version": "0.4.0",
"resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz",
@@ -2519,7 +3061,6 @@
}
],
"license": "MIT",
- "peer": true,
"dependencies": {
"baseline-browser-mapping": "^2.8.25",
"caniuse-lite": "^1.0.30001754",
@@ -2817,6 +3358,12 @@
"node": ">=8"
}
},
+ "node_modules/detect-node-es": {
+ "version": "1.1.0",
+ "resolved": "https://registry.npmjs.org/detect-node-es/-/detect-node-es-1.1.0.tgz",
+ "integrity": "sha512-ypdmJU/TbBby2Dxibuv7ZLW3Bs1QEmM7nHjEANfohJLvE0XVujisn1qPJcZxg+qDucsr+bP6fLD1rPS3AhJ7EQ==",
+ "license": "MIT"
+ },
"node_modules/devlop": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/devlop/-/devlop-1.1.0.tgz",
@@ -2981,7 +3528,6 @@
"integrity": "sha512-BhHmn2yNOFA9H9JmmIVKJmd288g9hrVRDkdoIgRCRuSySRUHH7r/DI6aAXW9T1WwUuY3DFgrcaqB+deURBLR5g==",
"dev": true,
"license": "MIT",
- "peer": true,
"dependencies": {
"@eslint-community/eslint-utils": "^4.8.0",
"@eslint-community/regexpp": "^4.12.1",
@@ -3414,6 +3960,15 @@
"url": "https://github.com/sponsors/ljharb"
}
},
+ "node_modules/get-nonce": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/get-nonce/-/get-nonce-1.0.1.tgz",
+ "integrity": "sha512-FJhYRoDaiatfEkUK8HKlicmu/3SGFD51q3itKDGoSTysQJBnfOcxU5GxnhE1E6soB76MbT0MBtnKJuXyAx+96Q==",
+ "license": "MIT",
+ "engines": {
+ "node": ">=6"
+ }
+ },
"node_modules/get-proto": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz",
@@ -3606,7 +4161,6 @@
}
],
"license": "MIT",
- "peer": true,
"dependencies": {
"@babel/runtime": "^7.27.6"
},
@@ -5096,7 +5650,6 @@
}
],
"license": "MIT",
- "peer": true,
"dependencies": {
"nanoid": "^3.3.11",
"picocolors": "^1.1.1",
@@ -5186,7 +5739,6 @@
"resolved": "https://registry.npmjs.org/react/-/react-19.2.0.tgz",
"integrity": "sha512-tmbWg6W31tQLeB5cdIBOicJDJRR2KzXsV7uSK9iNfLWQ5bIZfxuPEHp7M8wiHyHnn0DD1i7w3Zmin0FtkrwoCQ==",
"license": "MIT",
- "peer": true,
"engines": {
"node": ">=0.10.0"
}
@@ -5196,7 +5748,6 @@
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.2.0.tgz",
"integrity": "sha512-UlbRu4cAiGaIewkPyiRGJk0imDN2T3JjieT6spoL2UeSf5od4n5LB/mQ4ejmxhCFT1tYe8IvaFulzynWovsEFQ==",
"license": "MIT",
- "peer": true,
"dependencies": {
"scheduler": "^0.27.0"
},
@@ -5332,6 +5883,53 @@
"node": ">=0.10.0"
}
},
+ "node_modules/react-remove-scroll": {
+ "version": "2.7.1",
+ "resolved": "https://registry.npmjs.org/react-remove-scroll/-/react-remove-scroll-2.7.1.tgz",
+ "integrity": "sha512-HpMh8+oahmIdOuS5aFKKY6Pyog+FNaZV/XyJOq7b4YFwsFHe5yYfdbIalI4k3vU2nSDql7YskmUseHsRrJqIPA==",
+ "license": "MIT",
+ "dependencies": {
+ "react-remove-scroll-bar": "^2.3.7",
+ "react-style-singleton": "^2.2.3",
+ "tslib": "^2.1.0",
+ "use-callback-ref": "^1.3.3",
+ "use-sidecar": "^1.1.3"
+ },
+ "engines": {
+ "node": ">=10"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/react-remove-scroll-bar": {
+ "version": "2.3.8",
+ "resolved": "https://registry.npmjs.org/react-remove-scroll-bar/-/react-remove-scroll-bar-2.3.8.tgz",
+ "integrity": "sha512-9r+yi9+mgU33AKcj6IbT9oRCO78WriSj6t/cF8DWBZJ9aOGPOTEDvdUDz1FwKim7QXWwmHqtdHnRJfhAxEG46Q==",
+ "license": "MIT",
+ "dependencies": {
+ "react-style-singleton": "^2.2.2",
+ "tslib": "^2.0.0"
+ },
+ "engines": {
+ "node": ">=10"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
"node_modules/react-router": {
"version": "7.9.6",
"resolved": "https://registry.npmjs.org/react-router/-/react-router-7.9.6.tgz",
@@ -5370,6 +5968,28 @@
"react-dom": ">=18"
}
},
+ "node_modules/react-style-singleton": {
+ "version": "2.2.3",
+ "resolved": "https://registry.npmjs.org/react-style-singleton/-/react-style-singleton-2.2.3.tgz",
+ "integrity": "sha512-b6jSvxvVnyptAiLjbkWLE/lOnR4lfTtDAl+eUC7RZy+QQWc6wRzIV2CE6xBuMmDxc2qIihtDCZD5NPOFl7fRBQ==",
+ "license": "MIT",
+ "dependencies": {
+ "get-nonce": "^1.0.0",
+ "tslib": "^2.0.0"
+ },
+ "engines": {
+ "node": ">=10"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
"node_modules/remark-parse": {
"version": "11.0.0",
"resolved": "https://registry.npmjs.org/remark-parse/-/remark-parse-11.0.0.tgz",
@@ -5691,7 +6311,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
- "peer": true,
"engines": {
"node": ">=12"
},
@@ -5770,7 +6389,6 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"devOptional": true,
"license": "Apache-2.0",
- "peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -5938,6 +6556,49 @@
"punycode": "^2.1.0"
}
},
+ "node_modules/use-callback-ref": {
+ "version": "1.3.3",
+ "resolved": "https://registry.npmjs.org/use-callback-ref/-/use-callback-ref-1.3.3.tgz",
+ "integrity": "sha512-jQL3lRnocaFtu3V00JToYz/4QkNWswxijDaCVNZRiRTO3HQDLsdu1ZtmIUvV4yPp+rvWm5j0y0TG/S61cuijTg==",
+ "license": "MIT",
+ "dependencies": {
+ "tslib": "^2.0.0"
+ },
+ "engines": {
+ "node": ">=10"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/use-sidecar": {
+ "version": "1.1.3",
+ "resolved": "https://registry.npmjs.org/use-sidecar/-/use-sidecar-1.1.3.tgz",
+ "integrity": "sha512-Fedw0aZvkhynoPYlA5WXrMCAMm+nSWdZt6lzJQ7Ok8S6Q+VsHmHpRWndVRJ8Be0ZbkfPc5LRYH+5XrzXcEeLRQ==",
+ "license": "MIT",
+ "dependencies": {
+ "detect-node-es": "^1.1.0",
+ "tslib": "^2.0.0"
+ },
+ "engines": {
+ "node": ">=10"
+ },
+ "peerDependencies": {
+ "@types/react": "*",
+ "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc"
+ },
+ "peerDependenciesMeta": {
+ "@types/react": {
+ "optional": true
+ }
+ }
+ },
"node_modules/use-sync-external-store": {
"version": "1.6.0",
"resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.6.0.tgz",
@@ -5981,7 +6642,6 @@
"integrity": "sha512-BxAKBWmIbrDgrokdGZH1IgkIk/5mMHDreLDmCJ0qpyJaAteP8NvMhkwr/ZCQNqNH97bw/dANTE9PDzqwJghfMQ==",
"dev": true,
"license": "MIT",
- "peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.5.0",
@@ -6075,7 +6735,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
- "peer": true,
"engines": {
"node": ">=12"
},
diff --git a/frontend/package.json b/frontend/package.json
index 1f85631..5528d8b 100644
--- a/frontend/package.json
+++ b/frontend/package.json
@@ -10,6 +10,7 @@
"preview": "vite preview"
},
"dependencies": {
+ "@radix-ui/react-select": "^2.2.6",
"@tanstack/react-query": "^5.90.7",
"axios": "^1.13.2",
"class-variance-authority": "^0.7.0",
diff --git a/frontend/src/components/PDFViewer.tsx b/frontend/src/components/PDFViewer.tsx
index db4c8c6..cc65274 100644
--- a/frontend/src/components/PDFViewer.tsx
+++ b/frontend/src/components/PDFViewer.tsx
@@ -1,11 +1,17 @@
-import { useState, useMemo } from 'react'
-import { Document, Page } from 'react-pdf'
+import { useState, useCallback, useMemo, useRef, useEffect } from 'react'
+import { Document, Page, pdfjs } from 'react-pdf'
+import type { PDFDocumentProxy } from 'pdfjs-dist'
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
import { Button } from '@/components/ui/button'
-import { ChevronLeft, ChevronRight, ZoomIn, ZoomOut } from 'lucide-react'
+import { ChevronLeft, ChevronRight, ZoomIn, ZoomOut, Loader2 } from 'lucide-react'
import 'react-pdf/dist/Page/AnnotationLayer.css'
import 'react-pdf/dist/Page/TextLayer.css'
+// Configure standard font data URL for proper font rendering
+const pdfOptions = {
+ standardFontDataUrl: `https://unpkg.com/pdfjs-dist@${pdfjs.version}/standard_fonts/`,
+}
+
interface PDFViewerProps {
title?: string
pdfUrl: string
@@ -17,41 +23,56 @@ export default function PDFViewer({ title, pdfUrl, className, httpHeaders }: PDF
const [numPages, setNumPages] = useState(0)
const [pageNumber, setPageNumber] = useState(1)
const [scale, setScale] = useState(1.0)
- const [loading, setLoading] = useState(true)
+ const [documentLoaded, setDocumentLoaded] = useState(false)
const [error, setError] = useState(null)
- // Memoize the file prop to prevent unnecessary reloads
+ // Store PDF document reference
+ const pdfDocRef = useRef(null)
+
+ // Memoize file config to prevent unnecessary reloads
const fileConfig = useMemo(() => {
return httpHeaders ? { url: pdfUrl, httpHeaders } : pdfUrl
}, [pdfUrl, httpHeaders])
- const onDocumentLoadSuccess = ({ numPages }: { numPages: number }) => {
- setNumPages(numPages)
- setLoading(false)
+ // Reset state when URL changes
+ useEffect(() => {
+ setDocumentLoaded(false)
setError(null)
- }
+ setNumPages(0)
+ setPageNumber(1)
+ pdfDocRef.current = null
+ }, [pdfUrl])
- const onDocumentLoadError = (error: Error) => {
- console.error('Error loading PDF:', error)
- setError('Failed to load PDF. Please try again later.')
- setLoading(false)
- }
+ const onDocumentLoadSuccess = useCallback((pdf: { numPages: number }) => {
+ pdfDocRef.current = pdf as unknown as PDFDocumentProxy
+ setNumPages(pdf.numPages)
+ setPageNumber(1)
+ setDocumentLoaded(true)
+ setError(null)
+ }, [])
- const goToPreviousPage = () => {
+ const onDocumentLoadError = useCallback((err: Error) => {
+ console.error('Error loading PDF:', err)
+ setError('無法載入 PDF 檔案。請稍後再試。')
+ setDocumentLoaded(false)
+ pdfDocRef.current = null
+ }, [])
+
+ const goToPreviousPage = useCallback(() => {
setPageNumber((prev) => Math.max(prev - 1, 1))
- }
+ }, [])
- const goToNextPage = () => {
+ const goToNextPage = useCallback(() => {
setPageNumber((prev) => Math.min(prev + 1, numPages))
- }
+ }, [numPages])
- const zoomIn = () => {
+ const zoomIn = useCallback(() => {
setScale((prev) => Math.min(prev + 0.2, 3.0))
- }
+ }, [])
- const zoomOut = () => {
+ const zoomOut = useCallback(() => {
setScale((prev) => Math.max(prev - 0.2, 0.5))
- }
+ }, [])
return (
@@ -69,18 +90,18 @@ export default function PDFViewer({ title, pdfUrl, className, httpHeaders }: PDF
variant="outline"
size="sm"
onClick={goToPreviousPage}
- disabled={pageNumber <= 1 || loading}
+ disabled={pageNumber <= 1 || !documentLoaded}
>
- Page {pageNumber} of {numPages || '...'}
+ 第 {pageNumber} 頁 / 共 {numPages || '...'} 頁
@@ -92,7 +113,7 @@ export default function PDFViewer({ title, pdfUrl, className, httpHeaders }: PDF
variant="outline"
size="sm"
onClick={zoomOut}
- disabled={scale <= 0.5 || loading}
+ disabled={scale <= 0.5 || !documentLoaded}
>
@@ -103,7 +124,7 @@ export default function PDFViewer({ title, pdfUrl, className, httpHeaders }: PDF
variant="outline"
size="sm"
onClick={zoomIn}
- disabled={scale >= 3.0 || loading}
+ disabled={scale >= 3.0 || !documentLoaded}
>
@@ -113,39 +134,48 @@ export default function PDFViewer({ title, pdfUrl, className, httpHeaders }: PDF
{/* PDF Document */}