Files
OCR/backend/app/main.py
egg 1afdb822c3 feat: implement hybrid image extraction and memory management
Backend:
- Add hybrid image extraction for Direct track (inline image blocks)
- Add render_inline_image_regions() fallback when OCR doesn't find images
- Add check_document_for_missing_images() for detecting missing images
- Add memory management system (MemoryGuard, ModelManager, ServicePool)
- Update pdf_generator_service to handle HYBRID processing track
- Add ElementType.LOGO for logo extraction

Frontend:
- Fix PDF viewer re-rendering issues with memoization
- Add TaskNotFound component and useTaskValidation hook
- Disable StrictMode due to react-pdf incompatibility
- Fix task detail and results page loading states

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-26 10:56:22 +08:00

482 lines
16 KiB
Python

"""
Tool_OCR - FastAPI Application Entry Point (V2)
Main application setup with CORS, routes, and startup/shutdown events
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import logging
import signal
import sys
import asyncio
from pathlib import Path
from typing import Optional
from app.core.config import settings
# =============================================================================
# Section 6.1: Signal Handlers
# =============================================================================
# Flag to indicate graceful shutdown is in progress
_shutdown_requested = False
_shutdown_complete = asyncio.Event()
# Track active connections for draining
_active_connections = 0
_connection_lock = asyncio.Lock()
async def increment_connections():
"""Track active connection count"""
global _active_connections
async with _connection_lock:
_active_connections += 1
async def decrement_connections():
"""Track active connection count"""
global _active_connections
async with _connection_lock:
_active_connections -= 1
def get_active_connections() -> int:
"""Get current active connection count"""
return _active_connections
def is_shutdown_requested() -> bool:
"""Check if graceful shutdown has been requested"""
return _shutdown_requested
def _signal_handler(signum: int, frame) -> None:
"""
Signal handler for SIGTERM and SIGINT.
Initiates graceful shutdown by setting the shutdown flag.
The actual cleanup is handled by the lifespan context manager.
"""
global _shutdown_requested
signal_name = signal.Signals(signum).name
logger = logging.getLogger(__name__)
if _shutdown_requested:
logger.warning(f"Received {signal_name} again, forcing immediate exit...")
sys.exit(1)
logger.info(f"Received {signal_name}, initiating graceful shutdown...")
_shutdown_requested = True
# Try to stop the event loop gracefully
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# Schedule shutdown event
loop.call_soon_threadsafe(_shutdown_complete.set)
except RuntimeError:
pass # No event loop running
def setup_signal_handlers():
"""
Set up signal handlers for graceful shutdown.
Handles:
- SIGTERM: Standard termination signal (from systemd, docker, etc.)
- SIGINT: Keyboard interrupt (Ctrl+C)
"""
logger = logging.getLogger(__name__)
try:
# SIGTERM - Standard termination signal
signal.signal(signal.SIGTERM, _signal_handler)
logger.info("SIGTERM handler installed")
# SIGINT - Keyboard interrupt
signal.signal(signal.SIGINT, _signal_handler)
logger.info("SIGINT handler installed")
except (ValueError, OSError) as e:
# Signal handling may not be available in all contexts
logger.warning(f"Could not install signal handlers: {e}")
# Ensure log directory exists before configuring logging
Path(settings.log_file).parent.mkdir(parents=True, exist_ok=True)
# Configure logging - Force configuration to override uvicorn's settings
logging.basicConfig(
level=getattr(logging, settings.log_level),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(settings.log_file),
logging.StreamHandler(),
],
force=True # Force reconfiguration (Python 3.8+)
)
# Also explicitly configure root logger and app loggers
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, settings.log_level))
# Configure app-specific loggers
for logger_name in ['app', 'app.services', 'app.services.pdf_generator_service', 'app.services.ocr_service']:
app_logger = logging.getLogger(logger_name)
app_logger.setLevel(getattr(logging, settings.log_level))
app_logger.propagate = True # Ensure logs propagate to root logger
logger = logging.getLogger(__name__)
async def drain_connections(timeout: float = 30.0):
"""
Wait for active connections to complete (connection draining).
Args:
timeout: Maximum time to wait for connections to drain
"""
logger.info(f"Draining connections (timeout={timeout}s)...")
start_time = asyncio.get_event_loop().time()
while get_active_connections() > 0:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed >= timeout:
logger.warning(
f"Connection drain timeout after {timeout}s. "
f"{get_active_connections()} connections still active."
)
break
logger.info(f"Waiting for {get_active_connections()} active connections...")
await asyncio.sleep(1.0)
if get_active_connections() == 0:
logger.info("All connections drained successfully")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan events"""
# Startup
logger.info("Starting Tool_OCR V2 application...")
# Set up signal handlers for graceful shutdown
setup_signal_handlers()
# Ensure all directories exist
settings.ensure_directories()
logger.info("All directories created/verified")
# Initialize memory management if enabled
if settings.enable_model_lifecycle_management:
try:
from app.services.memory_manager import get_model_manager, MemoryConfig
memory_config = MemoryConfig(
warning_threshold=settings.memory_warning_threshold,
critical_threshold=settings.memory_critical_threshold,
emergency_threshold=settings.memory_emergency_threshold,
model_idle_timeout_seconds=settings.pp_structure_idle_timeout_seconds,
memory_check_interval_seconds=settings.memory_check_interval_seconds,
enable_auto_cleanup=settings.enable_memory_optimization,
enable_emergency_cleanup=settings.enable_emergency_cleanup,
max_concurrent_predictions=settings.max_concurrent_predictions,
enable_cpu_fallback=settings.enable_cpu_fallback,
gpu_memory_limit_mb=settings.gpu_memory_limit_mb,
)
get_model_manager(memory_config)
logger.info("Memory management initialized")
except Exception as e:
logger.warning(f"Failed to initialize memory management: {e}")
# Initialize service pool if enabled
if settings.enable_service_pool:
try:
from app.services.service_pool import get_service_pool, PoolConfig
pool_config = PoolConfig(
max_services_per_device=settings.max_services_per_device,
max_total_services=settings.max_total_services,
acquire_timeout_seconds=settings.service_acquire_timeout_seconds,
max_queue_size=settings.max_queue_size,
)
get_service_pool(pool_config)
logger.info("OCR service pool initialized")
except Exception as e:
logger.warning(f"Failed to initialize service pool: {e}")
# Initialize prediction semaphore for controlling concurrent PP-StructureV3 predictions
try:
from app.services.memory_manager import get_prediction_semaphore
get_prediction_semaphore(max_concurrent=settings.max_concurrent_predictions)
logger.info(f"Prediction semaphore initialized (max_concurrent={settings.max_concurrent_predictions})")
except Exception as e:
logger.warning(f"Failed to initialize prediction semaphore: {e}")
logger.info("Application startup complete")
yield
# Shutdown
logger.info("Shutting down Tool_OCR application...")
# Connection draining - wait for active requests to complete
await drain_connections(timeout=30.0)
# Shutdown recovery manager if initialized
try:
from app.services.memory_manager import shutdown_recovery_manager
shutdown_recovery_manager()
logger.info("Recovery manager shutdown complete")
except Exception as e:
logger.debug(f"Recovery manager shutdown skipped: {e}")
# Shutdown service pool
if settings.enable_service_pool:
try:
from app.services.service_pool import shutdown_service_pool
shutdown_service_pool()
logger.info("Service pool shutdown complete")
except Exception as e:
logger.warning(f"Error shutting down service pool: {e}")
# Shutdown prediction semaphore
try:
from app.services.memory_manager import shutdown_prediction_semaphore
shutdown_prediction_semaphore()
logger.info("Prediction semaphore shutdown complete")
except Exception as e:
logger.warning(f"Error shutting down prediction semaphore: {e}")
# Shutdown memory manager
if settings.enable_model_lifecycle_management:
try:
from app.services.memory_manager import shutdown_model_manager
shutdown_model_manager()
logger.info("Memory manager shutdown complete")
except Exception as e:
logger.warning(f"Error shutting down memory manager: {e}")
logger.info("Tool_OCR shutdown complete")
# Create FastAPI application
app = FastAPI(
title="Tool_OCR V2",
description="OCR Processing System with External Authentication & Task Isolation",
version="2.0.0",
lifespan=lifespan,
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint with GPU status and memory management info"""
response = {
"status": "healthy",
"service": "Tool_OCR V2",
"version": "2.0.0",
}
# Add GPU status information
try:
# Use service pool if available to avoid creating new instances
gpu_status = None
if settings.enable_service_pool:
try:
from app.services.service_pool import get_service_pool
pool = get_service_pool()
pool_stats = pool.get_pool_stats()
response["service_pool"] = pool_stats
# Get GPU status from first available service
for device, services in pool.services.items():
for pooled in services:
if hasattr(pooled.service, 'get_gpu_status'):
gpu_status = pooled.service.get_gpu_status()
break
if gpu_status:
break
except Exception as e:
logger.debug(f"Could not get service pool stats: {e}")
# Fallback: create temporary instance if no pool or no service available
if gpu_status is None:
from app.services.ocr_service import OCRService
ocr_service = OCRService()
gpu_status = ocr_service.get_gpu_status()
response["gpu"] = {
"available": gpu_status.get("gpu_available", False),
"enabled": gpu_status.get("gpu_enabled", False),
"device_name": gpu_status.get("device_name", "N/A"),
"device_count": gpu_status.get("device_count", 0),
"compute_capability": gpu_status.get("compute_capability", "N/A"),
}
# Add memory info if available
if gpu_status.get("memory_total_mb"):
response["gpu"]["memory"] = {
"total_mb": round(gpu_status.get("memory_total_mb", 0), 2),
"allocated_mb": round(gpu_status.get("memory_allocated_mb", 0), 2),
"utilization_percent": round(gpu_status.get("memory_utilization", 0), 2),
}
# Add reason if GPU is not available
if not gpu_status.get("gpu_available") and gpu_status.get("reason"):
response["gpu"]["reason"] = gpu_status.get("reason")
except Exception as e:
logger.warning(f"Failed to get GPU status: {e}")
response["gpu"] = {
"available": False,
"error": str(e),
}
# Add memory management status
if settings.enable_model_lifecycle_management:
try:
from app.services.memory_manager import get_model_manager
model_manager = get_model_manager()
response["memory_management"] = model_manager.get_model_stats()
except Exception as e:
logger.debug(f"Could not get memory management stats: {e}")
return response
# Root endpoint
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "Tool_OCR API V2 - External Authentication",
"version": "2.0.0",
"docs_url": "/docs",
"health_check": "/health",
}
# Include V2 API routers
from app.routers import auth, tasks, admin
from fastapi import UploadFile, File, Depends, HTTPException, status
from sqlalchemy.orm import Session
import hashlib
from app.core.deps import get_db, get_current_user
from app.models.user import User
from app.models.task import TaskFile
from app.schemas.task import UploadResponse, TaskStatusEnum
from app.services.task_service import task_service
app.include_router(auth.router)
app.include_router(tasks.router)
app.include_router(admin.router)
# File upload endpoint
@app.post("/api/v2/upload", response_model=UploadResponse, tags=["Upload"], summary="Upload file for OCR")
async def upload_file(
file: UploadFile = File(..., description="File to upload (PNG, JPG, PDF, etc.)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Upload a file for OCR processing
Creates a new task and uploads the file
- **file**: File to upload
"""
try:
# Validate file extension
file_ext = Path(file.filename).suffix.lower().lstrip('.')
if file_ext not in settings.allowed_extensions_list:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File type .{file_ext} not allowed. Allowed types: {', '.join(settings.allowed_extensions_list)}"
)
# Read file content
file_content = await file.read()
file_size = len(file_content)
# Calculate file hash
file_hash = hashlib.sha256(file_content).hexdigest()
# Create task
task = task_service.create_task(
db=db,
user_id=current_user.id,
filename=file.filename,
file_type=file.content_type
)
# Save file to disk
upload_dir = Path(settings.upload_dir)
upload_dir.mkdir(parents=True, exist_ok=True)
# Create unique filename using task_id
unique_filename = f"{task.task_id}_{file.filename}"
file_path = upload_dir / unique_filename
# Write file
with open(file_path, "wb") as f:
f.write(file_content)
# Create TaskFile record
task_file = TaskFile(
task_id=task.id,
original_name=file.filename,
stored_path=str(file_path),
file_size=file_size,
mime_type=file.content_type,
file_hash=file_hash
)
db.add(task_file)
db.commit()
logger.info(f"Uploaded file {file.filename} ({file_size} bytes) for task {task.task_id}, user {current_user.email}")
return {
"task_id": task.task_id,
"filename": file.filename,
"file_size": file_size,
"file_type": file.content_type or "application/octet-stream",
"status": TaskStatusEnum.PENDING
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"Failed to upload file for user {current_user.id}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to upload file: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=settings.backend_port,
reload=True,
log_level=settings.log_level.lower(),
)