feat: implement hybrid image extraction and memory management
Backend: - Add hybrid image extraction for Direct track (inline image blocks) - Add render_inline_image_regions() fallback when OCR doesn't find images - Add check_document_for_missing_images() for detecting missing images - Add memory management system (MemoryGuard, ModelManager, ServicePool) - Update pdf_generator_service to handle HYBRID processing track - Add ElementType.LOGO for logo extraction Frontend: - Fix PDF viewer re-rendering issues with memoization - Add TaskNotFound component and useTaskValidation hook - Disable StrictMode due to react-pdf incompatibility - Fix task detail and results page loading states 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -104,6 +104,37 @@ class Settings(BaseSettings):
|
||||
enable_cudnn_benchmark: bool = Field(default=True) # Optimize convolution algorithms
|
||||
num_threads: int = Field(default=4) # CPU threads for preprocessing
|
||||
|
||||
# ===== Enhanced Memory Management Configuration =====
|
||||
# Memory thresholds (as ratio of total GPU memory)
|
||||
memory_warning_threshold: float = Field(default=0.80) # 80% - start warning
|
||||
memory_critical_threshold: float = Field(default=0.95) # 95% - throttle operations
|
||||
memory_emergency_threshold: float = Field(default=0.98) # 98% - emergency cleanup
|
||||
|
||||
# Memory monitoring
|
||||
memory_check_interval_seconds: int = Field(default=30) # Background check interval
|
||||
enable_memory_alerts: bool = Field(default=True) # Enable memory alerts
|
||||
|
||||
# Model lifecycle management
|
||||
enable_model_lifecycle_management: bool = Field(default=True) # Use ModelManager
|
||||
pp_structure_idle_timeout_seconds: int = Field(default=300) # Unload PP-Structure after idle
|
||||
structure_model_memory_mb: int = Field(default=2000) # Estimated memory for PP-StructureV3
|
||||
ocr_model_memory_mb: int = Field(default=500) # Estimated memory per OCR language model
|
||||
|
||||
# Service pool configuration
|
||||
enable_service_pool: bool = Field(default=True) # Use OCRServicePool
|
||||
max_services_per_device: int = Field(default=1) # Max OCRService per GPU
|
||||
max_total_services: int = Field(default=2) # Max total OCRService instances
|
||||
service_acquire_timeout_seconds: float = Field(default=300.0) # Timeout for acquiring service
|
||||
max_queue_size: int = Field(default=50) # Max pending tasks per device
|
||||
|
||||
# Concurrency control
|
||||
max_concurrent_predictions: int = Field(default=2) # Max concurrent PP-StructureV3 predictions
|
||||
enable_cpu_fallback: bool = Field(default=True) # Fall back to CPU when GPU memory low
|
||||
|
||||
# Emergency recovery
|
||||
enable_emergency_cleanup: bool = Field(default=True) # Auto-cleanup on memory pressure
|
||||
enable_worker_restart: bool = Field(default=False) # Restart workers on OOM (requires supervisor)
|
||||
|
||||
# ===== File Upload Configuration =====
|
||||
max_upload_size: int = Field(default=52428800) # 50MB
|
||||
allowed_extensions: str = Field(default="png,jpg,jpeg,pdf,bmp,tiff,doc,docx,ppt,pptx")
|
||||
|
||||
@@ -7,10 +7,103 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 6.1: Signal Handlers
|
||||
# =============================================================================
|
||||
|
||||
# Flag to indicate graceful shutdown is in progress
|
||||
_shutdown_requested = False
|
||||
_shutdown_complete = asyncio.Event()
|
||||
|
||||
# Track active connections for draining
|
||||
_active_connections = 0
|
||||
_connection_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def increment_connections():
|
||||
"""Track active connection count"""
|
||||
global _active_connections
|
||||
async with _connection_lock:
|
||||
_active_connections += 1
|
||||
|
||||
|
||||
async def decrement_connections():
|
||||
"""Track active connection count"""
|
||||
global _active_connections
|
||||
async with _connection_lock:
|
||||
_active_connections -= 1
|
||||
|
||||
|
||||
def get_active_connections() -> int:
|
||||
"""Get current active connection count"""
|
||||
return _active_connections
|
||||
|
||||
|
||||
def is_shutdown_requested() -> bool:
|
||||
"""Check if graceful shutdown has been requested"""
|
||||
return _shutdown_requested
|
||||
|
||||
|
||||
def _signal_handler(signum: int, frame) -> None:
|
||||
"""
|
||||
Signal handler for SIGTERM and SIGINT.
|
||||
|
||||
Initiates graceful shutdown by setting the shutdown flag.
|
||||
The actual cleanup is handled by the lifespan context manager.
|
||||
"""
|
||||
global _shutdown_requested
|
||||
signal_name = signal.Signals(signum).name
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if _shutdown_requested:
|
||||
logger.warning(f"Received {signal_name} again, forcing immediate exit...")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(f"Received {signal_name}, initiating graceful shutdown...")
|
||||
_shutdown_requested = True
|
||||
|
||||
# Try to stop the event loop gracefully
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Schedule shutdown event
|
||||
loop.call_soon_threadsafe(_shutdown_complete.set)
|
||||
except RuntimeError:
|
||||
pass # No event loop running
|
||||
|
||||
|
||||
def setup_signal_handlers():
|
||||
"""
|
||||
Set up signal handlers for graceful shutdown.
|
||||
|
||||
Handles:
|
||||
- SIGTERM: Standard termination signal (from systemd, docker, etc.)
|
||||
- SIGINT: Keyboard interrupt (Ctrl+C)
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
# SIGTERM - Standard termination signal
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
logger.info("SIGTERM handler installed")
|
||||
|
||||
# SIGINT - Keyboard interrupt
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
logger.info("SIGINT handler installed")
|
||||
|
||||
except (ValueError, OSError) as e:
|
||||
# Signal handling may not be available in all contexts
|
||||
logger.warning(f"Could not install signal handlers: {e}")
|
||||
|
||||
# Ensure log directory exists before configuring logging
|
||||
Path(settings.log_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -38,16 +131,91 @@ for logger_name in ['app', 'app.services', 'app.services.pdf_generator_service',
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def drain_connections(timeout: float = 30.0):
|
||||
"""
|
||||
Wait for active connections to complete (connection draining).
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for connections to drain
|
||||
"""
|
||||
logger.info(f"Draining connections (timeout={timeout}s)...")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
while get_active_connections() > 0:
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
if elapsed >= timeout:
|
||||
logger.warning(
|
||||
f"Connection drain timeout after {timeout}s. "
|
||||
f"{get_active_connections()} connections still active."
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(f"Waiting for {get_active_connections()} active connections...")
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
if get_active_connections() == 0:
|
||||
logger.info("All connections drained successfully")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan events"""
|
||||
# Startup
|
||||
logger.info("Starting Tool_OCR V2 application...")
|
||||
|
||||
# Set up signal handlers for graceful shutdown
|
||||
setup_signal_handlers()
|
||||
|
||||
# Ensure all directories exist
|
||||
settings.ensure_directories()
|
||||
logger.info("All directories created/verified")
|
||||
|
||||
# Initialize memory management if enabled
|
||||
if settings.enable_model_lifecycle_management:
|
||||
try:
|
||||
from app.services.memory_manager import get_model_manager, MemoryConfig
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
warning_threshold=settings.memory_warning_threshold,
|
||||
critical_threshold=settings.memory_critical_threshold,
|
||||
emergency_threshold=settings.memory_emergency_threshold,
|
||||
model_idle_timeout_seconds=settings.pp_structure_idle_timeout_seconds,
|
||||
memory_check_interval_seconds=settings.memory_check_interval_seconds,
|
||||
enable_auto_cleanup=settings.enable_memory_optimization,
|
||||
enable_emergency_cleanup=settings.enable_emergency_cleanup,
|
||||
max_concurrent_predictions=settings.max_concurrent_predictions,
|
||||
enable_cpu_fallback=settings.enable_cpu_fallback,
|
||||
gpu_memory_limit_mb=settings.gpu_memory_limit_mb,
|
||||
)
|
||||
get_model_manager(memory_config)
|
||||
logger.info("Memory management initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize memory management: {e}")
|
||||
|
||||
# Initialize service pool if enabled
|
||||
if settings.enable_service_pool:
|
||||
try:
|
||||
from app.services.service_pool import get_service_pool, PoolConfig
|
||||
|
||||
pool_config = PoolConfig(
|
||||
max_services_per_device=settings.max_services_per_device,
|
||||
max_total_services=settings.max_total_services,
|
||||
acquire_timeout_seconds=settings.service_acquire_timeout_seconds,
|
||||
max_queue_size=settings.max_queue_size,
|
||||
)
|
||||
get_service_pool(pool_config)
|
||||
logger.info("OCR service pool initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize service pool: {e}")
|
||||
|
||||
# Initialize prediction semaphore for controlling concurrent PP-StructureV3 predictions
|
||||
try:
|
||||
from app.services.memory_manager import get_prediction_semaphore
|
||||
get_prediction_semaphore(max_concurrent=settings.max_concurrent_predictions)
|
||||
logger.info(f"Prediction semaphore initialized (max_concurrent={settings.max_concurrent_predictions})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize prediction semaphore: {e}")
|
||||
|
||||
logger.info("Application startup complete")
|
||||
|
||||
yield
|
||||
@@ -55,6 +223,45 @@ async def lifespan(app: FastAPI):
|
||||
# Shutdown
|
||||
logger.info("Shutting down Tool_OCR application...")
|
||||
|
||||
# Connection draining - wait for active requests to complete
|
||||
await drain_connections(timeout=30.0)
|
||||
|
||||
# Shutdown recovery manager if initialized
|
||||
try:
|
||||
from app.services.memory_manager import shutdown_recovery_manager
|
||||
shutdown_recovery_manager()
|
||||
logger.info("Recovery manager shutdown complete")
|
||||
except Exception as e:
|
||||
logger.debug(f"Recovery manager shutdown skipped: {e}")
|
||||
|
||||
# Shutdown service pool
|
||||
if settings.enable_service_pool:
|
||||
try:
|
||||
from app.services.service_pool import shutdown_service_pool
|
||||
shutdown_service_pool()
|
||||
logger.info("Service pool shutdown complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down service pool: {e}")
|
||||
|
||||
# Shutdown prediction semaphore
|
||||
try:
|
||||
from app.services.memory_manager import shutdown_prediction_semaphore
|
||||
shutdown_prediction_semaphore()
|
||||
logger.info("Prediction semaphore shutdown complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down prediction semaphore: {e}")
|
||||
|
||||
# Shutdown memory manager
|
||||
if settings.enable_model_lifecycle_management:
|
||||
try:
|
||||
from app.services.memory_manager import shutdown_model_manager
|
||||
shutdown_model_manager()
|
||||
logger.info("Memory manager shutdown complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down memory manager: {e}")
|
||||
|
||||
logger.info("Tool_OCR shutdown complete")
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
@@ -77,9 +284,7 @@ app.add_middleware(
|
||||
# Health check endpoint
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint with GPU status"""
|
||||
from app.services.ocr_service import OCRService
|
||||
|
||||
"""Health check endpoint with GPU status and memory management info"""
|
||||
response = {
|
||||
"status": "healthy",
|
||||
"service": "Tool_OCR V2",
|
||||
@@ -88,10 +293,31 @@ async def health_check():
|
||||
|
||||
# Add GPU status information
|
||||
try:
|
||||
# Create temporary OCRService instance to get GPU status
|
||||
# In production, this should be a singleton service
|
||||
ocr_service = OCRService()
|
||||
gpu_status = ocr_service.get_gpu_status()
|
||||
# Use service pool if available to avoid creating new instances
|
||||
gpu_status = None
|
||||
if settings.enable_service_pool:
|
||||
try:
|
||||
from app.services.service_pool import get_service_pool
|
||||
pool = get_service_pool()
|
||||
pool_stats = pool.get_pool_stats()
|
||||
response["service_pool"] = pool_stats
|
||||
|
||||
# Get GPU status from first available service
|
||||
for device, services in pool.services.items():
|
||||
for pooled in services:
|
||||
if hasattr(pooled.service, 'get_gpu_status'):
|
||||
gpu_status = pooled.service.get_gpu_status()
|
||||
break
|
||||
if gpu_status:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get service pool stats: {e}")
|
||||
|
||||
# Fallback: create temporary instance if no pool or no service available
|
||||
if gpu_status is None:
|
||||
from app.services.ocr_service import OCRService
|
||||
ocr_service = OCRService()
|
||||
gpu_status = ocr_service.get_gpu_status()
|
||||
|
||||
response["gpu"] = {
|
||||
"available": gpu_status.get("gpu_available", False),
|
||||
@@ -120,6 +346,15 @@ async def health_check():
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
# Add memory management status
|
||||
if settings.enable_model_lifecycle_management:
|
||||
try:
|
||||
from app.services.memory_manager import get_model_manager
|
||||
model_manager = get_model_manager()
|
||||
response["memory_management"] = model_manager.get_model_stats()
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get memory management stats: {e}")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@@ -212,26 +212,44 @@ class TableData:
|
||||
if self.caption:
|
||||
html.append(f"<caption>{self.caption}</caption>")
|
||||
|
||||
# Group cells by row
|
||||
rows_data = {}
|
||||
# Group cells by row and column for quick lookup
|
||||
cell_map = {}
|
||||
for cell in self.cells:
|
||||
if cell.row not in rows_data:
|
||||
rows_data[cell.row] = []
|
||||
rows_data[cell.row].append(cell)
|
||||
cell_map[(cell.row, cell.col)] = cell
|
||||
|
||||
# Generate HTML
|
||||
# Track which cells are covered by row/col spans
|
||||
covered = set()
|
||||
for cell in self.cells:
|
||||
if cell.row_span > 1 or cell.col_span > 1:
|
||||
for r in range(cell.row, cell.row + cell.row_span):
|
||||
for c in range(cell.col, cell.col + cell.col_span):
|
||||
if (r, c) != (cell.row, cell.col):
|
||||
covered.add((r, c))
|
||||
|
||||
# Generate HTML with proper column filling
|
||||
for row_idx in range(self.rows):
|
||||
html.append("<tr>")
|
||||
if row_idx in rows_data:
|
||||
for cell in sorted(rows_data[row_idx], key=lambda c: c.col):
|
||||
for col_idx in range(self.cols):
|
||||
# Skip cells covered by row/col spans
|
||||
if (row_idx, col_idx) in covered:
|
||||
continue
|
||||
|
||||
cell = cell_map.get((row_idx, col_idx))
|
||||
tag = "th" if row_idx == 0 and self.headers else "td"
|
||||
|
||||
if cell:
|
||||
span_attrs = []
|
||||
if cell.row_span > 1:
|
||||
span_attrs.append(f'rowspan="{cell.row_span}"')
|
||||
if cell.col_span > 1:
|
||||
span_attrs.append(f'colspan="{cell.col_span}"')
|
||||
span_str = " ".join(span_attrs)
|
||||
tag = "th" if row_idx == 0 and self.headers else "td"
|
||||
html.append(f'<{tag} {span_str}>{cell.content}</{tag}>')
|
||||
content = cell.content if cell.content else ""
|
||||
html.append(f'<{tag} {span_str}>{content}</{tag}>')
|
||||
else:
|
||||
# Fill in empty cell for missing positions
|
||||
html.append(f'<{tag}></{tag}>')
|
||||
|
||||
html.append("</tr>")
|
||||
|
||||
html.append("</table>")
|
||||
|
||||
@@ -39,6 +39,7 @@ from app.schemas.task import (
|
||||
from app.services.task_service import task_service
|
||||
from app.services.file_access_service import file_access_service
|
||||
from app.services.ocr_service import OCRService
|
||||
from app.services.service_pool import get_service_pool, PoolConfig
|
||||
|
||||
# Import dual-track components
|
||||
try:
|
||||
@@ -47,6 +48,13 @@ try:
|
||||
except ImportError:
|
||||
DUAL_TRACK_AVAILABLE = False
|
||||
|
||||
# Service pool availability
|
||||
SERVICE_POOL_AVAILABLE = True
|
||||
try:
|
||||
from app.services.memory_manager import get_model_manager
|
||||
except ImportError:
|
||||
SERVICE_POOL_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v2/tasks", tags=["Tasks"])
|
||||
@@ -63,7 +71,10 @@ def process_task_ocr(
|
||||
pp_structure_params: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Background task to process OCR for a task with dual-track support
|
||||
Background task to process OCR for a task with dual-track support.
|
||||
|
||||
Uses OCRServicePool to acquire a shared service instance instead of
|
||||
creating a new one, preventing GPU memory proliferation.
|
||||
|
||||
Args:
|
||||
task_id: Task UUID string
|
||||
@@ -80,6 +91,7 @@ def process_task_ocr(
|
||||
|
||||
db = SessionLocal()
|
||||
start_time = datetime.now()
|
||||
pooled_service = None
|
||||
|
||||
try:
|
||||
logger.info(f"Starting OCR processing for task {task_id}, file: {filename}")
|
||||
@@ -91,16 +103,39 @@ def process_task_ocr(
|
||||
logger.error(f"Task {task_id} not found in database")
|
||||
return
|
||||
|
||||
# Initialize OCR service
|
||||
ocr_service = OCRService()
|
||||
# Acquire OCR service from pool (or create new if pool disabled)
|
||||
ocr_service = None
|
||||
if settings.enable_service_pool and SERVICE_POOL_AVAILABLE:
|
||||
try:
|
||||
service_pool = get_service_pool()
|
||||
pooled_service = service_pool.acquire(
|
||||
device="GPU:0",
|
||||
timeout=settings.service_acquire_timeout_seconds,
|
||||
task_id=task_id
|
||||
)
|
||||
if pooled_service:
|
||||
ocr_service = pooled_service.service
|
||||
logger.info(f"Acquired OCR service from pool for task {task_id}")
|
||||
else:
|
||||
logger.warning(f"Timeout acquiring service from pool, creating new instance")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to acquire service from pool: {e}, creating new instance")
|
||||
|
||||
# Fallback: create new instance if pool acquisition failed
|
||||
if ocr_service is None:
|
||||
logger.info("Creating new OCRService instance (pool disabled or unavailable)")
|
||||
ocr_service = OCRService()
|
||||
|
||||
# Create result directory before OCR processing (needed for saving extracted images)
|
||||
result_dir = Path(settings.result_dir) / task_id
|
||||
result_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process the file with OCR (use dual-track if available)
|
||||
if use_dual_track and hasattr(ocr_service, 'process'):
|
||||
# Use new dual-track processing
|
||||
# Process the file with OCR
|
||||
# Use dual-track processing if:
|
||||
# 1. use_dual_track is True (auto-detection)
|
||||
# 2. OR force_track is specified (explicit track selection)
|
||||
if (use_dual_track or force_track) and hasattr(ocr_service, 'process'):
|
||||
# Use new dual-track processing (or forced track)
|
||||
ocr_result = ocr_service.process(
|
||||
file_path=Path(file_path),
|
||||
lang=language,
|
||||
@@ -111,7 +146,7 @@ def process_task_ocr(
|
||||
pp_structure_params=pp_structure_params
|
||||
)
|
||||
else:
|
||||
# Fall back to traditional processing
|
||||
# Fall back to traditional processing (no force_track support)
|
||||
ocr_result = ocr_service.process_image(
|
||||
image_path=Path(file_path),
|
||||
lang=language,
|
||||
@@ -131,6 +166,16 @@ def process_task_ocr(
|
||||
source_file_path=Path(file_path)
|
||||
)
|
||||
|
||||
# Release service back to pool (success case)
|
||||
if pooled_service:
|
||||
try:
|
||||
service_pool = get_service_pool()
|
||||
service_pool.release(pooled_service, error=None)
|
||||
logger.info(f"Released OCR service back to pool for task {task_id}")
|
||||
pooled_service = None # Prevent double release in finally
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release service to pool: {e}")
|
||||
|
||||
# Close old session and create fresh one to avoid MySQL timeout
|
||||
# (long OCR processing may cause connection to become stale)
|
||||
db.close()
|
||||
@@ -158,6 +203,15 @@ def process_task_ocr(
|
||||
except Exception as e:
|
||||
logger.exception(f"OCR processing failed for task {task_id}")
|
||||
|
||||
# Release service back to pool with error
|
||||
if pooled_service:
|
||||
try:
|
||||
service_pool = get_service_pool()
|
||||
service_pool.release(pooled_service, error=e)
|
||||
pooled_service = None
|
||||
except Exception as release_error:
|
||||
logger.warning(f"Failed to release service to pool: {release_error}")
|
||||
|
||||
# Update task status to failed (direct database update)
|
||||
try:
|
||||
task = db.query(Task).filter(Task.id == task_db_id).first()
|
||||
@@ -170,6 +224,13 @@ def process_task_ocr(
|
||||
logger.error(f"Failed to update task status: {update_error}")
|
||||
|
||||
finally:
|
||||
# Ensure service is released in case of any missed release
|
||||
if pooled_service:
|
||||
try:
|
||||
service_pool = get_service_pool()
|
||||
service_pool.release(pooled_service, error=None)
|
||||
except Exception:
|
||||
pass
|
||||
db.close()
|
||||
|
||||
|
||||
@@ -330,7 +391,13 @@ async def get_task(
|
||||
with open(result_path) as f:
|
||||
result_data = json.load(f)
|
||||
metadata = result_data.get("metadata", {})
|
||||
processing_track = metadata.get("processing_track")
|
||||
track_str = metadata.get("processing_track")
|
||||
# Convert string to enum to avoid Pydantic serialization warning
|
||||
if track_str:
|
||||
try:
|
||||
processing_track = ProcessingTrackEnum(track_str)
|
||||
except ValueError:
|
||||
processing_track = None
|
||||
except Exception:
|
||||
pass # Silently ignore errors reading the result file
|
||||
|
||||
|
||||
@@ -247,9 +247,11 @@ class DirectExtractionEngine:
|
||||
element_counter += len(image_elements)
|
||||
|
||||
# Extract vector graphics (charts, diagrams) from drawing commands
|
||||
# Pass table_bboxes to filter out table border drawings before clustering
|
||||
if self.enable_image_extraction:
|
||||
vector_elements = self._extract_vector_graphics(
|
||||
page, page_num, document_id, element_counter, output_dir
|
||||
page, page_num, document_id, element_counter, output_dir,
|
||||
table_bboxes=table_bboxes
|
||||
)
|
||||
elements.extend(vector_elements)
|
||||
element_counter += len(vector_elements)
|
||||
@@ -705,40 +707,52 @@ class DirectExtractionEngine:
|
||||
y1=bbox_data[3]
|
||||
)
|
||||
|
||||
# Extract column widths from table cells
|
||||
# Extract column widths from table cells by analyzing X boundaries
|
||||
column_widths = []
|
||||
if hasattr(table, 'cells') and table.cells:
|
||||
# Group cells by column
|
||||
cols_x = {}
|
||||
# Collect all unique X boundaries (both left and right edges)
|
||||
x_boundaries = set()
|
||||
for cell in table.cells:
|
||||
col_idx = None
|
||||
# Determine column index by x0 position
|
||||
for idx, x0 in enumerate(sorted(set(c[0] for c in table.cells))):
|
||||
if abs(cell[0] - x0) < 1.0: # Within 1pt tolerance
|
||||
col_idx = idx
|
||||
break
|
||||
x_boundaries.add(round(cell[0], 1)) # x0 (left edge)
|
||||
x_boundaries.add(round(cell[2], 1)) # x1 (right edge)
|
||||
|
||||
if col_idx is not None:
|
||||
if col_idx not in cols_x:
|
||||
cols_x[col_idx] = {'x0': cell[0], 'x1': cell[2]}
|
||||
else:
|
||||
cols_x[col_idx]['x1'] = max(cols_x[col_idx]['x1'], cell[2])
|
||||
# Sort boundaries to get column edges
|
||||
sorted_x = sorted(x_boundaries)
|
||||
|
||||
# Calculate width for each column
|
||||
for col_idx in sorted(cols_x.keys()):
|
||||
width = cols_x[col_idx]['x1'] - cols_x[col_idx]['x0']
|
||||
column_widths.append(width)
|
||||
# Calculate column widths from adjacent boundaries
|
||||
if len(sorted_x) >= 2:
|
||||
column_widths = [sorted_x[i+1] - sorted_x[i] for i in range(len(sorted_x)-1)]
|
||||
logger.debug(f"Calculated column widths from {len(sorted_x)} boundaries: {column_widths}")
|
||||
|
||||
# Extract row heights from table cells by analyzing Y boundaries
|
||||
row_heights = []
|
||||
if hasattr(table, 'cells') and table.cells:
|
||||
# Collect all unique Y boundaries (both top and bottom edges)
|
||||
y_boundaries = set()
|
||||
for cell in table.cells:
|
||||
y_boundaries.add(round(cell[1], 1)) # y0 (top edge)
|
||||
y_boundaries.add(round(cell[3], 1)) # y1 (bottom edge)
|
||||
|
||||
# Sort boundaries to get row edges
|
||||
sorted_y = sorted(y_boundaries)
|
||||
|
||||
# Calculate row heights from adjacent boundaries
|
||||
if len(sorted_y) >= 2:
|
||||
row_heights = [sorted_y[i+1] - sorted_y[i] for i in range(len(sorted_y)-1)]
|
||||
logger.debug(f"Calculated row heights from {len(sorted_y)} boundaries: {row_heights}")
|
||||
|
||||
# Create table cells
|
||||
# Note: Include ALL cells (even empty ones) to preserve table structure
|
||||
# This is critical for correct HTML generation and PDF rendering
|
||||
cells = []
|
||||
for row_idx, row in enumerate(data):
|
||||
for col_idx, cell_text in enumerate(row):
|
||||
if cell_text:
|
||||
cells.append(TableCell(
|
||||
row=row_idx,
|
||||
col=col_idx,
|
||||
content=str(cell_text) if cell_text else ""
|
||||
))
|
||||
# Always add cell, even if empty, to maintain table structure
|
||||
cells.append(TableCell(
|
||||
row=row_idx,
|
||||
col=col_idx,
|
||||
content=str(cell_text) if cell_text else ""
|
||||
))
|
||||
|
||||
# Create table data
|
||||
table_data = TableData(
|
||||
@@ -748,8 +762,13 @@ class DirectExtractionEngine:
|
||||
headers=data[0] if data else None # Assume first row is header
|
||||
)
|
||||
|
||||
# Store column widths in metadata
|
||||
metadata = {"column_widths": column_widths} if column_widths else None
|
||||
# Store column widths and row heights in metadata
|
||||
metadata = {}
|
||||
if column_widths:
|
||||
metadata["column_widths"] = column_widths
|
||||
if row_heights:
|
||||
metadata["row_heights"] = row_heights
|
||||
metadata = metadata if metadata else None
|
||||
|
||||
return DocumentElement(
|
||||
element_id=f"table_{page_num}_{counter}",
|
||||
@@ -978,7 +997,9 @@ class DirectExtractionEngine:
|
||||
image_filename = f"{document_id}_p{page_num}_img{img_idx}.png"
|
||||
image_path = output_dir / image_filename
|
||||
pix.save(str(image_path))
|
||||
image_data["saved_path"] = str(image_path)
|
||||
# Store relative filename only (consistent with OCR track)
|
||||
# PDF generator will join with result_dir to get full path
|
||||
image_data["saved_path"] = image_filename
|
||||
logger.debug(f"Saved image to {image_path}")
|
||||
|
||||
element = DocumentElement(
|
||||
@@ -1001,12 +1022,272 @@ class DirectExtractionEngine:
|
||||
|
||||
return elements
|
||||
|
||||
def has_missing_images(self, page: fitz.Page) -> bool:
|
||||
"""
|
||||
Detect if a page likely has images that weren't extracted.
|
||||
|
||||
This checks for inline image blocks (type=1 in text dict) which indicate
|
||||
graphics composed of many small image blocks (like logos) that
|
||||
page.get_images() cannot detect.
|
||||
|
||||
Args:
|
||||
page: PyMuPDF page object
|
||||
|
||||
Returns:
|
||||
True if there are likely missing images that need OCR extraction
|
||||
"""
|
||||
try:
|
||||
# Check if get_images found anything
|
||||
standard_images = page.get_images()
|
||||
if standard_images:
|
||||
return False # Standard images were found, no need for fallback
|
||||
|
||||
# Check for inline image blocks (type=1)
|
||||
text_dict = page.get_text("dict", sort=True)
|
||||
blocks = text_dict.get("blocks", [])
|
||||
|
||||
image_block_count = sum(1 for b in blocks if b.get("type") == 1)
|
||||
|
||||
# If there are many inline image blocks, likely there's a logo or graphic
|
||||
if image_block_count >= 10:
|
||||
logger.info(f"Detected {image_block_count} inline image blocks - may need OCR for image extraction")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking for missing images: {e}")
|
||||
return False
|
||||
|
||||
def check_document_for_missing_images(self, pdf_path: Path) -> List[int]:
|
||||
"""
|
||||
Check a PDF document for pages that likely have missing images.
|
||||
|
||||
This opens the PDF and checks each page for inline image blocks
|
||||
that weren't extracted by get_images().
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file
|
||||
|
||||
Returns:
|
||||
List of page numbers (1-indexed) that have missing images
|
||||
"""
|
||||
pages_with_missing_images = []
|
||||
|
||||
try:
|
||||
doc = fitz.open(str(pdf_path))
|
||||
for page_num in range(len(doc)):
|
||||
page = doc[page_num]
|
||||
if self.has_missing_images(page):
|
||||
pages_with_missing_images.append(page_num + 1) # 1-indexed
|
||||
doc.close()
|
||||
|
||||
if pages_with_missing_images:
|
||||
logger.info(f"Document has missing images on pages: {pages_with_missing_images}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking document for missing images: {e}")
|
||||
|
||||
return pages_with_missing_images
|
||||
|
||||
def render_inline_image_regions(
|
||||
self,
|
||||
pdf_path: Path,
|
||||
unified_doc: 'UnifiedDocument',
|
||||
pages: List[int],
|
||||
output_dir: Optional[Path] = None
|
||||
) -> int:
|
||||
"""
|
||||
Render inline image regions and add them to the unified document.
|
||||
|
||||
This is a fallback when OCR doesn't detect images. It clusters inline
|
||||
image blocks (type=1) and renders them as images.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file
|
||||
unified_doc: UnifiedDocument to add images to
|
||||
pages: List of page numbers (1-indexed) to process
|
||||
output_dir: Directory to save rendered images
|
||||
|
||||
Returns:
|
||||
Number of images added
|
||||
"""
|
||||
images_added = 0
|
||||
|
||||
try:
|
||||
doc = fitz.open(str(pdf_path))
|
||||
|
||||
for page_num in pages:
|
||||
if page_num < 1 or page_num > len(doc):
|
||||
continue
|
||||
|
||||
page = doc[page_num - 1] # 0-indexed
|
||||
page_rect = page.rect
|
||||
|
||||
# Get inline image blocks
|
||||
text_dict = page.get_text("dict", sort=True)
|
||||
blocks = text_dict.get("blocks", [])
|
||||
|
||||
image_blocks = []
|
||||
for block in blocks:
|
||||
if block.get("type") == 1: # Image block
|
||||
bbox = block.get("bbox")
|
||||
if bbox:
|
||||
image_blocks.append(fitz.Rect(bbox))
|
||||
|
||||
if len(image_blocks) < 5: # Reduced from 10
|
||||
logger.debug(f"Page {page_num}: Only {len(image_blocks)} inline image blocks, skipping")
|
||||
continue
|
||||
|
||||
logger.info(f"Page {page_num}: Found {len(image_blocks)} inline image blocks")
|
||||
|
||||
# Cluster nearby image blocks
|
||||
regions = self._cluster_nearby_rects(image_blocks, tolerance=5.0)
|
||||
logger.info(f"Page {page_num}: Clustered into {len(regions)} regions")
|
||||
|
||||
# Find the corresponding page in unified_doc
|
||||
target_page = None
|
||||
for p in unified_doc.pages:
|
||||
if p.page_number == page_num:
|
||||
target_page = p
|
||||
break
|
||||
|
||||
if not target_page:
|
||||
continue
|
||||
|
||||
for region_idx, region_rect in enumerate(regions):
|
||||
logger.info(f"Page {page_num} region {region_idx}: {region_rect} (w={region_rect.width:.1f}, h={region_rect.height:.1f})")
|
||||
|
||||
# Skip very small regions
|
||||
if region_rect.width < 30 or region_rect.height < 30:
|
||||
logger.info(f" -> Skipped: too small (min 30x30)")
|
||||
continue
|
||||
|
||||
# Skip regions that are primarily in the table area (below top 40%)
|
||||
# But allow regions that START in the top portion
|
||||
page_30_pct = page_rect.height * 0.3
|
||||
page_40_pct = page_rect.height * 0.4
|
||||
if region_rect.y0 > page_40_pct:
|
||||
logger.info(f" -> Skipped: y0={region_rect.y0:.1f} > 40% of page ({page_40_pct:.1f})")
|
||||
continue
|
||||
|
||||
logger.info(f"Rendering inline image region {region_idx} on page {page_num}: {region_rect}")
|
||||
|
||||
try:
|
||||
# Add small padding
|
||||
clip_rect = region_rect + (-2, -2, 2, 2)
|
||||
clip_rect.intersect(page_rect)
|
||||
|
||||
# Render at 2x resolution
|
||||
mat = fitz.Matrix(2, 2)
|
||||
pix = page.get_pixmap(clip=clip_rect, matrix=mat, alpha=False)
|
||||
|
||||
# Create bounding box
|
||||
bbox = BoundingBox(
|
||||
x0=clip_rect.x0,
|
||||
y0=clip_rect.y0,
|
||||
x1=clip_rect.x1,
|
||||
y1=clip_rect.y1
|
||||
)
|
||||
|
||||
image_data = {
|
||||
"width": pix.width,
|
||||
"height": pix.height,
|
||||
"colorspace": "rgb",
|
||||
"type": "inline_region"
|
||||
}
|
||||
|
||||
# Save image if output directory provided
|
||||
if output_dir:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
doc_id = unified_doc.document_id or "unknown"
|
||||
image_filename = f"{doc_id}_p{page_num}_logo{region_idx}.png"
|
||||
image_path = output_dir / image_filename
|
||||
pix.save(str(image_path))
|
||||
image_data["saved_path"] = image_filename
|
||||
logger.info(f"Saved inline image region to {image_path}")
|
||||
|
||||
element = DocumentElement(
|
||||
element_id=f"logo_{page_num}_{region_idx}",
|
||||
type=ElementType.LOGO,
|
||||
content=image_data,
|
||||
bbox=bbox,
|
||||
confidence=0.9,
|
||||
metadata={
|
||||
"region_type": "inline_image_blocks",
|
||||
"block_count": len(image_blocks)
|
||||
}
|
||||
)
|
||||
target_page.elements.append(element)
|
||||
images_added += 1
|
||||
|
||||
pix = None # Free memory
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error rendering inline image region {region_idx}: {e}")
|
||||
|
||||
doc.close()
|
||||
|
||||
if images_added > 0:
|
||||
current_images = unified_doc.metadata.total_images or 0
|
||||
unified_doc.metadata.total_images = current_images + images_added
|
||||
logger.info(f"Added {images_added} inline image regions to document")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error rendering inline image regions: {e}")
|
||||
|
||||
return images_added
|
||||
|
||||
def _cluster_nearby_rects(self, rects: List[fitz.Rect], tolerance: float = 5.0) -> List[fitz.Rect]:
|
||||
"""Cluster nearby rectangles into regions."""
|
||||
if not rects:
|
||||
return []
|
||||
|
||||
sorted_rects = sorted(rects, key=lambda r: (r.y0, r.x0))
|
||||
|
||||
merged = []
|
||||
for rect in sorted_rects:
|
||||
merged_with_existing = False
|
||||
for i, region in enumerate(merged):
|
||||
expanded = region + (-tolerance, -tolerance, tolerance, tolerance)
|
||||
if expanded.intersects(rect):
|
||||
merged[i] = region | rect
|
||||
merged_with_existing = True
|
||||
break
|
||||
if not merged_with_existing:
|
||||
merged.append(rect)
|
||||
|
||||
# Second pass: merge any regions that now overlap
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
new_merged = []
|
||||
skip = set()
|
||||
|
||||
for i, r1 in enumerate(merged):
|
||||
if i in skip:
|
||||
continue
|
||||
current = r1
|
||||
for j, r2 in enumerate(merged[i+1:], start=i+1):
|
||||
if j in skip:
|
||||
continue
|
||||
expanded = current + (-tolerance, -tolerance, tolerance, tolerance)
|
||||
if expanded.intersects(r2):
|
||||
current = current | r2
|
||||
skip.add(j)
|
||||
changed = True
|
||||
new_merged.append(current)
|
||||
merged = new_merged
|
||||
|
||||
return merged
|
||||
|
||||
def _extract_vector_graphics(self,
|
||||
page: fitz.Page,
|
||||
page_num: int,
|
||||
document_id: str,
|
||||
counter: int,
|
||||
output_dir: Optional[Path]) -> List[DocumentElement]:
|
||||
output_dir: Optional[Path],
|
||||
table_bboxes: Optional[List[BoundingBox]] = None) -> List[DocumentElement]:
|
||||
"""
|
||||
Extract vector graphics (charts, diagrams) from page.
|
||||
|
||||
@@ -1020,6 +1301,7 @@ class DirectExtractionEngine:
|
||||
document_id: Unique document identifier
|
||||
counter: Starting counter for element IDs
|
||||
output_dir: Directory to save rendered graphics
|
||||
table_bboxes: List of table bounding boxes to exclude table border drawings
|
||||
|
||||
Returns:
|
||||
List of DocumentElement objects representing vector graphics
|
||||
@@ -1034,16 +1316,25 @@ class DirectExtractionEngine:
|
||||
|
||||
logger.debug(f"Page {page_num} contains {len(drawings)} vector drawing commands")
|
||||
|
||||
# Filter out drawings that are likely table borders
|
||||
# Table borders are typically thin rectangular lines within table regions
|
||||
non_table_drawings = self._filter_table_border_drawings(drawings, table_bboxes)
|
||||
logger.debug(f"After filtering table borders: {len(non_table_drawings)} drawings remain")
|
||||
|
||||
if not non_table_drawings:
|
||||
logger.debug("All drawings appear to be table borders, no vector graphics to extract")
|
||||
return elements
|
||||
|
||||
# Cluster drawings into groups (charts, diagrams, etc.)
|
||||
try:
|
||||
# PyMuPDF's cluster_drawings() groups nearby drawings automatically
|
||||
drawing_clusters = page.cluster_drawings()
|
||||
# Use custom clustering that only considers non-table drawings
|
||||
drawing_clusters = self._cluster_non_table_drawings(page, non_table_drawings)
|
||||
logger.debug(f"Clustered into {len(drawing_clusters)} groups")
|
||||
except (AttributeError, TypeError) as e:
|
||||
# cluster_drawings not available or has different signature
|
||||
# Fallback: try to identify charts by analyzing drawing density
|
||||
logger.warning(f"cluster_drawings() failed ({e}), using fallback method")
|
||||
drawing_clusters = self._cluster_drawings_fallback(page, drawings)
|
||||
logger.warning(f"Custom clustering failed ({e}), using fallback method")
|
||||
drawing_clusters = self._cluster_drawings_fallback(page, non_table_drawings)
|
||||
|
||||
for cluster_idx, bbox in enumerate(drawing_clusters):
|
||||
# Ignore small regions (likely noise or separator lines)
|
||||
@@ -1148,6 +1439,124 @@ class DirectExtractionEngine:
|
||||
|
||||
return filtered_clusters
|
||||
|
||||
def _filter_table_border_drawings(self, drawings: list, table_bboxes: Optional[List[BoundingBox]]) -> list:
|
||||
"""
|
||||
Filter out drawings that are likely table borders.
|
||||
|
||||
Table borders are typically:
|
||||
- Thin rectangular lines (height or width < 5pt)
|
||||
- Located within or on the edge of table bounding boxes
|
||||
|
||||
Args:
|
||||
drawings: List of PyMuPDF drawing objects
|
||||
table_bboxes: List of table bounding boxes
|
||||
|
||||
Returns:
|
||||
List of drawings that are NOT table borders (likely logos, charts, etc.)
|
||||
"""
|
||||
if not table_bboxes:
|
||||
return drawings
|
||||
|
||||
non_table_drawings = []
|
||||
table_border_count = 0
|
||||
|
||||
for drawing in drawings:
|
||||
rect = drawing.get('rect')
|
||||
if not rect:
|
||||
continue
|
||||
|
||||
draw_rect = fitz.Rect(rect)
|
||||
|
||||
# Check if this drawing is a thin line (potential table border)
|
||||
is_thin_line = draw_rect.width < 5 or draw_rect.height < 5
|
||||
|
||||
# Check if drawing overlaps significantly with any table
|
||||
overlaps_table = False
|
||||
for table_bbox in table_bboxes:
|
||||
table_rect = fitz.Rect(table_bbox.x0, table_bbox.y0, table_bbox.x1, table_bbox.y1)
|
||||
|
||||
# Expand table rect slightly to include border lines on edges
|
||||
expanded_table = table_rect + (-5, -5, 5, 5)
|
||||
|
||||
if expanded_table.contains(draw_rect) or expanded_table.intersects(draw_rect):
|
||||
# Calculate overlap ratio
|
||||
intersection = draw_rect & expanded_table
|
||||
if not intersection.is_empty:
|
||||
overlap_ratio = intersection.get_area() / draw_rect.get_area() if draw_rect.get_area() > 0 else 0
|
||||
|
||||
# If drawing is mostly inside table region, it's likely a border
|
||||
if overlap_ratio > 0.8:
|
||||
overlaps_table = True
|
||||
break
|
||||
|
||||
# Keep drawing if it's NOT (thin line AND overlapping table)
|
||||
# This keeps: logos (complex shapes), charts outside tables, etc.
|
||||
if is_thin_line and overlaps_table:
|
||||
table_border_count += 1
|
||||
else:
|
||||
non_table_drawings.append(drawing)
|
||||
|
||||
if table_border_count > 0:
|
||||
logger.debug(f"Filtered out {table_border_count} table border drawings")
|
||||
|
||||
return non_table_drawings
|
||||
|
||||
def _cluster_non_table_drawings(self, page: fitz.Page, drawings: list) -> list:
|
||||
"""
|
||||
Cluster non-table drawings into groups.
|
||||
|
||||
This method clusters drawings that have been pre-filtered to exclude table borders.
|
||||
It uses a more conservative clustering approach suitable for logos and charts.
|
||||
|
||||
Args:
|
||||
page: PyMuPDF page object
|
||||
drawings: Pre-filtered list of drawings (excluding table borders)
|
||||
|
||||
Returns:
|
||||
List of fitz.Rect representing clustered drawing regions
|
||||
"""
|
||||
if not drawings:
|
||||
return []
|
||||
|
||||
# Collect all drawing bounding boxes
|
||||
bboxes = []
|
||||
for drawing in drawings:
|
||||
rect = drawing.get('rect')
|
||||
if rect:
|
||||
bboxes.append(fitz.Rect(rect))
|
||||
|
||||
if not bboxes:
|
||||
return []
|
||||
|
||||
# More conservative clustering with smaller tolerance
|
||||
# This prevents grouping distant graphics together
|
||||
clusters = []
|
||||
tolerance = 10 # Smaller tolerance than fallback (was 20)
|
||||
|
||||
for bbox in bboxes:
|
||||
# Try to merge with existing cluster
|
||||
merged = False
|
||||
for i, cluster in enumerate(clusters):
|
||||
# Check if bbox is close to this cluster
|
||||
expanded_cluster = cluster + (-tolerance, -tolerance, tolerance, tolerance)
|
||||
if expanded_cluster.intersects(bbox):
|
||||
# Merge bbox into cluster
|
||||
clusters[i] = cluster | bbox # Union of rectangles
|
||||
merged = True
|
||||
break
|
||||
|
||||
if not merged:
|
||||
# Create new cluster
|
||||
clusters.append(bbox)
|
||||
|
||||
# Filter out very small clusters (noise)
|
||||
# Keep minimum 30x30 for logos (smaller than default 50x50)
|
||||
filtered_clusters = [c for c in clusters if c.width >= 30 and c.height >= 30]
|
||||
|
||||
logger.debug(f"Non-table clustering: {len(bboxes)} drawings -> {len(clusters)} clusters -> {len(filtered_clusters)} filtered")
|
||||
|
||||
return filtered_clusters
|
||||
|
||||
def _deduplicate_table_chart_overlap(self, elements: List[DocumentElement]) -> List[DocumentElement]:
|
||||
"""
|
||||
Intelligently resolve TABLE-CHART overlaps based on table structure completeness.
|
||||
|
||||
2269
backend/app/services/memory_manager.py
Normal file
2269
backend/app/services/memory_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -25,6 +25,7 @@ except ImportError:
|
||||
|
||||
from app.core.config import settings
|
||||
from app.services.office_converter import OfficeConverter, OfficeConverterError
|
||||
from app.services.memory_manager import get_model_manager, MemoryConfig, MemoryGuard, prediction_context
|
||||
|
||||
# Import dual-track components
|
||||
try:
|
||||
@@ -96,6 +97,26 @@ class OCRService:
|
||||
self._model_last_used = {} # Track last usage time for each model
|
||||
self._memory_warning_logged = False
|
||||
|
||||
# Initialize MemoryGuard for enhanced memory monitoring
|
||||
self._memory_guard = None
|
||||
if settings.enable_model_lifecycle_management:
|
||||
try:
|
||||
memory_config = MemoryConfig(
|
||||
warning_threshold=settings.memory_warning_threshold,
|
||||
critical_threshold=settings.memory_critical_threshold,
|
||||
emergency_threshold=settings.memory_emergency_threshold,
|
||||
model_idle_timeout_seconds=settings.pp_structure_idle_timeout_seconds,
|
||||
gpu_memory_limit_mb=settings.gpu_memory_limit_mb,
|
||||
enable_cpu_fallback=settings.enable_cpu_fallback,
|
||||
)
|
||||
self._memory_guard = MemoryGuard(memory_config)
|
||||
logger.debug("MemoryGuard initialized for OCRService")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize MemoryGuard: {e}")
|
||||
|
||||
# Track if CPU fallback was activated
|
||||
self._cpu_fallback_active = False
|
||||
|
||||
self._detect_and_configure_gpu()
|
||||
|
||||
# Log GPU optimization settings
|
||||
@@ -217,53 +238,91 @@ class OCRService:
|
||||
def _check_gpu_memory_usage(self):
|
||||
"""
|
||||
Check GPU memory usage and log warnings if approaching limits.
|
||||
Implements memory optimization for RTX 4060 8GB.
|
||||
Uses MemoryGuard for enhanced monitoring with multiple backends.
|
||||
"""
|
||||
if not self.use_gpu or not settings.enable_memory_optimization:
|
||||
return
|
||||
|
||||
try:
|
||||
device_id = self.gpu_info.get('device_id', 0)
|
||||
memory_allocated = paddle.device.cuda.memory_allocated(device_id)
|
||||
memory_allocated_mb = memory_allocated / (1024**2)
|
||||
memory_limit_mb = settings.gpu_memory_limit_mb
|
||||
# Use MemoryGuard if available for better monitoring
|
||||
if self._memory_guard:
|
||||
stats = self._memory_guard.get_memory_stats()
|
||||
|
||||
utilization = (memory_allocated_mb / memory_limit_mb * 100) if memory_limit_mb > 0 else 0
|
||||
# Log based on usage ratio
|
||||
if stats.gpu_used_ratio > 0.90 and not self._memory_warning_logged:
|
||||
logger.warning(
|
||||
f"GPU memory usage critical: {stats.gpu_used_mb:.0f}MB / {stats.gpu_total_mb:.0f}MB "
|
||||
f"({stats.gpu_used_ratio*100:.1f}%)"
|
||||
)
|
||||
logger.warning("Consider enabling auto_unload_unused_models or reducing batch size")
|
||||
self._memory_warning_logged = True
|
||||
|
||||
if utilization > 90 and not self._memory_warning_logged:
|
||||
logger.warning(f"GPU memory usage high: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
|
||||
logger.warning("Consider enabling auto_unload_unused_models or reducing batch size")
|
||||
self._memory_warning_logged = True
|
||||
elif utilization > 75:
|
||||
logger.info(f"GPU memory: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
|
||||
# Trigger emergency cleanup if enabled
|
||||
if settings.enable_emergency_cleanup:
|
||||
self._cleanup_unused_models()
|
||||
self._memory_guard.clear_gpu_cache()
|
||||
|
||||
elif stats.gpu_used_ratio > 0.75:
|
||||
logger.info(
|
||||
f"GPU memory: {stats.gpu_used_mb:.0f}MB / {stats.gpu_total_mb:.0f}MB "
|
||||
f"({stats.gpu_used_ratio*100:.1f}%)"
|
||||
)
|
||||
else:
|
||||
# Fallback to original implementation
|
||||
device_id = self.gpu_info.get('device_id', 0)
|
||||
memory_allocated = paddle.device.cuda.memory_allocated(device_id)
|
||||
memory_allocated_mb = memory_allocated / (1024**2)
|
||||
memory_limit_mb = settings.gpu_memory_limit_mb
|
||||
|
||||
utilization = (memory_allocated_mb / memory_limit_mb * 100) if memory_limit_mb > 0 else 0
|
||||
|
||||
if utilization > 90 and not self._memory_warning_logged:
|
||||
logger.warning(f"GPU memory usage high: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
|
||||
logger.warning("Consider enabling auto_unload_unused_models or reducing batch size")
|
||||
self._memory_warning_logged = True
|
||||
elif utilization > 75:
|
||||
logger.info(f"GPU memory: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Memory check failed: {e}")
|
||||
|
||||
def _cleanup_unused_models(self):
|
||||
"""
|
||||
Clean up unused language models to free GPU memory.
|
||||
Clean up unused models (including PP-StructureV3) to free GPU memory.
|
||||
Models idle longer than model_idle_timeout_seconds will be unloaded.
|
||||
|
||||
Note: PP-StructureV3 is NO LONGER exempted from cleanup - it will be
|
||||
unloaded based on pp_structure_idle_timeout_seconds configuration.
|
||||
"""
|
||||
if not settings.auto_unload_unused_models:
|
||||
return
|
||||
|
||||
current_time = datetime.now()
|
||||
timeout = settings.model_idle_timeout_seconds
|
||||
models_to_remove = []
|
||||
|
||||
for lang, last_used in self._model_last_used.items():
|
||||
if lang == 'structure': # Don't unload structure engine
|
||||
continue
|
||||
# Use different timeout for structure engine vs language models
|
||||
if lang == 'structure':
|
||||
timeout = settings.pp_structure_idle_timeout_seconds
|
||||
else:
|
||||
timeout = settings.model_idle_timeout_seconds
|
||||
|
||||
idle_seconds = (current_time - last_used).total_seconds()
|
||||
if idle_seconds > timeout:
|
||||
models_to_remove.append(lang)
|
||||
|
||||
for lang in models_to_remove:
|
||||
if lang in self.ocr_engines:
|
||||
logger.info(f"Unloading idle OCR engine for {lang} (idle {timeout}s)")
|
||||
del self.ocr_engines[lang]
|
||||
del self._model_last_used[lang]
|
||||
for model_key in models_to_remove:
|
||||
if model_key == 'structure':
|
||||
if self.structure_engine is not None:
|
||||
logger.info(f"Unloading idle PP-StructureV3 engine (idle {settings.pp_structure_idle_timeout_seconds}s)")
|
||||
self._unload_structure_engine()
|
||||
if model_key in self._model_last_used:
|
||||
del self._model_last_used[model_key]
|
||||
elif model_key in self.ocr_engines:
|
||||
logger.info(f"Unloading idle OCR engine for {model_key} (idle {settings.model_idle_timeout_seconds}s)")
|
||||
del self.ocr_engines[model_key]
|
||||
if model_key in self._model_last_used:
|
||||
del self._model_last_used[model_key]
|
||||
|
||||
if models_to_remove and self.use_gpu:
|
||||
# Clear CUDA cache
|
||||
@@ -273,6 +332,41 @@ class OCRService:
|
||||
except Exception as e:
|
||||
logger.debug(f"Cache clear failed: {e}")
|
||||
|
||||
def _unload_structure_engine(self):
|
||||
"""
|
||||
Properly unload PP-StructureV3 engine and free GPU memory.
|
||||
"""
|
||||
if self.structure_engine is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# Clear internal engine components
|
||||
if hasattr(self.structure_engine, 'table_engine'):
|
||||
self.structure_engine.table_engine = None
|
||||
if hasattr(self.structure_engine, 'text_detector'):
|
||||
self.structure_engine.text_detector = None
|
||||
if hasattr(self.structure_engine, 'text_recognizer'):
|
||||
self.structure_engine.text_recognizer = None
|
||||
if hasattr(self.structure_engine, 'layout_predictor'):
|
||||
self.structure_engine.layout_predictor = None
|
||||
|
||||
# Delete the engine
|
||||
del self.structure_engine
|
||||
self.structure_engine = None
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
|
||||
# Clear GPU cache
|
||||
if self.use_gpu:
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
logger.info("PP-StructureV3 engine unloaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error unloading PP-StructureV3: {e}")
|
||||
self.structure_engine = None
|
||||
|
||||
def clear_gpu_cache(self):
|
||||
"""
|
||||
Manually clear GPU memory cache.
|
||||
@@ -519,46 +613,160 @@ class OCRService:
|
||||
logger.warning(f"GPU memory cleanup failed (non-critical): {e}")
|
||||
# Don't fail the processing if cleanup fails
|
||||
|
||||
def check_gpu_memory(self, required_mb: int = 2000) -> bool:
|
||||
def check_gpu_memory(self, required_mb: int = 2000, enable_fallback: bool = True) -> bool:
|
||||
"""
|
||||
Check if sufficient GPU memory is available.
|
||||
Check if sufficient GPU memory is available using MemoryGuard.
|
||||
|
||||
This method now uses MemoryGuard for accurate memory queries across
|
||||
multiple backends (pynvml, torch, paddle) instead of returning True
|
||||
blindly for PaddlePaddle-only environments.
|
||||
|
||||
Args:
|
||||
required_mb: Required memory in MB (default 2000MB for OCR models)
|
||||
enable_fallback: If True and CPU fallback is enabled, switch to CPU mode
|
||||
when memory is insufficient instead of returning False
|
||||
|
||||
Returns:
|
||||
True if sufficient memory is available or GPU is not used
|
||||
True if sufficient memory is available, GPU is not used, or CPU fallback activated
|
||||
"""
|
||||
try:
|
||||
# Check GPU memory using torch if available, otherwise use PaddlePaddle
|
||||
free_memory = None
|
||||
# If not using GPU, always return True
|
||||
if not self.use_gpu:
|
||||
return True
|
||||
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
free_memory = torch.cuda.mem_get_info()[0] / 1024**2
|
||||
elif paddle.device.is_compiled_with_cuda():
|
||||
# PaddlePaddle doesn't have direct API to get free memory,
|
||||
# so we rely on cleanup and continue
|
||||
logger.debug("Using PaddlePaddle GPU, memory info not directly available")
|
||||
try:
|
||||
# Use MemoryGuard if available for accurate multi-backend memory queries
|
||||
if self._memory_guard:
|
||||
is_available, stats = self._memory_guard.check_memory(
|
||||
required_mb=required_mb,
|
||||
device_id=self.gpu_info.get('device_id', 0)
|
||||
)
|
||||
|
||||
if not is_available:
|
||||
logger.warning(
|
||||
f"GPU memory check failed: {stats.gpu_free_mb:.0f}MB free, "
|
||||
f"{required_mb}MB required ({stats.gpu_used_ratio*100:.1f}% used)"
|
||||
)
|
||||
|
||||
# Try to free memory
|
||||
logger.info("Attempting memory cleanup before retry...")
|
||||
self._cleanup_unused_models()
|
||||
self._memory_guard.clear_gpu_cache()
|
||||
|
||||
# Check again
|
||||
is_available, stats = self._memory_guard.check_memory(required_mb=required_mb)
|
||||
|
||||
if not is_available:
|
||||
# Memory still insufficient after cleanup
|
||||
if enable_fallback and settings.enable_cpu_fallback:
|
||||
logger.warning(
|
||||
f"Insufficient GPU memory ({stats.gpu_free_mb:.0f}MB) after cleanup. "
|
||||
f"Activating CPU fallback mode."
|
||||
)
|
||||
self._activate_cpu_fallback()
|
||||
return True # Continue with CPU
|
||||
else:
|
||||
logger.error(
|
||||
f"Insufficient GPU memory: {stats.gpu_free_mb:.0f}MB available, "
|
||||
f"{required_mb}MB required"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(
|
||||
f"GPU memory check passed: {stats.gpu_free_mb:.0f}MB free "
|
||||
f"({stats.gpu_used_ratio*100:.1f}% used)"
|
||||
)
|
||||
return True
|
||||
|
||||
if free_memory is not None:
|
||||
if free_memory < required_mb:
|
||||
logger.warning(f"Low GPU memory: {free_memory:.0f}MB available, {required_mb}MB required")
|
||||
# Try to free memory
|
||||
self.cleanup_gpu_memory()
|
||||
# Check again
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
free_memory = torch.cuda.mem_get_info()[0] / 1024**2
|
||||
if free_memory < required_mb:
|
||||
logger.error(f"Insufficient GPU memory after cleanup: {free_memory:.0f}MB")
|
||||
return False
|
||||
logger.debug(f"GPU memory check passed: {free_memory:.0f}MB available")
|
||||
else:
|
||||
# Fallback to original implementation
|
||||
free_memory = None
|
||||
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
free_memory = torch.cuda.mem_get_info()[0] / 1024**2
|
||||
elif paddle.device.is_compiled_with_cuda():
|
||||
# PaddlePaddle doesn't have direct API to get free memory,
|
||||
# use allocated memory to estimate
|
||||
device_id = self.gpu_info.get('device_id', 0)
|
||||
allocated = paddle.device.cuda.memory_allocated(device_id) / (1024**2)
|
||||
total = settings.gpu_memory_limit_mb
|
||||
free_memory = max(0, total - allocated)
|
||||
logger.debug(f"Estimated free GPU memory: {free_memory:.0f}MB (total: {total}MB, allocated: {allocated:.0f}MB)")
|
||||
|
||||
if free_memory is not None:
|
||||
if free_memory < required_mb:
|
||||
logger.warning(f"Low GPU memory: {free_memory:.0f}MB available, {required_mb}MB required")
|
||||
self.cleanup_gpu_memory()
|
||||
|
||||
# Recheck
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
free_memory = torch.cuda.mem_get_info()[0] / 1024**2
|
||||
else:
|
||||
allocated = paddle.device.cuda.memory_allocated(device_id) / (1024**2)
|
||||
free_memory = max(0, total - allocated)
|
||||
|
||||
if free_memory < required_mb:
|
||||
if enable_fallback and settings.enable_cpu_fallback:
|
||||
logger.warning(f"Insufficient GPU memory after cleanup. Activating CPU fallback.")
|
||||
self._activate_cpu_fallback()
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Insufficient GPU memory after cleanup: {free_memory:.0f}MB")
|
||||
return False
|
||||
|
||||
logger.debug(f"GPU memory check passed: {free_memory:.0f}MB available")
|
||||
|
||||
return True
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"GPU memory check failed: {e}")
|
||||
return True # Continue processing even if check fails
|
||||
|
||||
def _activate_cpu_fallback(self):
|
||||
"""
|
||||
Activate CPU fallback mode when GPU memory is insufficient.
|
||||
This disables GPU usage for the current service instance.
|
||||
"""
|
||||
if self._cpu_fallback_active:
|
||||
return # Already in CPU mode
|
||||
|
||||
logger.warning("=== CPU FALLBACK MODE ACTIVATED ===")
|
||||
logger.warning("GPU memory insufficient, switching to CPU processing")
|
||||
logger.warning("Performance will be significantly reduced")
|
||||
|
||||
self._cpu_fallback_active = True
|
||||
self.use_gpu = False
|
||||
|
||||
# Update GPU info to reflect fallback
|
||||
self.gpu_info['cpu_fallback'] = True
|
||||
self.gpu_info['fallback_reason'] = 'GPU memory insufficient'
|
||||
|
||||
# Clear GPU cache to free memory
|
||||
if self._memory_guard:
|
||||
self._memory_guard.clear_gpu_cache()
|
||||
|
||||
def _restore_gpu_mode(self):
|
||||
"""
|
||||
Attempt to restore GPU mode after CPU fallback.
|
||||
Called when memory pressure has been relieved.
|
||||
"""
|
||||
if not self._cpu_fallback_active:
|
||||
return
|
||||
|
||||
if not self.gpu_available:
|
||||
return
|
||||
|
||||
# Check if GPU memory is now available
|
||||
if self._memory_guard:
|
||||
is_available, stats = self._memory_guard.check_memory(
|
||||
required_mb=settings.structure_model_memory_mb
|
||||
)
|
||||
if is_available:
|
||||
logger.info("GPU memory available, restoring GPU mode")
|
||||
self._cpu_fallback_active = False
|
||||
self.use_gpu = True
|
||||
self.gpu_info.pop('cpu_fallback', None)
|
||||
self.gpu_info.pop('fallback_reason', None)
|
||||
|
||||
def convert_pdf_to_images(self, pdf_path: Path, output_dir: Path) -> List[Path]:
|
||||
"""
|
||||
Convert PDF to images (one per page)
|
||||
@@ -626,6 +834,24 @@ class OCRService:
|
||||
threshold = confidence_threshold if confidence_threshold is not None else self.confidence_threshold
|
||||
|
||||
try:
|
||||
# Pre-operation memory check: Try to restore GPU if in fallback and memory available
|
||||
if self._cpu_fallback_active:
|
||||
self._restore_gpu_mode()
|
||||
if not self._cpu_fallback_active:
|
||||
logger.info("GPU mode restored for processing")
|
||||
|
||||
# Initial memory check before starting any heavy processing
|
||||
# Estimate memory requirement based on image type
|
||||
estimated_memory_mb = 2500 # Conservative estimate for full OCR + layout
|
||||
if detect_layout:
|
||||
estimated_memory_mb += 500 # Additional for PP-StructureV3
|
||||
|
||||
if not self.check_gpu_memory(required_mb=estimated_memory_mb, enable_fallback=True):
|
||||
logger.warning(
|
||||
f"Pre-operation memory check failed ({estimated_memory_mb}MB required). "
|
||||
f"Processing will attempt to proceed but may encounter issues."
|
||||
)
|
||||
|
||||
# Check if file is Office document
|
||||
if self.office_converter.is_office_document(image_path):
|
||||
logger.info(f"Detected Office document: {image_path.name}, converting to PDF")
|
||||
@@ -748,9 +974,12 @@ class OCRService:
|
||||
# Get OCR engine (for non-PDF images)
|
||||
ocr_engine = self.get_ocr_engine(lang)
|
||||
|
||||
# Check GPU memory before OCR processing
|
||||
if not self.check_gpu_memory(required_mb=1500):
|
||||
logger.warning("Insufficient GPU memory for OCR, attempting to proceed anyway")
|
||||
# Secondary memory check before OCR processing
|
||||
if not self.check_gpu_memory(required_mb=1500, enable_fallback=True):
|
||||
logger.warning(
|
||||
f"OCR memory check: insufficient GPU memory (1500MB required). "
|
||||
f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU (low memory)'}"
|
||||
)
|
||||
|
||||
# Get the actual image dimensions that OCR will use
|
||||
from PIL import Image
|
||||
@@ -950,6 +1179,18 @@ class OCRService:
|
||||
Tuple of (layout_data, images_metadata)
|
||||
"""
|
||||
try:
|
||||
# Pre-operation memory check for layout analysis
|
||||
if self._cpu_fallback_active:
|
||||
self._restore_gpu_mode()
|
||||
if not self._cpu_fallback_active:
|
||||
logger.info("GPU mode restored for layout analysis")
|
||||
|
||||
if not self.check_gpu_memory(required_mb=2000, enable_fallback=True):
|
||||
logger.warning(
|
||||
f"Layout analysis pre-check: insufficient GPU memory (2000MB required). "
|
||||
f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU'}"
|
||||
)
|
||||
|
||||
structure_engine = self._ensure_structure_engine(pp_structure_params)
|
||||
|
||||
# Try enhanced processing first
|
||||
@@ -998,11 +1239,21 @@ class OCRService:
|
||||
# Standard processing (original implementation)
|
||||
logger.info(f"Running standard layout analysis on {image_path.name}")
|
||||
|
||||
# Check GPU memory before processing
|
||||
if not self.check_gpu_memory(required_mb=2000):
|
||||
logger.warning("Insufficient GPU memory for PP-StructureV3, attempting to proceed anyway")
|
||||
# Memory check before PP-StructureV3 processing
|
||||
if not self.check_gpu_memory(required_mb=2000, enable_fallback=True):
|
||||
logger.warning(
|
||||
f"PP-StructureV3 memory check: insufficient GPU memory (2000MB required). "
|
||||
f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU (low memory)'}"
|
||||
)
|
||||
|
||||
results = structure_engine.predict(str(image_path))
|
||||
# Use prediction semaphore to control concurrent predictions
|
||||
# This prevents OOM errors from multiple simultaneous PP-StructureV3.predict() calls
|
||||
with prediction_context(timeout=settings.service_acquire_timeout_seconds) as acquired:
|
||||
if not acquired:
|
||||
logger.error("Failed to acquire prediction slot (timeout), returning empty layout")
|
||||
return None, []
|
||||
|
||||
results = structure_engine.predict(str(image_path))
|
||||
|
||||
layout_elements = []
|
||||
images_metadata = []
|
||||
@@ -1254,6 +1505,46 @@ class OCRService:
|
||||
if temp_pdf_path:
|
||||
unified_doc.metadata.original_filename = file_path.name
|
||||
|
||||
# HYBRID MODE: Check if Direct track missed images (e.g., inline image blocks)
|
||||
# If so, use OCR to extract images and merge them into the Direct result
|
||||
pages_with_missing_images = self.direct_extraction_engine.check_document_for_missing_images(
|
||||
actual_file_path
|
||||
)
|
||||
if pages_with_missing_images:
|
||||
logger.info(f"Hybrid mode: Direct track missing images on pages {pages_with_missing_images}, using OCR to extract images")
|
||||
try:
|
||||
# Run OCR on the file to extract images
|
||||
ocr_result = self.process_file_traditional(
|
||||
actual_file_path, lang, detect_layout=True,
|
||||
confidence_threshold=confidence_threshold,
|
||||
output_dir=output_dir, pp_structure_params=pp_structure_params
|
||||
)
|
||||
|
||||
# Convert OCR result to extract images
|
||||
ocr_unified = self.ocr_to_unified_converter.convert(
|
||||
ocr_result, actual_file_path, 0.0, lang
|
||||
)
|
||||
|
||||
# Merge OCR-extracted images into Direct track result
|
||||
images_added = self._merge_ocr_images_into_direct(
|
||||
unified_doc, ocr_unified, pages_with_missing_images
|
||||
)
|
||||
if images_added > 0:
|
||||
logger.info(f"Hybrid mode: Added {images_added} images from OCR to Direct track result")
|
||||
unified_doc.metadata.processing_track = ProcessingTrack.HYBRID
|
||||
else:
|
||||
# Fallback: OCR didn't find images either, render inline image blocks directly
|
||||
logger.info("Hybrid mode: OCR didn't find images, falling back to inline image rendering")
|
||||
images_added = self.direct_extraction_engine.render_inline_image_regions(
|
||||
actual_file_path, unified_doc, pages_with_missing_images, output_dir
|
||||
)
|
||||
if images_added > 0:
|
||||
logger.info(f"Hybrid mode: Rendered {images_added} inline image regions")
|
||||
unified_doc.metadata.processing_track = ProcessingTrack.HYBRID
|
||||
except Exception as e:
|
||||
logger.warning(f"Hybrid mode image extraction failed: {e}")
|
||||
# Continue with Direct track result without images
|
||||
|
||||
# Use OCR track (either by recommendation or fallback)
|
||||
if recommendation.track == "ocr":
|
||||
# Use OCR for scanned documents, images, etc.
|
||||
@@ -1269,17 +1560,19 @@ class OCRService:
|
||||
)
|
||||
unified_doc.document_id = document_id
|
||||
|
||||
# Update processing track metadata
|
||||
unified_doc.metadata.processing_track = (
|
||||
ProcessingTrack.DIRECT if recommendation.track == "direct"
|
||||
else ProcessingTrack.OCR
|
||||
)
|
||||
# Update processing track metadata (only if not already set to HYBRID)
|
||||
if unified_doc.metadata.processing_track != ProcessingTrack.HYBRID:
|
||||
unified_doc.metadata.processing_track = (
|
||||
ProcessingTrack.DIRECT if recommendation.track == "direct"
|
||||
else ProcessingTrack.OCR
|
||||
)
|
||||
|
||||
# Calculate total processing time
|
||||
processing_time = (datetime.now() - start_time).total_seconds()
|
||||
unified_doc.metadata.processing_time = processing_time
|
||||
|
||||
logger.info(f"Document processing completed in {processing_time:.2f}s using {recommendation.track} track")
|
||||
actual_track = unified_doc.metadata.processing_track.value
|
||||
logger.info(f"Document processing completed in {processing_time:.2f}s using {actual_track} track")
|
||||
|
||||
return unified_doc
|
||||
|
||||
@@ -1290,6 +1583,75 @@ class OCRService:
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
|
||||
)
|
||||
|
||||
def _merge_ocr_images_into_direct(
|
||||
self,
|
||||
direct_doc: 'UnifiedDocument',
|
||||
ocr_doc: 'UnifiedDocument',
|
||||
pages_with_missing_images: List[int]
|
||||
) -> int:
|
||||
"""
|
||||
Merge OCR-extracted images into Direct track result.
|
||||
|
||||
This is used in hybrid mode when Direct track couldn't extract certain
|
||||
images (like logos composed of inline image blocks).
|
||||
|
||||
Args:
|
||||
direct_doc: UnifiedDocument from Direct track
|
||||
ocr_doc: UnifiedDocument from OCR track
|
||||
pages_with_missing_images: List of page numbers (1-indexed) that need images
|
||||
|
||||
Returns:
|
||||
Number of images added
|
||||
"""
|
||||
images_added = 0
|
||||
|
||||
try:
|
||||
# Get image element types to look for
|
||||
image_types = {ElementType.FIGURE, ElementType.IMAGE, ElementType.LOGO}
|
||||
|
||||
for page_num in pages_with_missing_images:
|
||||
# Find the target page in direct_doc
|
||||
direct_page = None
|
||||
for page in direct_doc.pages:
|
||||
if page.page_number == page_num:
|
||||
direct_page = page
|
||||
break
|
||||
|
||||
if not direct_page:
|
||||
continue
|
||||
|
||||
# Find the source page in ocr_doc
|
||||
ocr_page = None
|
||||
for page in ocr_doc.pages:
|
||||
if page.page_number == page_num:
|
||||
ocr_page = page
|
||||
break
|
||||
|
||||
if not ocr_page:
|
||||
continue
|
||||
|
||||
# Extract image elements from OCR page
|
||||
for element in ocr_page.elements:
|
||||
if element.type in image_types:
|
||||
# Assign new element ID to avoid conflicts
|
||||
new_element_id = f"hybrid_{element.element_id}"
|
||||
element.element_id = new_element_id
|
||||
|
||||
# Add to direct page
|
||||
direct_page.elements.append(element)
|
||||
images_added += 1
|
||||
logger.debug(f"Added image element {new_element_id} to page {page_num}")
|
||||
|
||||
# Update image count in direct_doc metadata
|
||||
if images_added > 0:
|
||||
current_images = direct_doc.metadata.total_images or 0
|
||||
direct_doc.metadata.total_images = current_images + images_added
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error merging OCR images into Direct track: {e}")
|
||||
|
||||
return images_added
|
||||
|
||||
def process_file_traditional(
|
||||
self,
|
||||
file_path: Path,
|
||||
@@ -1441,13 +1803,16 @@ class OCRService:
|
||||
UnifiedDocument if dual-track is enabled and use_dual_track=True,
|
||||
Dict with legacy format otherwise
|
||||
"""
|
||||
if use_dual_track and self.dual_track_enabled:
|
||||
# Use dual-track processing
|
||||
# Use dual-track processing if:
|
||||
# 1. use_dual_track is True (auto-detection), OR
|
||||
# 2. force_track is specified (explicit track selection)
|
||||
if (use_dual_track or force_track) and self.dual_track_enabled:
|
||||
# Use dual-track processing (or forced track)
|
||||
return self.process_with_dual_track(
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, force_track, pp_structure_params
|
||||
)
|
||||
else:
|
||||
# Use traditional OCR processing
|
||||
# Use traditional OCR processing (no force_track support)
|
||||
return self.process_file_traditional(
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
|
||||
)
|
||||
|
||||
@@ -572,8 +572,10 @@ class PDFGeneratorService:
|
||||
processing_track = unified_doc.metadata.get('processing_track')
|
||||
|
||||
# Route to track-specific rendering method
|
||||
is_direct_track = (processing_track == 'direct' or
|
||||
processing_track == ProcessingTrack.DIRECT)
|
||||
# ProcessingTrack is (str, Enum), so comparing with enum value works for both string and enum
|
||||
# HYBRID track uses Direct track rendering (Direct text/tables + OCR images)
|
||||
is_direct_track = (processing_track == ProcessingTrack.DIRECT or
|
||||
processing_track == ProcessingTrack.HYBRID)
|
||||
|
||||
logger.info(f"Processing track: {processing_track}, using {'Direct' if is_direct_track else 'OCR'} track rendering")
|
||||
|
||||
@@ -675,8 +677,11 @@ class PDFGeneratorService:
|
||||
logger.info("=== Direct Track PDF Generation ===")
|
||||
logger.info(f"Total pages: {len(unified_doc.pages)}")
|
||||
|
||||
# Set current track for helper methods
|
||||
self.current_processing_track = 'direct'
|
||||
# Set current track for helper methods (may be DIRECT or HYBRID)
|
||||
if hasattr(unified_doc, 'metadata') and unified_doc.metadata:
|
||||
self.current_processing_track = unified_doc.metadata.processing_track
|
||||
else:
|
||||
self.current_processing_track = ProcessingTrack.DIRECT
|
||||
|
||||
# Get page dimensions from first page (for canvas initialization)
|
||||
if not unified_doc.pages:
|
||||
@@ -1074,11 +1079,16 @@ class PDFGeneratorService:
|
||||
# *** 優先級 1: 檢查 ocr_dimensions (UnifiedDocument 轉換來的) ***
|
||||
if 'ocr_dimensions' in ocr_data:
|
||||
dims = ocr_data['ocr_dimensions']
|
||||
w = float(dims.get('width', 0))
|
||||
h = float(dims.get('height', 0))
|
||||
if w > 0 and h > 0:
|
||||
logger.info(f"使用 ocr_dimensions 欄位的頁面尺寸: {w:.1f} x {h:.1f}")
|
||||
return (w, h)
|
||||
# Handle both dict format {'width': w, 'height': h} and
|
||||
# list format [{'page': 1, 'width': w, 'height': h}, ...]
|
||||
if isinstance(dims, list) and len(dims) > 0:
|
||||
dims = dims[0] # Use first page dimensions
|
||||
if isinstance(dims, dict):
|
||||
w = float(dims.get('width', 0))
|
||||
h = float(dims.get('height', 0))
|
||||
if w > 0 and h > 0:
|
||||
logger.info(f"使用 ocr_dimensions 欄位的頁面尺寸: {w:.1f} x {h:.1f}")
|
||||
return (w, h)
|
||||
|
||||
# *** 優先級 2: 檢查原始 JSON 的 dimensions ***
|
||||
if 'dimensions' in ocr_data:
|
||||
@@ -1418,8 +1428,8 @@ class PDFGeneratorService:
|
||||
# Set font with track-specific styling
|
||||
# Note: OCR track has no StyleInfo (extracted from images), so no advanced formatting
|
||||
style_info = region.get('style')
|
||||
is_direct_track = (self.current_processing_track == 'direct' or
|
||||
self.current_processing_track == ProcessingTrack.DIRECT)
|
||||
is_direct_track = (self.current_processing_track == ProcessingTrack.DIRECT or
|
||||
self.current_processing_track == ProcessingTrack.HYBRID)
|
||||
|
||||
if style_info and is_direct_track:
|
||||
# Direct track: Apply rich styling from StyleInfo
|
||||
@@ -1661,10 +1671,15 @@ class PDFGeneratorService:
|
||||
return
|
||||
|
||||
# Construct full path to image
|
||||
# saved_path is relative to result_dir (e.g., "imgs/element_id.png")
|
||||
image_path = result_dir / image_path_str
|
||||
|
||||
# Fallback for legacy data
|
||||
if not image_path.exists():
|
||||
logger.warning(f"Image not found: {image_path}")
|
||||
image_path = result_dir / Path(image_path_str).name
|
||||
|
||||
if not image_path.exists():
|
||||
logger.warning(f"Image not found: {image_path_str} (in {result_dir})")
|
||||
return
|
||||
|
||||
# Get bbox for positioning
|
||||
@@ -2289,12 +2304,30 @@ class PDFGeneratorService:
|
||||
col_widths = element.metadata['column_widths']
|
||||
logger.debug(f"Using extracted column widths: {col_widths}")
|
||||
|
||||
# Create table without rowHeights (will use canvas scaling instead)
|
||||
t = Table(table_content, colWidths=col_widths)
|
||||
# Use original row heights from extraction if available
|
||||
# Row heights must match the number of data rows exactly
|
||||
row_heights_list = None
|
||||
if element.metadata and 'row_heights' in element.metadata:
|
||||
extracted_row_heights = element.metadata['row_heights']
|
||||
num_data_rows = len(table_content)
|
||||
num_height_rows = len(extracted_row_heights)
|
||||
|
||||
if num_height_rows == num_data_rows:
|
||||
row_heights_list = extracted_row_heights
|
||||
logger.debug(f"Using extracted row heights ({num_height_rows} rows): {row_heights_list}")
|
||||
else:
|
||||
# Row counts don't match - this can happen with merged cells or empty rows
|
||||
logger.warning(f"Row height mismatch: {num_height_rows} heights for {num_data_rows} data rows, falling back to auto-sizing")
|
||||
|
||||
# Create table with both column widths and row heights for accurate sizing
|
||||
t = Table(table_content, colWidths=col_widths, rowHeights=row_heights_list)
|
||||
|
||||
# Apply style with minimal padding to reduce table extension
|
||||
# Use Chinese font to support special characters (℃, μm, ≦, ×, Ω, etc.)
|
||||
font_for_table = self.font_name if self.font_registered else 'Helvetica'
|
||||
style = TableStyle([
|
||||
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
|
||||
('FONTNAME', (0, 0), (-1, -1), font_for_table),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 8),
|
||||
('ALIGN', (0, 0), (-1, -1), 'LEFT'),
|
||||
('VALIGN', (0, 0), (-1, -1), 'TOP'),
|
||||
@@ -2307,8 +2340,8 @@ class PDFGeneratorService:
|
||||
])
|
||||
t.setStyle(style)
|
||||
|
||||
# CRITICAL: Use canvas scaling to fit table within bbox
|
||||
# This is more reliable than rowHeights which doesn't always work
|
||||
# Use canvas scaling as fallback to fit table within bbox
|
||||
# With proper row heights, scaling should be minimal (close to 1.0)
|
||||
|
||||
# Step 1: Wrap to get actual rendered size
|
||||
actual_width, actual_height = t.wrapOn(pdf_canvas, table_width * 10, table_height * 10)
|
||||
@@ -2358,11 +2391,16 @@ class PDFGeneratorService:
|
||||
logger.warning(f"No image path for element {element.element_id}")
|
||||
return
|
||||
|
||||
# Construct full path
|
||||
# Construct full path to image
|
||||
# saved_path is relative to result_dir (e.g., "document_id_p1_img0.png")
|
||||
image_path = result_dir / image_path_str
|
||||
|
||||
# Fallback for legacy data
|
||||
if not image_path.exists():
|
||||
logger.warning(f"Image not found: {image_path}")
|
||||
image_path = result_dir / Path(image_path_str).name
|
||||
|
||||
if not image_path.exists():
|
||||
logger.warning(f"Image not found: {image_path_str} (in {result_dir})")
|
||||
return
|
||||
|
||||
# Get bbox
|
||||
@@ -2388,7 +2426,7 @@ class PDFGeneratorService:
|
||||
preserveAspectRatio=True
|
||||
)
|
||||
|
||||
logger.debug(f"Drew image: {image_path_str}")
|
||||
logger.debug(f"Drew image: {image_path} (from: {original_path_str})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to draw image element {element.element_id}: {e}")
|
||||
|
||||
@@ -21,6 +21,8 @@ except ImportError:
|
||||
import paddle
|
||||
from paddleocr import PPStructureV3
|
||||
from app.models.unified_document import ElementType
|
||||
from app.core.config import settings
|
||||
from app.services.memory_manager import prediction_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,8 +98,22 @@ class PPStructureEnhanced:
|
||||
try:
|
||||
logger.info(f"Enhanced PP-StructureV3 analysis on {image_path.name}")
|
||||
|
||||
# Perform structure analysis
|
||||
results = self.structure_engine.predict(str(image_path))
|
||||
# Perform structure analysis with semaphore control
|
||||
# This prevents OOM errors from multiple simultaneous predictions
|
||||
with prediction_context(timeout=settings.service_acquire_timeout_seconds) as acquired:
|
||||
if not acquired:
|
||||
logger.error("Failed to acquire prediction slot (timeout), returning empty result")
|
||||
return {
|
||||
'has_parsing_res_list': False,
|
||||
'elements': [],
|
||||
'total_elements': 0,
|
||||
'images': [],
|
||||
'tables': [],
|
||||
'element_types': {},
|
||||
'error': 'Prediction slot timeout'
|
||||
}
|
||||
|
||||
results = self.structure_engine.predict(str(image_path))
|
||||
|
||||
all_elements = []
|
||||
all_images = []
|
||||
|
||||
468
backend/app/services/service_pool.py
Normal file
468
backend/app/services/service_pool.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""
|
||||
Tool_OCR - OCR Service Pool
|
||||
Manages a pool of OCRService instances to prevent duplicate model loading
|
||||
and control concurrent GPU operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from app.services.memory_manager import get_model_manager, MemoryConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.ocr_service import OCRService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceState(Enum):
|
||||
"""State of a pooled service"""
|
||||
AVAILABLE = "available"
|
||||
IN_USE = "in_use"
|
||||
UNHEALTHY = "unhealthy"
|
||||
INITIALIZING = "initializing"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PooledService:
|
||||
"""Wrapper for a pooled OCRService instance"""
|
||||
service: Any # OCRService
|
||||
device: str
|
||||
state: ServiceState = ServiceState.AVAILABLE
|
||||
created_at: float = field(default_factory=time.time)
|
||||
last_used: float = field(default_factory=time.time)
|
||||
use_count: int = 0
|
||||
error_count: int = 0
|
||||
current_task_id: Optional[str] = None
|
||||
|
||||
|
||||
class PoolConfig:
|
||||
"""Configuration for the service pool"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_services_per_device: int = 1,
|
||||
max_total_services: int = 2,
|
||||
acquire_timeout_seconds: float = 300.0,
|
||||
max_queue_size: int = 50,
|
||||
health_check_interval_seconds: int = 60,
|
||||
max_consecutive_errors: int = 3,
|
||||
service_idle_timeout_seconds: int = 600,
|
||||
enable_auto_scaling: bool = False,
|
||||
):
|
||||
self.max_services_per_device = max_services_per_device
|
||||
self.max_total_services = max_total_services
|
||||
self.acquire_timeout_seconds = acquire_timeout_seconds
|
||||
self.max_queue_size = max_queue_size
|
||||
self.health_check_interval_seconds = health_check_interval_seconds
|
||||
self.max_consecutive_errors = max_consecutive_errors
|
||||
self.service_idle_timeout_seconds = service_idle_timeout_seconds
|
||||
self.enable_auto_scaling = enable_auto_scaling
|
||||
|
||||
|
||||
class OCRServicePool:
|
||||
"""
|
||||
Pool of OCRService instances with concurrency control.
|
||||
|
||||
Features:
|
||||
- Per-device instance management (one service per GPU)
|
||||
- Queue-based task distribution
|
||||
- Semaphore-based concurrency limits
|
||||
- Health monitoring
|
||||
- Automatic service recovery
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Singleton pattern"""
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Optional[PoolConfig] = None):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.config = config or PoolConfig()
|
||||
self.services: Dict[str, List[PooledService]] = {}
|
||||
self.semaphores: Dict[str, threading.Semaphore] = {}
|
||||
self.queues: Dict[str, List] = {}
|
||||
self._pool_lock = threading.RLock()
|
||||
self._condition = threading.Condition(self._pool_lock)
|
||||
|
||||
# Metrics
|
||||
self._metrics = {
|
||||
"total_acquisitions": 0,
|
||||
"total_releases": 0,
|
||||
"total_timeouts": 0,
|
||||
"total_errors": 0,
|
||||
"queue_waits": 0,
|
||||
}
|
||||
|
||||
# Initialize default device pool
|
||||
self._initialize_device("GPU:0")
|
||||
|
||||
self._initialized = True
|
||||
logger.info("OCRServicePool initialized")
|
||||
|
||||
def _initialize_device(self, device: str):
|
||||
"""Initialize pool resources for a device"""
|
||||
with self._pool_lock:
|
||||
if device not in self.services:
|
||||
self.services[device] = []
|
||||
self.semaphores[device] = threading.Semaphore(
|
||||
self.config.max_services_per_device
|
||||
)
|
||||
self.queues[device] = []
|
||||
logger.info(f"Initialized pool for device {device}")
|
||||
|
||||
def _create_service(self, device: str) -> PooledService:
|
||||
"""
|
||||
Create a new OCRService instance for the pool.
|
||||
|
||||
Args:
|
||||
device: Device identifier (e.g., "GPU:0", "CPU")
|
||||
|
||||
Returns:
|
||||
PooledService wrapper
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from app.services.ocr_service import OCRService
|
||||
|
||||
logger.info(f"Creating new OCRService for device {device}")
|
||||
start_time = time.time()
|
||||
|
||||
# Create service instance
|
||||
service = OCRService()
|
||||
|
||||
creation_time = time.time() - start_time
|
||||
logger.info(f"OCRService created in {creation_time:.2f}s for device {device}")
|
||||
|
||||
return PooledService(
|
||||
service=service,
|
||||
device=device,
|
||||
state=ServiceState.AVAILABLE
|
||||
)
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
device: str = "GPU:0",
|
||||
timeout: Optional[float] = None,
|
||||
task_id: Optional[str] = None
|
||||
) -> Optional[PooledService]:
|
||||
"""
|
||||
Acquire an OCRService from the pool.
|
||||
|
||||
Args:
|
||||
device: Preferred device (e.g., "GPU:0")
|
||||
timeout: Maximum time to wait for a service
|
||||
task_id: Optional task ID for tracking
|
||||
|
||||
Returns:
|
||||
PooledService if available, None if timeout
|
||||
"""
|
||||
timeout = timeout or self.config.acquire_timeout_seconds
|
||||
self._initialize_device(device)
|
||||
|
||||
start_time = time.time()
|
||||
deadline = start_time + timeout
|
||||
|
||||
with self._condition:
|
||||
while True:
|
||||
# Try to get an available service
|
||||
service = self._try_acquire_service(device, task_id)
|
||||
if service:
|
||||
self._metrics["total_acquisitions"] += 1
|
||||
return service
|
||||
|
||||
# Check if we can create a new service
|
||||
if self._can_create_service(device):
|
||||
try:
|
||||
pooled = self._create_service(device)
|
||||
pooled.state = ServiceState.IN_USE
|
||||
pooled.current_task_id = task_id
|
||||
pooled.use_count += 1
|
||||
self.services[device].append(pooled)
|
||||
self._metrics["total_acquisitions"] += 1
|
||||
logger.info(f"Created and acquired new service for {device}")
|
||||
return pooled
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create service for {device}: {e}")
|
||||
self._metrics["total_errors"] += 1
|
||||
|
||||
# Wait for a service to become available
|
||||
remaining = deadline - time.time()
|
||||
if remaining <= 0:
|
||||
self._metrics["total_timeouts"] += 1
|
||||
logger.warning(f"Timeout waiting for service on {device}")
|
||||
return None
|
||||
|
||||
self._metrics["queue_waits"] += 1
|
||||
logger.debug(f"Waiting for service on {device} (timeout: {remaining:.1f}s)")
|
||||
self._condition.wait(timeout=min(remaining, 1.0))
|
||||
|
||||
def _try_acquire_service(self, device: str, task_id: Optional[str]) -> Optional[PooledService]:
|
||||
"""Try to acquire an available service without waiting"""
|
||||
for pooled in self.services.get(device, []):
|
||||
if pooled.state == ServiceState.AVAILABLE:
|
||||
pooled.state = ServiceState.IN_USE
|
||||
pooled.last_used = time.time()
|
||||
pooled.use_count += 1
|
||||
pooled.current_task_id = task_id
|
||||
logger.debug(f"Acquired existing service for {device} (use #{pooled.use_count})")
|
||||
return pooled
|
||||
return None
|
||||
|
||||
def _can_create_service(self, device: str) -> bool:
|
||||
"""Check if a new service can be created"""
|
||||
device_count = len(self.services.get(device, []))
|
||||
total_count = sum(len(services) for services in self.services.values())
|
||||
|
||||
return (
|
||||
device_count < self.config.max_services_per_device and
|
||||
total_count < self.config.max_total_services
|
||||
)
|
||||
|
||||
def release(self, pooled: PooledService, error: Optional[Exception] = None):
|
||||
"""
|
||||
Release a service back to the pool.
|
||||
|
||||
Args:
|
||||
pooled: The pooled service to release
|
||||
error: Optional error that occurred during use
|
||||
"""
|
||||
with self._condition:
|
||||
if error:
|
||||
pooled.error_count += 1
|
||||
self._metrics["total_errors"] += 1
|
||||
logger.warning(f"Service released with error: {error}")
|
||||
|
||||
# Mark unhealthy if too many errors
|
||||
if pooled.error_count >= self.config.max_consecutive_errors:
|
||||
pooled.state = ServiceState.UNHEALTHY
|
||||
logger.error(f"Service marked unhealthy after {pooled.error_count} errors")
|
||||
else:
|
||||
pooled.state = ServiceState.AVAILABLE
|
||||
else:
|
||||
pooled.error_count = 0 # Reset error count on success
|
||||
pooled.state = ServiceState.AVAILABLE
|
||||
|
||||
pooled.last_used = time.time()
|
||||
pooled.current_task_id = None
|
||||
self._metrics["total_releases"] += 1
|
||||
|
||||
# Clean up GPU memory after release
|
||||
try:
|
||||
model_manager = get_model_manager()
|
||||
model_manager.memory_guard.clear_gpu_cache()
|
||||
except Exception as e:
|
||||
logger.debug(f"Cache clear after release failed: {e}")
|
||||
|
||||
# Notify waiting threads
|
||||
self._condition.notify_all()
|
||||
|
||||
logger.debug(f"Service released for device {pooled.device}")
|
||||
|
||||
@contextmanager
|
||||
def acquire_context(
|
||||
self,
|
||||
device: str = "GPU:0",
|
||||
timeout: Optional[float] = None,
|
||||
task_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Context manager for acquiring and releasing a service.
|
||||
|
||||
Usage:
|
||||
with pool.acquire_context("GPU:0") as pooled:
|
||||
result = pooled.service.process(...)
|
||||
"""
|
||||
pooled = None
|
||||
error = None
|
||||
try:
|
||||
pooled = self.acquire(device, timeout, task_id)
|
||||
if pooled is None:
|
||||
raise TimeoutError(f"Failed to acquire service for {device}")
|
||||
yield pooled
|
||||
except Exception as e:
|
||||
error = e
|
||||
raise
|
||||
finally:
|
||||
if pooled:
|
||||
self.release(pooled, error)
|
||||
|
||||
def get_service(self, device: str = "GPU:0") -> Optional["OCRService"]:
|
||||
"""
|
||||
Get a service directly (for backward compatibility).
|
||||
|
||||
This acquires a service and returns the underlying OCRService.
|
||||
The caller is responsible for calling release_service() when done.
|
||||
|
||||
Args:
|
||||
device: Device identifier
|
||||
|
||||
Returns:
|
||||
OCRService instance or None
|
||||
"""
|
||||
pooled = self.acquire(device)
|
||||
if pooled:
|
||||
return pooled.service
|
||||
return None
|
||||
|
||||
def get_pool_stats(self) -> Dict:
|
||||
"""Get current pool statistics"""
|
||||
with self._pool_lock:
|
||||
stats = {
|
||||
"devices": {},
|
||||
"metrics": self._metrics.copy(),
|
||||
"total_services": 0,
|
||||
"available_services": 0,
|
||||
"in_use_services": 0,
|
||||
}
|
||||
|
||||
for device, services in self.services.items():
|
||||
available = sum(1 for s in services if s.state == ServiceState.AVAILABLE)
|
||||
in_use = sum(1 for s in services if s.state == ServiceState.IN_USE)
|
||||
unhealthy = sum(1 for s in services if s.state == ServiceState.UNHEALTHY)
|
||||
|
||||
stats["devices"][device] = {
|
||||
"total": len(services),
|
||||
"available": available,
|
||||
"in_use": in_use,
|
||||
"unhealthy": unhealthy,
|
||||
"max_allowed": self.config.max_services_per_device,
|
||||
}
|
||||
|
||||
stats["total_services"] += len(services)
|
||||
stats["available_services"] += available
|
||||
stats["in_use_services"] += in_use
|
||||
|
||||
return stats
|
||||
|
||||
def health_check(self) -> Dict:
|
||||
"""
|
||||
Perform health check on all pooled services.
|
||||
|
||||
Returns:
|
||||
Health check results
|
||||
"""
|
||||
results = {
|
||||
"healthy": True,
|
||||
"services": [],
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
with self._pool_lock:
|
||||
for device, services in self.services.items():
|
||||
for idx, pooled in enumerate(services):
|
||||
service_health = {
|
||||
"device": device,
|
||||
"index": idx,
|
||||
"state": pooled.state.value,
|
||||
"error_count": pooled.error_count,
|
||||
"use_count": pooled.use_count,
|
||||
"idle_seconds": time.time() - pooled.last_used,
|
||||
}
|
||||
|
||||
# Check if service is responsive
|
||||
if pooled.state == ServiceState.AVAILABLE:
|
||||
try:
|
||||
# Simple check - verify service has required attributes
|
||||
has_process = hasattr(pooled.service, 'process')
|
||||
has_gpu_status = hasattr(pooled.service, 'get_gpu_status')
|
||||
service_health["responsive"] = has_process and has_gpu_status
|
||||
except Exception as e:
|
||||
service_health["responsive"] = False
|
||||
service_health["error"] = str(e)
|
||||
results["healthy"] = False
|
||||
else:
|
||||
service_health["responsive"] = pooled.state != ServiceState.UNHEALTHY
|
||||
|
||||
if pooled.state == ServiceState.UNHEALTHY:
|
||||
results["healthy"] = False
|
||||
|
||||
results["services"].append(service_health)
|
||||
|
||||
return results
|
||||
|
||||
def recover_unhealthy(self):
|
||||
"""
|
||||
Attempt to recover unhealthy services.
|
||||
"""
|
||||
with self._pool_lock:
|
||||
for device, services in self.services.items():
|
||||
for idx, pooled in enumerate(services):
|
||||
if pooled.state == ServiceState.UNHEALTHY:
|
||||
logger.info(f"Attempting to recover unhealthy service {device}:{idx}")
|
||||
try:
|
||||
# Remove old service
|
||||
services.remove(pooled)
|
||||
|
||||
# Create new service
|
||||
new_pooled = self._create_service(device)
|
||||
services.append(new_pooled)
|
||||
logger.info(f"Successfully recovered service {device}:{idx}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to recover service {device}:{idx}: {e}")
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the pool and cleanup all services.
|
||||
"""
|
||||
logger.info("OCRServicePool shutdown started")
|
||||
|
||||
with self._pool_lock:
|
||||
for device, services in self.services.items():
|
||||
for pooled in services:
|
||||
try:
|
||||
# Clean up service resources
|
||||
if hasattr(pooled.service, 'cleanup_gpu_memory'):
|
||||
pooled.service.cleanup_gpu_memory()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up service: {e}")
|
||||
|
||||
# Clear all pools
|
||||
self.services.clear()
|
||||
self.semaphores.clear()
|
||||
self.queues.clear()
|
||||
|
||||
logger.info("OCRServicePool shutdown completed")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_service_pool: Optional[OCRServicePool] = None
|
||||
|
||||
|
||||
def get_service_pool(config: Optional[PoolConfig] = None) -> OCRServicePool:
|
||||
"""
|
||||
Get the global OCRServicePool instance.
|
||||
|
||||
Args:
|
||||
config: Optional configuration (only used on first call)
|
||||
|
||||
Returns:
|
||||
OCRServicePool singleton instance
|
||||
"""
|
||||
global _service_pool
|
||||
if _service_pool is None:
|
||||
_service_pool = OCRServicePool(config)
|
||||
return _service_pool
|
||||
|
||||
|
||||
def shutdown_service_pool():
|
||||
"""Shutdown the global service pool"""
|
||||
global _service_pool
|
||||
if _service_pool is not None:
|
||||
_service_pool.shutdown()
|
||||
_service_pool = None
|
||||
Reference in New Issue
Block a user