Files
OCR/backend/app/main.py
egg 86a6633000 feat: consolidate env config and add deployment files
- Add debug_font_path, demo_docs_dir, e2e_api_base_url to config.py
- Fix hardcoded paths in pp_structure_debug.py, create_demo_images.py
- Fix hardcoded paths in test files
- Update .env.example with new configuration options
- Update .gitignore to exclude AI development files (.claude/, openspec/, AGENTS.md, CLAUDE.md)
- Add production startup script (start-prod.sh)
- Add README.md with project documentation
- Add 1panel Docker deployment files (docker-compose.yml, Dockerfiles, nginx.conf)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-14 15:02:16 +08:00

538 lines
18 KiB
Python

"""
Tool_OCR - FastAPI Application Entry Point (V2)
Main application setup with CORS, routes, and startup/shutdown events
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import logging
import signal
import sys
import asyncio
from pathlib import Path
from typing import Optional
from app.core.config import settings
# =============================================================================
# Section 6.1: Signal Handlers
# =============================================================================
# Flag to indicate graceful shutdown is in progress
_shutdown_requested = False
_shutdown_complete = asyncio.Event()
# Track active connections for draining
_active_connections = 0
_connection_lock = asyncio.Lock()
async def increment_connections():
"""Track active connection count"""
global _active_connections
async with _connection_lock:
_active_connections += 1
async def decrement_connections():
"""Track active connection count"""
global _active_connections
async with _connection_lock:
_active_connections -= 1
def get_active_connections() -> int:
"""Get current active connection count"""
return _active_connections
def is_shutdown_requested() -> bool:
"""Check if graceful shutdown has been requested"""
return _shutdown_requested
def _signal_handler(signum: int, frame) -> None:
"""
Signal handler for SIGTERM and SIGINT.
Initiates graceful shutdown by setting the shutdown flag.
The actual cleanup is handled by the lifespan context manager.
"""
global _shutdown_requested
signal_name = signal.Signals(signum).name
logger = logging.getLogger(__name__)
if _shutdown_requested:
logger.warning(f"Received {signal_name} again, forcing immediate exit...")
sys.exit(1)
logger.info(f"Received {signal_name}, initiating graceful shutdown...")
_shutdown_requested = True
# Try to stop the event loop gracefully
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# Schedule shutdown event
loop.call_soon_threadsafe(_shutdown_complete.set)
except RuntimeError:
pass # No event loop running
def setup_signal_handlers():
"""
Set up signal handlers for graceful shutdown.
Handles:
- SIGTERM: Standard termination signal (from systemd, docker, etc.)
- SIGINT: Keyboard interrupt (Ctrl+C)
"""
logger = logging.getLogger(__name__)
try:
# SIGTERM - Standard termination signal
signal.signal(signal.SIGTERM, _signal_handler)
logger.info("SIGTERM handler installed")
# SIGINT - Keyboard interrupt
signal.signal(signal.SIGINT, _signal_handler)
logger.info("SIGINT handler installed")
except (ValueError, OSError) as e:
# Signal handling may not be available in all contexts
logger.warning(f"Could not install signal handlers: {e}")
# Ensure log directory exists before configuring logging
Path(settings.log_file).parent.mkdir(parents=True, exist_ok=True)
# Configure logging - Force configuration to override uvicorn's settings
logging.basicConfig(
level=getattr(logging, settings.log_level),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(settings.log_file),
logging.StreamHandler(),
],
force=True # Force reconfiguration (Python 3.8+)
)
# Also explicitly configure root logger and app loggers
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, settings.log_level))
# Configure app-specific loggers
for logger_name in ['app', 'app.services', 'app.services.pdf_generator_service', 'app.services.ocr_service']:
app_logger = logging.getLogger(logger_name)
app_logger.setLevel(getattr(logging, settings.log_level))
app_logger.propagate = True # Ensure logs propagate to root logger
logger = logging.getLogger(__name__)
async def drain_connections(timeout: float = 30.0):
"""
Wait for active connections to complete (connection draining).
Args:
timeout: Maximum time to wait for connections to drain
"""
logger.info(f"Draining connections (timeout={timeout}s)...")
start_time = asyncio.get_event_loop().time()
while get_active_connections() > 0:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed >= timeout:
logger.warning(
f"Connection drain timeout after {timeout}s. "
f"{get_active_connections()} connections still active."
)
break
logger.info(f"Waiting for {get_active_connections()} active connections...")
await asyncio.sleep(1.0)
if get_active_connections() == 0:
logger.info("All connections drained successfully")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan events"""
# Startup
logger.info("Starting Tool_OCR V2 application...")
# Set up signal handlers for graceful shutdown
setup_signal_handlers()
# Ensure all directories exist
settings.ensure_directories()
logger.info("All directories created/verified")
# Initialize memory management if enabled
if settings.enable_model_lifecycle_management:
try:
from app.services.memory_manager import get_model_manager, MemoryConfig
memory_config = MemoryConfig(
warning_threshold=settings.memory_warning_threshold,
critical_threshold=settings.memory_critical_threshold,
emergency_threshold=settings.memory_emergency_threshold,
model_idle_timeout_seconds=settings.pp_structure_idle_timeout_seconds,
memory_check_interval_seconds=settings.memory_check_interval_seconds,
enable_auto_cleanup=settings.enable_memory_optimization,
enable_emergency_cleanup=settings.enable_emergency_cleanup,
max_concurrent_predictions=settings.max_concurrent_predictions,
enable_cpu_fallback=settings.enable_cpu_fallback,
gpu_memory_limit_mb=settings.gpu_memory_limit_mb,
)
get_model_manager(memory_config)
logger.info("Memory management initialized")
except Exception as e:
logger.warning(f"Failed to initialize memory management: {e}")
# Initialize service pool if enabled
if settings.enable_service_pool:
try:
from app.services.service_pool import get_service_pool, PoolConfig
pool_config = PoolConfig(
max_services_per_device=settings.max_services_per_device,
max_total_services=settings.max_total_services,
acquire_timeout_seconds=settings.service_acquire_timeout_seconds,
max_queue_size=settings.max_queue_size,
)
get_service_pool(pool_config)
logger.info("OCR service pool initialized")
except Exception as e:
logger.warning(f"Failed to initialize service pool: {e}")
# Initialize prediction semaphore for controlling concurrent PP-StructureV3 predictions
try:
from app.services.memory_manager import get_prediction_semaphore
get_prediction_semaphore(max_concurrent=settings.max_concurrent_predictions)
logger.info(f"Prediction semaphore initialized (max_concurrent={settings.max_concurrent_predictions})")
except Exception as e:
logger.warning(f"Failed to initialize prediction semaphore: {e}")
# Initialize cleanup scheduler if enabled
if settings.cleanup_enabled:
try:
from app.services.cleanup_scheduler import start_cleanup_scheduler
await start_cleanup_scheduler()
logger.info("Cleanup scheduler initialized")
except Exception as e:
logger.warning(f"Failed to initialize cleanup scheduler: {e}")
logger.info("Application startup complete")
yield
# Shutdown
logger.info("Shutting down Tool_OCR application...")
# Stop cleanup scheduler
if settings.cleanup_enabled:
try:
from app.services.cleanup_scheduler import stop_cleanup_scheduler
await stop_cleanup_scheduler()
logger.info("Cleanup scheduler stopped")
except Exception as e:
logger.warning(f"Error stopping cleanup scheduler: {e}")
# Connection draining - wait for active requests to complete
await drain_connections(timeout=30.0)
# Shutdown recovery manager if initialized
try:
from app.services.memory_manager import shutdown_recovery_manager
shutdown_recovery_manager()
logger.info("Recovery manager shutdown complete")
except Exception as e:
logger.debug(f"Recovery manager shutdown skipped: {e}")
# Shutdown service pool
if settings.enable_service_pool:
try:
from app.services.service_pool import shutdown_service_pool
shutdown_service_pool()
logger.info("Service pool shutdown complete")
except Exception as e:
logger.warning(f"Error shutting down service pool: {e}")
# Shutdown prediction semaphore
try:
from app.services.memory_manager import shutdown_prediction_semaphore
shutdown_prediction_semaphore()
logger.info("Prediction semaphore shutdown complete")
except Exception as e:
logger.warning(f"Error shutting down prediction semaphore: {e}")
# Shutdown memory manager
if settings.enable_model_lifecycle_management:
try:
from app.services.memory_manager import shutdown_model_manager
shutdown_model_manager()
logger.info("Memory manager shutdown complete")
except Exception as e:
logger.warning(f"Error shutting down memory manager: {e}")
logger.info("Tool_OCR shutdown complete")
# Create FastAPI application
app = FastAPI(
title="Tool_OCR V2",
description="OCR Processing System with External Authentication & Task Isolation",
version="2.0.0",
lifespan=lifespan,
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint with GPU status and memory management info"""
response = {
"status": "healthy",
"service": "Tool_OCR V2",
"version": "2.0.0",
}
# Add GPU status information
try:
# Use service pool if available to avoid creating new instances
gpu_status = None
if settings.enable_service_pool:
try:
from app.services.service_pool import get_service_pool
pool = get_service_pool()
pool_stats = pool.get_pool_stats()
response["service_pool"] = pool_stats
# Get GPU status from first available service
for device, services in pool.services.items():
for pooled in services:
if hasattr(pooled.service, 'get_gpu_status'):
gpu_status = pooled.service.get_gpu_status()
break
if gpu_status:
break
except Exception as e:
logger.debug(f"Could not get service pool stats: {e}")
# Fallback: create temporary instance if no pool or no service available
if gpu_status is None:
from app.services.ocr_service import OCRService
ocr_service = OCRService()
gpu_status = ocr_service.get_gpu_status()
response["gpu"] = {
"available": gpu_status.get("gpu_available", False),
"enabled": gpu_status.get("gpu_enabled", False),
"device_name": gpu_status.get("device_name", "N/A"),
"device_count": gpu_status.get("device_count", 0),
"compute_capability": gpu_status.get("compute_capability", "N/A"),
}
# Add memory info if available
if gpu_status.get("memory_total_mb"):
response["gpu"]["memory"] = {
"total_mb": round(gpu_status.get("memory_total_mb", 0), 2),
"allocated_mb": round(gpu_status.get("memory_allocated_mb", 0), 2),
"utilization_percent": round(gpu_status.get("memory_utilization", 0), 2),
}
# Add reason if GPU is not available
if not gpu_status.get("gpu_available") and gpu_status.get("reason"):
response["gpu"]["reason"] = gpu_status.get("reason")
except Exception as e:
logger.warning(f"Failed to get GPU status: {e}")
response["gpu"] = {
"available": False,
"error": str(e),
}
# Add memory management status
if settings.enable_model_lifecycle_management:
try:
from app.services.memory_manager import get_model_manager
model_manager = get_model_manager()
response["memory_management"] = model_manager.get_model_stats()
except Exception as e:
logger.debug(f"Could not get memory management stats: {e}")
return response
# Root endpoint
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "Tool_OCR API V2 - External Authentication",
"version": "2.0.0",
"docs_url": "/docs",
"health_check": "/health",
}
# Include V2 API routers
from app.routers import auth, tasks, admin, translate
from fastapi import UploadFile, File, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
import hashlib
from app.core.deps import get_db, get_current_user
from app.models.user import User
from app.models.task import TaskFile
from app.schemas.task import UploadResponse, TaskStatusEnum
from app.services.task_service import task_service
from app.services.audit_service import audit_service
def get_client_ip(request: Request) -> str:
"""Extract client IP address from request"""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
return request.client.host if request.client else "unknown"
def get_user_agent(request: Request) -> str:
"""Extract user agent from request"""
return request.headers.get("User-Agent", "unknown")[:500]
app.include_router(auth.router)
app.include_router(tasks.router)
app.include_router(admin.router)
app.include_router(translate.router)
# File upload endpoint
@app.post("/api/v2/upload", response_model=UploadResponse, tags=["Upload"], summary="Upload file for OCR")
async def upload_file(
request: Request,
file: UploadFile = File(..., description="File to upload (PNG, JPG, PDF, etc.)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Upload a file for OCR processing
Creates a new task and uploads the file
- **file**: File to upload
"""
try:
# Validate file extension
file_ext = Path(file.filename).suffix.lower().lstrip('.')
if file_ext not in settings.allowed_extensions_list:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File type .{file_ext} not allowed. Allowed types: {', '.join(settings.allowed_extensions_list)}"
)
# Read file content
file_content = await file.read()
file_size = len(file_content)
# Calculate file hash
file_hash = hashlib.sha256(file_content).hexdigest()
# Create task
task = task_service.create_task(
db=db,
user_id=current_user.id,
filename=file.filename,
file_type=file.content_type
)
# Save file to disk
upload_dir = Path(settings.upload_dir)
upload_dir.mkdir(parents=True, exist_ok=True)
# Create unique filename using task_id
unique_filename = f"{task.task_id}_{file.filename}"
file_path = upload_dir / unique_filename
# Write file
with open(file_path, "wb") as f:
f.write(file_content)
# Create TaskFile record
task_file = TaskFile(
task_id=task.id,
original_name=file.filename,
stored_path=str(file_path),
file_size=file_size,
mime_type=file.content_type,
file_hash=file_hash
)
db.add(task_file)
db.commit()
logger.info(f"Uploaded file {file.filename} ({file_size} bytes) for task {task.task_id}, user {current_user.email}")
# Log file upload event
audit_service.log_event(
db=db,
event_type="file_upload",
event_category="file",
description=f"File uploaded: {file.filename} ({file_size} bytes)",
user_id=current_user.id,
ip_address=get_client_ip(request),
user_agent=get_user_agent(request),
resource_type="task",
resource_id=task.task_id,
success=True,
metadata={
"filename": file.filename,
"file_size": file_size,
"file_type": file.content_type
}
)
return {
"task_id": task.task_id,
"filename": file.filename,
"file_size": file_size,
"file_type": file.content_type or "application/octet-stream",
"status": TaskStatusEnum.PENDING
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"Failed to upload file for user {current_user.id}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to upload file: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.backend_host,
port=settings.backend_port,
reload=True,
log_level=settings.log_level.lower(),
)