feat: implement hybrid image extraction and memory management
Backend: - Add hybrid image extraction for Direct track (inline image blocks) - Add render_inline_image_regions() fallback when OCR doesn't find images - Add check_document_for_missing_images() for detecting missing images - Add memory management system (MemoryGuard, ModelManager, ServicePool) - Update pdf_generator_service to handle HYBRID processing track - Add ElementType.LOGO for logo extraction Frontend: - Fix PDF viewer re-rendering issues with memoization - Add TaskNotFound component and useTaskValidation hook - Disable StrictMode due to react-pdf incompatibility - Fix task detail and results page loading states 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
468
backend/app/services/service_pool.py
Normal file
468
backend/app/services/service_pool.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""
|
||||
Tool_OCR - OCR Service Pool
|
||||
Manages a pool of OCRService instances to prevent duplicate model loading
|
||||
and control concurrent GPU operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from app.services.memory_manager import get_model_manager, MemoryConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.ocr_service import OCRService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceState(Enum):
|
||||
"""State of a pooled service"""
|
||||
AVAILABLE = "available"
|
||||
IN_USE = "in_use"
|
||||
UNHEALTHY = "unhealthy"
|
||||
INITIALIZING = "initializing"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PooledService:
|
||||
"""Wrapper for a pooled OCRService instance"""
|
||||
service: Any # OCRService
|
||||
device: str
|
||||
state: ServiceState = ServiceState.AVAILABLE
|
||||
created_at: float = field(default_factory=time.time)
|
||||
last_used: float = field(default_factory=time.time)
|
||||
use_count: int = 0
|
||||
error_count: int = 0
|
||||
current_task_id: Optional[str] = None
|
||||
|
||||
|
||||
class PoolConfig:
|
||||
"""Configuration for the service pool"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_services_per_device: int = 1,
|
||||
max_total_services: int = 2,
|
||||
acquire_timeout_seconds: float = 300.0,
|
||||
max_queue_size: int = 50,
|
||||
health_check_interval_seconds: int = 60,
|
||||
max_consecutive_errors: int = 3,
|
||||
service_idle_timeout_seconds: int = 600,
|
||||
enable_auto_scaling: bool = False,
|
||||
):
|
||||
self.max_services_per_device = max_services_per_device
|
||||
self.max_total_services = max_total_services
|
||||
self.acquire_timeout_seconds = acquire_timeout_seconds
|
||||
self.max_queue_size = max_queue_size
|
||||
self.health_check_interval_seconds = health_check_interval_seconds
|
||||
self.max_consecutive_errors = max_consecutive_errors
|
||||
self.service_idle_timeout_seconds = service_idle_timeout_seconds
|
||||
self.enable_auto_scaling = enable_auto_scaling
|
||||
|
||||
|
||||
class OCRServicePool:
|
||||
"""
|
||||
Pool of OCRService instances with concurrency control.
|
||||
|
||||
Features:
|
||||
- Per-device instance management (one service per GPU)
|
||||
- Queue-based task distribution
|
||||
- Semaphore-based concurrency limits
|
||||
- Health monitoring
|
||||
- Automatic service recovery
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Singleton pattern"""
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Optional[PoolConfig] = None):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.config = config or PoolConfig()
|
||||
self.services: Dict[str, List[PooledService]] = {}
|
||||
self.semaphores: Dict[str, threading.Semaphore] = {}
|
||||
self.queues: Dict[str, List] = {}
|
||||
self._pool_lock = threading.RLock()
|
||||
self._condition = threading.Condition(self._pool_lock)
|
||||
|
||||
# Metrics
|
||||
self._metrics = {
|
||||
"total_acquisitions": 0,
|
||||
"total_releases": 0,
|
||||
"total_timeouts": 0,
|
||||
"total_errors": 0,
|
||||
"queue_waits": 0,
|
||||
}
|
||||
|
||||
# Initialize default device pool
|
||||
self._initialize_device("GPU:0")
|
||||
|
||||
self._initialized = True
|
||||
logger.info("OCRServicePool initialized")
|
||||
|
||||
def _initialize_device(self, device: str):
|
||||
"""Initialize pool resources for a device"""
|
||||
with self._pool_lock:
|
||||
if device not in self.services:
|
||||
self.services[device] = []
|
||||
self.semaphores[device] = threading.Semaphore(
|
||||
self.config.max_services_per_device
|
||||
)
|
||||
self.queues[device] = []
|
||||
logger.info(f"Initialized pool for device {device}")
|
||||
|
||||
def _create_service(self, device: str) -> PooledService:
|
||||
"""
|
||||
Create a new OCRService instance for the pool.
|
||||
|
||||
Args:
|
||||
device: Device identifier (e.g., "GPU:0", "CPU")
|
||||
|
||||
Returns:
|
||||
PooledService wrapper
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from app.services.ocr_service import OCRService
|
||||
|
||||
logger.info(f"Creating new OCRService for device {device}")
|
||||
start_time = time.time()
|
||||
|
||||
# Create service instance
|
||||
service = OCRService()
|
||||
|
||||
creation_time = time.time() - start_time
|
||||
logger.info(f"OCRService created in {creation_time:.2f}s for device {device}")
|
||||
|
||||
return PooledService(
|
||||
service=service,
|
||||
device=device,
|
||||
state=ServiceState.AVAILABLE
|
||||
)
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
device: str = "GPU:0",
|
||||
timeout: Optional[float] = None,
|
||||
task_id: Optional[str] = None
|
||||
) -> Optional[PooledService]:
|
||||
"""
|
||||
Acquire an OCRService from the pool.
|
||||
|
||||
Args:
|
||||
device: Preferred device (e.g., "GPU:0")
|
||||
timeout: Maximum time to wait for a service
|
||||
task_id: Optional task ID for tracking
|
||||
|
||||
Returns:
|
||||
PooledService if available, None if timeout
|
||||
"""
|
||||
timeout = timeout or self.config.acquire_timeout_seconds
|
||||
self._initialize_device(device)
|
||||
|
||||
start_time = time.time()
|
||||
deadline = start_time + timeout
|
||||
|
||||
with self._condition:
|
||||
while True:
|
||||
# Try to get an available service
|
||||
service = self._try_acquire_service(device, task_id)
|
||||
if service:
|
||||
self._metrics["total_acquisitions"] += 1
|
||||
return service
|
||||
|
||||
# Check if we can create a new service
|
||||
if self._can_create_service(device):
|
||||
try:
|
||||
pooled = self._create_service(device)
|
||||
pooled.state = ServiceState.IN_USE
|
||||
pooled.current_task_id = task_id
|
||||
pooled.use_count += 1
|
||||
self.services[device].append(pooled)
|
||||
self._metrics["total_acquisitions"] += 1
|
||||
logger.info(f"Created and acquired new service for {device}")
|
||||
return pooled
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create service for {device}: {e}")
|
||||
self._metrics["total_errors"] += 1
|
||||
|
||||
# Wait for a service to become available
|
||||
remaining = deadline - time.time()
|
||||
if remaining <= 0:
|
||||
self._metrics["total_timeouts"] += 1
|
||||
logger.warning(f"Timeout waiting for service on {device}")
|
||||
return None
|
||||
|
||||
self._metrics["queue_waits"] += 1
|
||||
logger.debug(f"Waiting for service on {device} (timeout: {remaining:.1f}s)")
|
||||
self._condition.wait(timeout=min(remaining, 1.0))
|
||||
|
||||
def _try_acquire_service(self, device: str, task_id: Optional[str]) -> Optional[PooledService]:
|
||||
"""Try to acquire an available service without waiting"""
|
||||
for pooled in self.services.get(device, []):
|
||||
if pooled.state == ServiceState.AVAILABLE:
|
||||
pooled.state = ServiceState.IN_USE
|
||||
pooled.last_used = time.time()
|
||||
pooled.use_count += 1
|
||||
pooled.current_task_id = task_id
|
||||
logger.debug(f"Acquired existing service for {device} (use #{pooled.use_count})")
|
||||
return pooled
|
||||
return None
|
||||
|
||||
def _can_create_service(self, device: str) -> bool:
|
||||
"""Check if a new service can be created"""
|
||||
device_count = len(self.services.get(device, []))
|
||||
total_count = sum(len(services) for services in self.services.values())
|
||||
|
||||
return (
|
||||
device_count < self.config.max_services_per_device and
|
||||
total_count < self.config.max_total_services
|
||||
)
|
||||
|
||||
def release(self, pooled: PooledService, error: Optional[Exception] = None):
|
||||
"""
|
||||
Release a service back to the pool.
|
||||
|
||||
Args:
|
||||
pooled: The pooled service to release
|
||||
error: Optional error that occurred during use
|
||||
"""
|
||||
with self._condition:
|
||||
if error:
|
||||
pooled.error_count += 1
|
||||
self._metrics["total_errors"] += 1
|
||||
logger.warning(f"Service released with error: {error}")
|
||||
|
||||
# Mark unhealthy if too many errors
|
||||
if pooled.error_count >= self.config.max_consecutive_errors:
|
||||
pooled.state = ServiceState.UNHEALTHY
|
||||
logger.error(f"Service marked unhealthy after {pooled.error_count} errors")
|
||||
else:
|
||||
pooled.state = ServiceState.AVAILABLE
|
||||
else:
|
||||
pooled.error_count = 0 # Reset error count on success
|
||||
pooled.state = ServiceState.AVAILABLE
|
||||
|
||||
pooled.last_used = time.time()
|
||||
pooled.current_task_id = None
|
||||
self._metrics["total_releases"] += 1
|
||||
|
||||
# Clean up GPU memory after release
|
||||
try:
|
||||
model_manager = get_model_manager()
|
||||
model_manager.memory_guard.clear_gpu_cache()
|
||||
except Exception as e:
|
||||
logger.debug(f"Cache clear after release failed: {e}")
|
||||
|
||||
# Notify waiting threads
|
||||
self._condition.notify_all()
|
||||
|
||||
logger.debug(f"Service released for device {pooled.device}")
|
||||
|
||||
@contextmanager
|
||||
def acquire_context(
|
||||
self,
|
||||
device: str = "GPU:0",
|
||||
timeout: Optional[float] = None,
|
||||
task_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Context manager for acquiring and releasing a service.
|
||||
|
||||
Usage:
|
||||
with pool.acquire_context("GPU:0") as pooled:
|
||||
result = pooled.service.process(...)
|
||||
"""
|
||||
pooled = None
|
||||
error = None
|
||||
try:
|
||||
pooled = self.acquire(device, timeout, task_id)
|
||||
if pooled is None:
|
||||
raise TimeoutError(f"Failed to acquire service for {device}")
|
||||
yield pooled
|
||||
except Exception as e:
|
||||
error = e
|
||||
raise
|
||||
finally:
|
||||
if pooled:
|
||||
self.release(pooled, error)
|
||||
|
||||
def get_service(self, device: str = "GPU:0") -> Optional["OCRService"]:
|
||||
"""
|
||||
Get a service directly (for backward compatibility).
|
||||
|
||||
This acquires a service and returns the underlying OCRService.
|
||||
The caller is responsible for calling release_service() when done.
|
||||
|
||||
Args:
|
||||
device: Device identifier
|
||||
|
||||
Returns:
|
||||
OCRService instance or None
|
||||
"""
|
||||
pooled = self.acquire(device)
|
||||
if pooled:
|
||||
return pooled.service
|
||||
return None
|
||||
|
||||
def get_pool_stats(self) -> Dict:
|
||||
"""Get current pool statistics"""
|
||||
with self._pool_lock:
|
||||
stats = {
|
||||
"devices": {},
|
||||
"metrics": self._metrics.copy(),
|
||||
"total_services": 0,
|
||||
"available_services": 0,
|
||||
"in_use_services": 0,
|
||||
}
|
||||
|
||||
for device, services in self.services.items():
|
||||
available = sum(1 for s in services if s.state == ServiceState.AVAILABLE)
|
||||
in_use = sum(1 for s in services if s.state == ServiceState.IN_USE)
|
||||
unhealthy = sum(1 for s in services if s.state == ServiceState.UNHEALTHY)
|
||||
|
||||
stats["devices"][device] = {
|
||||
"total": len(services),
|
||||
"available": available,
|
||||
"in_use": in_use,
|
||||
"unhealthy": unhealthy,
|
||||
"max_allowed": self.config.max_services_per_device,
|
||||
}
|
||||
|
||||
stats["total_services"] += len(services)
|
||||
stats["available_services"] += available
|
||||
stats["in_use_services"] += in_use
|
||||
|
||||
return stats
|
||||
|
||||
def health_check(self) -> Dict:
|
||||
"""
|
||||
Perform health check on all pooled services.
|
||||
|
||||
Returns:
|
||||
Health check results
|
||||
"""
|
||||
results = {
|
||||
"healthy": True,
|
||||
"services": [],
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
with self._pool_lock:
|
||||
for device, services in self.services.items():
|
||||
for idx, pooled in enumerate(services):
|
||||
service_health = {
|
||||
"device": device,
|
||||
"index": idx,
|
||||
"state": pooled.state.value,
|
||||
"error_count": pooled.error_count,
|
||||
"use_count": pooled.use_count,
|
||||
"idle_seconds": time.time() - pooled.last_used,
|
||||
}
|
||||
|
||||
# Check if service is responsive
|
||||
if pooled.state == ServiceState.AVAILABLE:
|
||||
try:
|
||||
# Simple check - verify service has required attributes
|
||||
has_process = hasattr(pooled.service, 'process')
|
||||
has_gpu_status = hasattr(pooled.service, 'get_gpu_status')
|
||||
service_health["responsive"] = has_process and has_gpu_status
|
||||
except Exception as e:
|
||||
service_health["responsive"] = False
|
||||
service_health["error"] = str(e)
|
||||
results["healthy"] = False
|
||||
else:
|
||||
service_health["responsive"] = pooled.state != ServiceState.UNHEALTHY
|
||||
|
||||
if pooled.state == ServiceState.UNHEALTHY:
|
||||
results["healthy"] = False
|
||||
|
||||
results["services"].append(service_health)
|
||||
|
||||
return results
|
||||
|
||||
def recover_unhealthy(self):
|
||||
"""
|
||||
Attempt to recover unhealthy services.
|
||||
"""
|
||||
with self._pool_lock:
|
||||
for device, services in self.services.items():
|
||||
for idx, pooled in enumerate(services):
|
||||
if pooled.state == ServiceState.UNHEALTHY:
|
||||
logger.info(f"Attempting to recover unhealthy service {device}:{idx}")
|
||||
try:
|
||||
# Remove old service
|
||||
services.remove(pooled)
|
||||
|
||||
# Create new service
|
||||
new_pooled = self._create_service(device)
|
||||
services.append(new_pooled)
|
||||
logger.info(f"Successfully recovered service {device}:{idx}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to recover service {device}:{idx}: {e}")
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the pool and cleanup all services.
|
||||
"""
|
||||
logger.info("OCRServicePool shutdown started")
|
||||
|
||||
with self._pool_lock:
|
||||
for device, services in self.services.items():
|
||||
for pooled in services:
|
||||
try:
|
||||
# Clean up service resources
|
||||
if hasattr(pooled.service, 'cleanup_gpu_memory'):
|
||||
pooled.service.cleanup_gpu_memory()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up service: {e}")
|
||||
|
||||
# Clear all pools
|
||||
self.services.clear()
|
||||
self.semaphores.clear()
|
||||
self.queues.clear()
|
||||
|
||||
logger.info("OCRServicePool shutdown completed")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_service_pool: Optional[OCRServicePool] = None
|
||||
|
||||
|
||||
def get_service_pool(config: Optional[PoolConfig] = None) -> OCRServicePool:
|
||||
"""
|
||||
Get the global OCRServicePool instance.
|
||||
|
||||
Args:
|
||||
config: Optional configuration (only used on first call)
|
||||
|
||||
Returns:
|
||||
OCRServicePool singleton instance
|
||||
"""
|
||||
global _service_pool
|
||||
if _service_pool is None:
|
||||
_service_pool = OCRServicePool(config)
|
||||
return _service_pool
|
||||
|
||||
|
||||
def shutdown_service_pool():
|
||||
"""Shutdown the global service pool"""
|
||||
global _service_pool
|
||||
if _service_pool is not None:
|
||||
_service_pool.shutdown()
|
||||
_service_pool = None
|
||||
Reference in New Issue
Block a user