feat: implement hybrid image extraction and memory management
Backend: - Add hybrid image extraction for Direct track (inline image blocks) - Add render_inline_image_regions() fallback when OCR doesn't find images - Add check_document_for_missing_images() for detecting missing images - Add memory management system (MemoryGuard, ModelManager, ServicePool) - Update pdf_generator_service to handle HYBRID processing track - Add ElementType.LOGO for logo extraction Frontend: - Fix PDF viewer re-rendering issues with memoization - Add TaskNotFound component and useTaskValidation hook - Disable StrictMode due to react-pdf incompatibility - Fix task detail and results page loading states 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -7,10 +7,103 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 6.1: Signal Handlers
|
||||
# =============================================================================
|
||||
|
||||
# Flag to indicate graceful shutdown is in progress
|
||||
_shutdown_requested = False
|
||||
_shutdown_complete = asyncio.Event()
|
||||
|
||||
# Track active connections for draining
|
||||
_active_connections = 0
|
||||
_connection_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def increment_connections():
|
||||
"""Track active connection count"""
|
||||
global _active_connections
|
||||
async with _connection_lock:
|
||||
_active_connections += 1
|
||||
|
||||
|
||||
async def decrement_connections():
|
||||
"""Track active connection count"""
|
||||
global _active_connections
|
||||
async with _connection_lock:
|
||||
_active_connections -= 1
|
||||
|
||||
|
||||
def get_active_connections() -> int:
|
||||
"""Get current active connection count"""
|
||||
return _active_connections
|
||||
|
||||
|
||||
def is_shutdown_requested() -> bool:
|
||||
"""Check if graceful shutdown has been requested"""
|
||||
return _shutdown_requested
|
||||
|
||||
|
||||
def _signal_handler(signum: int, frame) -> None:
|
||||
"""
|
||||
Signal handler for SIGTERM and SIGINT.
|
||||
|
||||
Initiates graceful shutdown by setting the shutdown flag.
|
||||
The actual cleanup is handled by the lifespan context manager.
|
||||
"""
|
||||
global _shutdown_requested
|
||||
signal_name = signal.Signals(signum).name
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if _shutdown_requested:
|
||||
logger.warning(f"Received {signal_name} again, forcing immediate exit...")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(f"Received {signal_name}, initiating graceful shutdown...")
|
||||
_shutdown_requested = True
|
||||
|
||||
# Try to stop the event loop gracefully
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Schedule shutdown event
|
||||
loop.call_soon_threadsafe(_shutdown_complete.set)
|
||||
except RuntimeError:
|
||||
pass # No event loop running
|
||||
|
||||
|
||||
def setup_signal_handlers():
|
||||
"""
|
||||
Set up signal handlers for graceful shutdown.
|
||||
|
||||
Handles:
|
||||
- SIGTERM: Standard termination signal (from systemd, docker, etc.)
|
||||
- SIGINT: Keyboard interrupt (Ctrl+C)
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
# SIGTERM - Standard termination signal
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
logger.info("SIGTERM handler installed")
|
||||
|
||||
# SIGINT - Keyboard interrupt
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
logger.info("SIGINT handler installed")
|
||||
|
||||
except (ValueError, OSError) as e:
|
||||
# Signal handling may not be available in all contexts
|
||||
logger.warning(f"Could not install signal handlers: {e}")
|
||||
|
||||
# Ensure log directory exists before configuring logging
|
||||
Path(settings.log_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -38,16 +131,91 @@ for logger_name in ['app', 'app.services', 'app.services.pdf_generator_service',
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def drain_connections(timeout: float = 30.0):
|
||||
"""
|
||||
Wait for active connections to complete (connection draining).
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for connections to drain
|
||||
"""
|
||||
logger.info(f"Draining connections (timeout={timeout}s)...")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
while get_active_connections() > 0:
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
if elapsed >= timeout:
|
||||
logger.warning(
|
||||
f"Connection drain timeout after {timeout}s. "
|
||||
f"{get_active_connections()} connections still active."
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(f"Waiting for {get_active_connections()} active connections...")
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
if get_active_connections() == 0:
|
||||
logger.info("All connections drained successfully")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan events"""
|
||||
# Startup
|
||||
logger.info("Starting Tool_OCR V2 application...")
|
||||
|
||||
# Set up signal handlers for graceful shutdown
|
||||
setup_signal_handlers()
|
||||
|
||||
# Ensure all directories exist
|
||||
settings.ensure_directories()
|
||||
logger.info("All directories created/verified")
|
||||
|
||||
# Initialize memory management if enabled
|
||||
if settings.enable_model_lifecycle_management:
|
||||
try:
|
||||
from app.services.memory_manager import get_model_manager, MemoryConfig
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
warning_threshold=settings.memory_warning_threshold,
|
||||
critical_threshold=settings.memory_critical_threshold,
|
||||
emergency_threshold=settings.memory_emergency_threshold,
|
||||
model_idle_timeout_seconds=settings.pp_structure_idle_timeout_seconds,
|
||||
memory_check_interval_seconds=settings.memory_check_interval_seconds,
|
||||
enable_auto_cleanup=settings.enable_memory_optimization,
|
||||
enable_emergency_cleanup=settings.enable_emergency_cleanup,
|
||||
max_concurrent_predictions=settings.max_concurrent_predictions,
|
||||
enable_cpu_fallback=settings.enable_cpu_fallback,
|
||||
gpu_memory_limit_mb=settings.gpu_memory_limit_mb,
|
||||
)
|
||||
get_model_manager(memory_config)
|
||||
logger.info("Memory management initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize memory management: {e}")
|
||||
|
||||
# Initialize service pool if enabled
|
||||
if settings.enable_service_pool:
|
||||
try:
|
||||
from app.services.service_pool import get_service_pool, PoolConfig
|
||||
|
||||
pool_config = PoolConfig(
|
||||
max_services_per_device=settings.max_services_per_device,
|
||||
max_total_services=settings.max_total_services,
|
||||
acquire_timeout_seconds=settings.service_acquire_timeout_seconds,
|
||||
max_queue_size=settings.max_queue_size,
|
||||
)
|
||||
get_service_pool(pool_config)
|
||||
logger.info("OCR service pool initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize service pool: {e}")
|
||||
|
||||
# Initialize prediction semaphore for controlling concurrent PP-StructureV3 predictions
|
||||
try:
|
||||
from app.services.memory_manager import get_prediction_semaphore
|
||||
get_prediction_semaphore(max_concurrent=settings.max_concurrent_predictions)
|
||||
logger.info(f"Prediction semaphore initialized (max_concurrent={settings.max_concurrent_predictions})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize prediction semaphore: {e}")
|
||||
|
||||
logger.info("Application startup complete")
|
||||
|
||||
yield
|
||||
@@ -55,6 +223,45 @@ async def lifespan(app: FastAPI):
|
||||
# Shutdown
|
||||
logger.info("Shutting down Tool_OCR application...")
|
||||
|
||||
# Connection draining - wait for active requests to complete
|
||||
await drain_connections(timeout=30.0)
|
||||
|
||||
# Shutdown recovery manager if initialized
|
||||
try:
|
||||
from app.services.memory_manager import shutdown_recovery_manager
|
||||
shutdown_recovery_manager()
|
||||
logger.info("Recovery manager shutdown complete")
|
||||
except Exception as e:
|
||||
logger.debug(f"Recovery manager shutdown skipped: {e}")
|
||||
|
||||
# Shutdown service pool
|
||||
if settings.enable_service_pool:
|
||||
try:
|
||||
from app.services.service_pool import shutdown_service_pool
|
||||
shutdown_service_pool()
|
||||
logger.info("Service pool shutdown complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down service pool: {e}")
|
||||
|
||||
# Shutdown prediction semaphore
|
||||
try:
|
||||
from app.services.memory_manager import shutdown_prediction_semaphore
|
||||
shutdown_prediction_semaphore()
|
||||
logger.info("Prediction semaphore shutdown complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down prediction semaphore: {e}")
|
||||
|
||||
# Shutdown memory manager
|
||||
if settings.enable_model_lifecycle_management:
|
||||
try:
|
||||
from app.services.memory_manager import shutdown_model_manager
|
||||
shutdown_model_manager()
|
||||
logger.info("Memory manager shutdown complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down memory manager: {e}")
|
||||
|
||||
logger.info("Tool_OCR shutdown complete")
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
@@ -77,9 +284,7 @@ app.add_middleware(
|
||||
# Health check endpoint
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint with GPU status"""
|
||||
from app.services.ocr_service import OCRService
|
||||
|
||||
"""Health check endpoint with GPU status and memory management info"""
|
||||
response = {
|
||||
"status": "healthy",
|
||||
"service": "Tool_OCR V2",
|
||||
@@ -88,10 +293,31 @@ async def health_check():
|
||||
|
||||
# Add GPU status information
|
||||
try:
|
||||
# Create temporary OCRService instance to get GPU status
|
||||
# In production, this should be a singleton service
|
||||
ocr_service = OCRService()
|
||||
gpu_status = ocr_service.get_gpu_status()
|
||||
# Use service pool if available to avoid creating new instances
|
||||
gpu_status = None
|
||||
if settings.enable_service_pool:
|
||||
try:
|
||||
from app.services.service_pool import get_service_pool
|
||||
pool = get_service_pool()
|
||||
pool_stats = pool.get_pool_stats()
|
||||
response["service_pool"] = pool_stats
|
||||
|
||||
# Get GPU status from first available service
|
||||
for device, services in pool.services.items():
|
||||
for pooled in services:
|
||||
if hasattr(pooled.service, 'get_gpu_status'):
|
||||
gpu_status = pooled.service.get_gpu_status()
|
||||
break
|
||||
if gpu_status:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get service pool stats: {e}")
|
||||
|
||||
# Fallback: create temporary instance if no pool or no service available
|
||||
if gpu_status is None:
|
||||
from app.services.ocr_service import OCRService
|
||||
ocr_service = OCRService()
|
||||
gpu_status = ocr_service.get_gpu_status()
|
||||
|
||||
response["gpu"] = {
|
||||
"available": gpu_status.get("gpu_available", False),
|
||||
@@ -120,6 +346,15 @@ async def health_check():
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
# Add memory management status
|
||||
if settings.enable_model_lifecycle_management:
|
||||
try:
|
||||
from app.services.memory_manager import get_model_manager
|
||||
model_manager = get_model_manager()
|
||||
response["memory_management"] = model_manager.get_model_stats()
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get memory management stats: {e}")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user