""" Tool_OCR - OCR Service Pool Manages a pool of OCRService instances to prevent duplicate model loading and control concurrent GPU operations. """ import asyncio import logging import threading import time from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional, TYPE_CHECKING from app.services.memory_manager import get_model_manager, MemoryConfig from app.services.memory_policy_engine import get_memory_policy_engine if TYPE_CHECKING: from app.services.ocr_service import OCRService logger = logging.getLogger(__name__) class ServiceState(Enum): """State of a pooled service""" AVAILABLE = "available" IN_USE = "in_use" UNHEALTHY = "unhealthy" INITIALIZING = "initializing" @dataclass class PooledService: """Wrapper for a pooled OCRService instance""" service: Any # OCRService device: str state: ServiceState = ServiceState.AVAILABLE created_at: float = field(default_factory=time.time) last_used: float = field(default_factory=time.time) use_count: int = 0 error_count: int = 0 current_task_id: Optional[str] = None class PoolConfig: """Configuration for the service pool""" def __init__( self, max_services_per_device: int = 1, max_total_services: int = 2, acquire_timeout_seconds: float = 300.0, max_queue_size: int = 50, health_check_interval_seconds: int = 60, max_consecutive_errors: int = 3, service_idle_timeout_seconds: int = 600, enable_auto_scaling: bool = False, ): self.max_services_per_device = max_services_per_device self.max_total_services = max_total_services self.acquire_timeout_seconds = acquire_timeout_seconds self.max_queue_size = max_queue_size self.health_check_interval_seconds = health_check_interval_seconds self.max_consecutive_errors = max_consecutive_errors self.service_idle_timeout_seconds = service_idle_timeout_seconds self.enable_auto_scaling = enable_auto_scaling class OCRServicePool: """ Pool of OCRService instances with concurrency control. Features: - Per-device instance management (one service per GPU) - Queue-based task distribution - Semaphore-based concurrency limits - Health monitoring - Automatic service recovery """ _instance = None _lock = threading.Lock() def __new__(cls, *args, **kwargs): """Singleton pattern""" with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self, config: Optional[PoolConfig] = None): if self._initialized: return self.config = config or PoolConfig() self.services: Dict[str, List[PooledService]] = {} self.semaphores: Dict[str, threading.Semaphore] = {} self.queues: Dict[str, List] = {} self._pool_lock = threading.RLock() self._condition = threading.Condition(self._pool_lock) # Metrics self._metrics = { "total_acquisitions": 0, "total_releases": 0, "total_timeouts": 0, "total_errors": 0, "queue_waits": 0, } # Initialize default device pool self._initialize_device("GPU:0") self._initialized = True logger.info("OCRServicePool initialized") def _initialize_device(self, device: str): """Initialize pool resources for a device""" with self._pool_lock: if device not in self.services: self.services[device] = [] self.semaphores[device] = threading.Semaphore( self.config.max_services_per_device ) self.queues[device] = [] logger.info(f"Initialized pool for device {device}") def _create_service(self, device: str) -> PooledService: """ Create a new OCRService instance for the pool. Args: device: Device identifier (e.g., "GPU:0", "CPU") Returns: PooledService wrapper """ # Import here to avoid circular imports from app.services.ocr_service import OCRService logger.info(f"Creating new OCRService for device {device}") start_time = time.time() # Create service instance service = OCRService() creation_time = time.time() - start_time logger.info(f"OCRService created in {creation_time:.2f}s for device {device}") return PooledService( service=service, device=device, state=ServiceState.AVAILABLE ) def acquire( self, device: str = "GPU:0", timeout: Optional[float] = None, task_id: Optional[str] = None ) -> Optional[PooledService]: """ Acquire an OCRService from the pool. Args: device: Preferred device (e.g., "GPU:0") timeout: Maximum time to wait for a service task_id: Optional task ID for tracking Returns: PooledService if available, None if timeout """ timeout = timeout or self.config.acquire_timeout_seconds self._initialize_device(device) start_time = time.time() deadline = start_time + timeout with self._condition: while True: # Try to get an available service service = self._try_acquire_service(device, task_id) if service: self._metrics["total_acquisitions"] += 1 return service # Check if we can create a new service if self._can_create_service(device): try: pooled = self._create_service(device) pooled.state = ServiceState.IN_USE pooled.current_task_id = task_id pooled.use_count += 1 self.services[device].append(pooled) self._metrics["total_acquisitions"] += 1 logger.info(f"Created and acquired new service for {device}") return pooled except Exception as e: logger.error(f"Failed to create service for {device}: {e}") self._metrics["total_errors"] += 1 # Wait for a service to become available remaining = deadline - time.time() if remaining <= 0: self._metrics["total_timeouts"] += 1 logger.warning(f"Timeout waiting for service on {device}") return None self._metrics["queue_waits"] += 1 logger.debug(f"Waiting for service on {device} (timeout: {remaining:.1f}s)") self._condition.wait(timeout=min(remaining, 1.0)) def _try_acquire_service(self, device: str, task_id: Optional[str]) -> Optional[PooledService]: """Try to acquire an available service without waiting""" for pooled in self.services.get(device, []): if pooled.state == ServiceState.AVAILABLE: pooled.state = ServiceState.IN_USE pooled.last_used = time.time() pooled.use_count += 1 pooled.current_task_id = task_id logger.debug(f"Acquired existing service for {device} (use #{pooled.use_count})") return pooled return None def _can_create_service(self, device: str) -> bool: """Check if a new service can be created""" device_count = len(self.services.get(device, [])) total_count = sum(len(services) for services in self.services.values()) return ( device_count < self.config.max_services_per_device and total_count < self.config.max_total_services ) def release(self, pooled: PooledService, error: Optional[Exception] = None): """ Release a service back to the pool. Args: pooled: The pooled service to release error: Optional error that occurred during use """ with self._condition: if error: pooled.error_count += 1 self._metrics["total_errors"] += 1 logger.warning(f"Service released with error: {error}") # Mark unhealthy if too many errors if pooled.error_count >= self.config.max_consecutive_errors: pooled.state = ServiceState.UNHEALTHY logger.error(f"Service marked unhealthy after {pooled.error_count} errors") else: pooled.state = ServiceState.AVAILABLE else: pooled.error_count = 0 # Reset error count on success pooled.state = ServiceState.AVAILABLE pooled.last_used = time.time() pooled.current_task_id = None self._metrics["total_releases"] += 1 # Clean up GPU memory after release try: # Prefer new MemoryPolicyEngine engine = get_memory_policy_engine() engine.clear_cache() except Exception: # Fallback to legacy model_manager try: model_manager = get_model_manager() model_manager.memory_guard.clear_gpu_cache() except Exception as e: logger.debug(f"Cache clear after release failed: {e}") # Notify waiting threads self._condition.notify_all() logger.debug(f"Service released for device {pooled.device}") @contextmanager def acquire_context( self, device: str = "GPU:0", timeout: Optional[float] = None, task_id: Optional[str] = None ): """ Context manager for acquiring and releasing a service. Usage: with pool.acquire_context("GPU:0") as pooled: result = pooled.service.process(...) """ pooled = None error = None try: pooled = self.acquire(device, timeout, task_id) if pooled is None: raise TimeoutError(f"Failed to acquire service for {device}") yield pooled except Exception as e: error = e raise finally: if pooled: self.release(pooled, error) def get_service(self, device: str = "GPU:0") -> Optional["OCRService"]: """ Get a service directly (for backward compatibility). This acquires a service and returns the underlying OCRService. The caller is responsible for calling release_service() when done. Args: device: Device identifier Returns: OCRService instance or None """ pooled = self.acquire(device) if pooled: return pooled.service return None def get_pool_stats(self) -> Dict: """Get current pool statistics""" with self._pool_lock: stats = { "devices": {}, "metrics": self._metrics.copy(), "total_services": 0, "available_services": 0, "in_use_services": 0, } for device, services in self.services.items(): available = sum(1 for s in services if s.state == ServiceState.AVAILABLE) in_use = sum(1 for s in services if s.state == ServiceState.IN_USE) unhealthy = sum(1 for s in services if s.state == ServiceState.UNHEALTHY) stats["devices"][device] = { "total": len(services), "available": available, "in_use": in_use, "unhealthy": unhealthy, "max_allowed": self.config.max_services_per_device, } stats["total_services"] += len(services) stats["available_services"] += available stats["in_use_services"] += in_use return stats def health_check(self) -> Dict: """ Perform health check on all pooled services. Returns: Health check results """ results = { "healthy": True, "services": [], "timestamp": time.time() } with self._pool_lock: for device, services in self.services.items(): for idx, pooled in enumerate(services): service_health = { "device": device, "index": idx, "state": pooled.state.value, "error_count": pooled.error_count, "use_count": pooled.use_count, "idle_seconds": time.time() - pooled.last_used, } # Check if service is responsive if pooled.state == ServiceState.AVAILABLE: try: # Simple check - verify service has required attributes has_process = hasattr(pooled.service, 'process') has_gpu_status = hasattr(pooled.service, 'get_gpu_status') service_health["responsive"] = has_process and has_gpu_status except Exception as e: service_health["responsive"] = False service_health["error"] = str(e) results["healthy"] = False else: service_health["responsive"] = pooled.state != ServiceState.UNHEALTHY if pooled.state == ServiceState.UNHEALTHY: results["healthy"] = False results["services"].append(service_health) return results def recover_unhealthy(self): """ Attempt to recover unhealthy services. """ with self._pool_lock: for device, services in self.services.items(): for idx, pooled in enumerate(services): if pooled.state == ServiceState.UNHEALTHY: logger.info(f"Attempting to recover unhealthy service {device}:{idx}") try: # Remove old service services.remove(pooled) # Create new service new_pooled = self._create_service(device) services.append(new_pooled) logger.info(f"Successfully recovered service {device}:{idx}") except Exception as e: logger.error(f"Failed to recover service {device}:{idx}: {e}") def shutdown(self): """ Shutdown the pool and cleanup all services. """ logger.info("OCRServicePool shutdown started") with self._pool_lock: for device, services in self.services.items(): for pooled in services: try: # Clean up service resources if hasattr(pooled.service, 'cleanup_gpu_memory'): pooled.service.cleanup_gpu_memory() except Exception as e: logger.warning(f"Error cleaning up service: {e}") # Clear all pools self.services.clear() self.semaphores.clear() self.queues.clear() logger.info("OCRServicePool shutdown completed") # Global singleton instance _service_pool: Optional[OCRServicePool] = None def get_service_pool(config: Optional[PoolConfig] = None) -> OCRServicePool: """ Get the global OCRServicePool instance. Args: config: Optional configuration (only used on first call) Returns: OCRServicePool singleton instance """ global _service_pool if _service_pool is None: _service_pool = OCRServicePool(config) return _service_pool def shutdown_service_pool(): """Shutdown the global service pool""" global _service_pool if _service_pool is not None: _service_pool.shutdown() _service_pool = None