""" Memory Policy Engine - Simplified memory management for OCR processing. This module consolidates the essential memory management features: - GPU memory monitoring - Prediction concurrency control - Model lifecycle management Removed unused features from the original memory_manager.py: - BatchProcessor - ProgressiveLoader - PriorityOperationQueue - RecoveryManager - MemoryDumper - PrometheusMetrics """ import gc import logging import threading import time from collections import deque from contextlib import contextmanager from dataclasses import dataclass, field from datetime import datetime from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) # ============================================================================ # Configuration # ============================================================================ @dataclass class MemoryPolicyConfig: """ Simplified memory policy configuration. Only includes parameters that are actually used in production. """ # GPU memory thresholds (ratio 0.0-1.0) warning_threshold: float = 0.80 critical_threshold: float = 0.95 emergency_threshold: float = 0.98 # Model management model_idle_timeout_seconds: int = 300 # 5 minutes memory_check_interval_seconds: int = 30 # Concurrency control max_concurrent_predictions: int = 2 prediction_timeout_seconds: float = 300.0 # GPU settings gpu_memory_limit_mb: int = 6144 # 6GB default class MemoryBackend(Enum): """Available memory monitoring backends.""" PYNVML = "pynvml" TORCH = "torch" PADDLE = "paddle" NONE = "none" # ============================================================================ # Memory Statistics # ============================================================================ @dataclass class MemoryStats: """Memory usage statistics.""" gpu_used_mb: float = 0.0 gpu_free_mb: float = 0.0 gpu_total_mb: float = 0.0 gpu_used_ratio: float = 0.0 cpu_used_mb: float = 0.0 cpu_available_mb: float = 0.0 timestamp: datetime = field(default_factory=datetime.now) backend: str = "none" @dataclass class MemoryAlert: """Memory alert record.""" level: str # warning, critical, emergency message: str stats: MemoryStats timestamp: datetime = field(default_factory=datetime.now) # ============================================================================ # GPU Memory Monitor # ============================================================================ class GPUMemoryMonitor: """ Monitors GPU memory usage with multiple backend support. Priority: pynvml > torch > paddle > none """ def __init__(self, config: MemoryPolicyConfig): self.config = config self._backend: MemoryBackend = MemoryBackend.NONE self._nvml_handle = None self._history: deque = deque(maxlen=100) self._alerts: deque = deque(maxlen=50) self._lock = threading.Lock() self._init_backend() def _init_backend(self): """Initialize the best available memory monitoring backend.""" # Try pynvml first (most accurate) try: import pynvml pynvml.nvmlInit() self._nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(0) self._backend = MemoryBackend.PYNVML logger.info("GPU memory monitoring using pynvml") return except Exception as e: logger.debug(f"pynvml not available: {e}") # Try torch try: import torch if torch.cuda.is_available(): self._backend = MemoryBackend.TORCH logger.info("GPU memory monitoring using torch") return except Exception as e: logger.debug(f"torch CUDA not available: {e}") # Try paddle try: import paddle if paddle.is_compiled_with_cuda(): self._backend = MemoryBackend.PADDLE logger.info("GPU memory monitoring using paddle") return except Exception as e: logger.debug(f"paddle CUDA not available: {e}") logger.warning("No GPU memory monitoring available") def get_stats(self, device_id: int = 0) -> MemoryStats: """Get current memory statistics.""" stats = MemoryStats(backend=self._backend.value) try: if self._backend == MemoryBackend.PYNVML: stats = self._get_pynvml_stats(device_id) elif self._backend == MemoryBackend.TORCH: stats = self._get_torch_stats(device_id) elif self._backend == MemoryBackend.PADDLE: stats = self._get_paddle_stats(device_id) # Add CPU stats stats = self._add_cpu_stats(stats) # Store in history with self._lock: self._history.append(stats) except Exception as e: logger.error(f"Failed to get memory stats: {e}") return stats def _get_pynvml_stats(self, device_id: int) -> MemoryStats: """Get stats using pynvml.""" import pynvml handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) info = pynvml.nvmlDeviceGetMemoryInfo(handle) return MemoryStats( gpu_used_mb=info.used / (1024 * 1024), gpu_free_mb=info.free / (1024 * 1024), gpu_total_mb=info.total / (1024 * 1024), gpu_used_ratio=info.used / info.total if info.total > 0 else 0, backend="pynvml" ) def _get_torch_stats(self, device_id: int) -> MemoryStats: """Get stats using torch.""" import torch allocated = torch.cuda.memory_allocated(device_id) reserved = torch.cuda.memory_reserved(device_id) total = torch.cuda.get_device_properties(device_id).total_memory return MemoryStats( gpu_used_mb=reserved / (1024 * 1024), gpu_free_mb=(total - reserved) / (1024 * 1024), gpu_total_mb=total / (1024 * 1024), gpu_used_ratio=reserved / total if total > 0 else 0, backend="torch" ) def _get_paddle_stats(self, device_id: int) -> MemoryStats: """Get stats using paddle.""" import paddle allocated = paddle.device.cuda.memory_allocated(device_id) reserved = paddle.device.cuda.memory_reserved(device_id) total = paddle.device.cuda.get_device_properties(device_id).total_memory return MemoryStats( gpu_used_mb=reserved / (1024 * 1024), gpu_free_mb=(total - reserved) / (1024 * 1024), gpu_total_mb=total / (1024 * 1024), gpu_used_ratio=reserved / total if total > 0 else 0, backend="paddle" ) def _add_cpu_stats(self, stats: MemoryStats) -> MemoryStats: """Add CPU memory stats.""" try: import psutil mem = psutil.virtual_memory() stats.cpu_used_mb = mem.used / (1024 * 1024) stats.cpu_available_mb = mem.available / (1024 * 1024) except Exception: pass return stats def check_memory(self, required_mb: float = 0, device_id: int = 0) -> Tuple[bool, str]: """ Check if memory is available. Returns: Tuple of (is_available, message) """ stats = self.get_stats(device_id) # Check thresholds if stats.gpu_used_ratio >= self.config.emergency_threshold: msg = f"Emergency: GPU at {stats.gpu_used_ratio*100:.1f}%" self._add_alert("emergency", msg, stats) return False, msg if stats.gpu_used_ratio >= self.config.critical_threshold: msg = f"Critical: GPU at {stats.gpu_used_ratio*100:.1f}%" self._add_alert("critical", msg, stats) return False, msg if stats.gpu_used_ratio >= self.config.warning_threshold: msg = f"Warning: GPU at {stats.gpu_used_ratio*100:.1f}%" self._add_alert("warning", msg, stats) # Warning doesn't block, just logs # Check if required memory is available if required_mb > 0 and stats.gpu_free_mb < required_mb: msg = f"Insufficient memory: need {required_mb}MB, have {stats.gpu_free_mb:.0f}MB" return False, msg return True, "OK" def _add_alert(self, level: str, message: str, stats: MemoryStats): """Add alert to history.""" with self._lock: self._alerts.append(MemoryAlert( level=level, message=message, stats=stats )) log_func = getattr(logger, level if level != "emergency" else "error") log_func(message) def clear_cache(self): """Clear GPU memory caches.""" try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() except Exception: pass try: import paddle if paddle.is_compiled_with_cuda(): paddle.device.cuda.empty_cache() except Exception: pass gc.collect() def get_alerts(self, limit: int = 10) -> List[MemoryAlert]: """Get recent alerts.""" with self._lock: return list(self._alerts)[-limit:] # ============================================================================ # Prediction Semaphore # ============================================================================ class PredictionSemaphore: """ Controls concurrent predictions to prevent GPU OOM. Singleton pattern ensures single point of concurrency control. """ _instance: Optional['PredictionSemaphore'] = None _lock = threading.Lock() def __new__(cls, max_concurrent: int = 2): with cls._lock: if cls._instance is None: instance = super().__new__(cls) instance._initialized = False cls._instance = instance return cls._instance def __init__(self, max_concurrent: int = 2): if self._initialized: return self._semaphore = threading.Semaphore(max_concurrent) self._max_concurrent = max_concurrent self._active_count = 0 self._queue_depth = 0 self._stats_lock = threading.Lock() # Metrics self._total_predictions = 0 self._total_timeouts = 0 self._total_wait_time = 0.0 self._initialized = True logger.info(f"PredictionSemaphore initialized: max_concurrent={max_concurrent}") @classmethod def reset(cls): """Reset singleton (for testing).""" with cls._lock: cls._instance = None def acquire(self, timeout: float = 300.0, task_id: str = "") -> bool: """ Acquire a prediction slot. Args: timeout: Maximum wait time in seconds task_id: Optional task identifier for logging Returns: True if acquired, False on timeout """ start_time = time.time() with self._stats_lock: self._queue_depth += 1 try: acquired = self._semaphore.acquire(timeout=timeout) wait_time = time.time() - start_time with self._stats_lock: self._queue_depth -= 1 if acquired: self._active_count += 1 self._total_predictions += 1 self._total_wait_time += wait_time else: self._total_timeouts += 1 if not acquired: logger.warning(f"Prediction semaphore timeout after {timeout}s") return acquired except Exception as e: with self._stats_lock: self._queue_depth -= 1 logger.error(f"Semaphore acquire error: {e}") return False def release(self): """Release a prediction slot.""" with self._stats_lock: if self._active_count > 0: self._active_count -= 1 self._semaphore.release() def get_stats(self) -> Dict[str, Any]: """Get semaphore statistics.""" with self._stats_lock: avg_wait = (self._total_wait_time / self._total_predictions if self._total_predictions > 0 else 0) return { "max_concurrent": self._max_concurrent, "active_predictions": self._active_count, "queue_depth": self._queue_depth, "total_predictions": self._total_predictions, "total_timeouts": self._total_timeouts, "average_wait_seconds": avg_wait } @contextmanager def prediction_context(timeout: float = 300.0, task_id: str = ""): """ Context manager for prediction semaphore. Usage: with prediction_context(timeout=300) as acquired: if acquired: # run prediction """ semaphore = get_prediction_semaphore() acquired = semaphore.acquire(timeout=timeout, task_id=task_id) try: yield acquired finally: if acquired: semaphore.release() # ============================================================================ # Model Manager # ============================================================================ @dataclass class ModelInfo: """Information about a loaded model.""" model_id: str model: Any reference_count: int = 0 loaded_at: datetime = field(default_factory=datetime.now) last_used: datetime = field(default_factory=datetime.now) estimated_memory_mb: float = 0.0 cleanup_callback: Optional[Callable] = None class ModelManager: """ Manages model lifecycle with reference counting and idle cleanup. Features: - Reference-counted model loading - Automatic unload after idle timeout - LRU eviction on memory pressure """ _instance: Optional['ModelManager'] = None _lock = threading.Lock() def __new__(cls, config: Optional[MemoryPolicyConfig] = None): with cls._lock: if cls._instance is None: instance = super().__new__(cls) instance._initialized = False cls._instance = instance return cls._instance def __init__(self, config: Optional[MemoryPolicyConfig] = None): if self._initialized: return self.config = config or MemoryPolicyConfig() self._models: Dict[str, ModelInfo] = {} self._models_lock = threading.Lock() self._monitor = GPUMemoryMonitor(self.config) # Background cleanup thread self._shutdown = threading.Event() self._cleanup_thread = threading.Thread( target=self._cleanup_loop, daemon=True, name="ModelManager-Cleanup" ) self._cleanup_thread.start() self._initialized = True logger.info("ModelManager initialized") @classmethod def reset(cls): """Reset singleton (for testing).""" with cls._lock: if cls._instance is not None: cls._instance.shutdown() cls._instance = None def get_or_load( self, model_id: str, loader_func: Callable[[], Any], estimated_memory_mb: float = 0, cleanup_callback: Optional[Callable] = None ) -> Optional[Any]: """ Get a model, loading it if necessary. Args: model_id: Unique identifier for the model loader_func: Function to load the model if not cached estimated_memory_mb: Estimated GPU memory usage cleanup_callback: Optional cleanup function when unloading Returns: The model, or None if loading failed """ with self._models_lock: # Check if already loaded if model_id in self._models: info = self._models[model_id] info.reference_count += 1 info.last_used = datetime.now() logger.debug(f"Model {model_id} retrieved (refs={info.reference_count})") return info.model # Check memory before loading if estimated_memory_mb > 0: available, msg = self._monitor.check_memory(estimated_memory_mb) if not available: logger.warning(f"Cannot load {model_id}: {msg}") # Try eviction if not self._evict_lru(estimated_memory_mb): return None # Load the model try: logger.info(f"Loading model {model_id}") model = loader_func() self._models[model_id] = ModelInfo( model_id=model_id, model=model, reference_count=1, estimated_memory_mb=estimated_memory_mb, cleanup_callback=cleanup_callback ) logger.info(f"Model {model_id} loaded successfully") return model except Exception as e: logger.error(f"Failed to load model {model_id}: {e}") return None def release(self, model_id: str): """Release a reference to a model.""" with self._models_lock: if model_id in self._models: info = self._models[model_id] info.reference_count = max(0, info.reference_count - 1) logger.debug(f"Model {model_id} released (refs={info.reference_count})") def unload(self, model_id: str, force: bool = False) -> bool: """ Unload a model from memory. Args: model_id: Model to unload force: If True, unload even if references exist Returns: True if unloaded """ with self._models_lock: if model_id not in self._models: return False info = self._models[model_id] if not force and info.reference_count > 0: logger.warning(f"Cannot unload {model_id}: {info.reference_count} refs") return False # Run cleanup callback if info.cleanup_callback: try: info.cleanup_callback(info.model) except Exception as e: logger.error(f"Cleanup callback failed for {model_id}: {e}") # Remove model del self._models[model_id] logger.info(f"Model {model_id} unloaded") # Clear GPU cache self._monitor.clear_cache() return True def _evict_lru(self, required_mb: float) -> bool: """Evict least-recently-used models to free memory.""" freed_mb = 0.0 # Sort by last_used (oldest first) candidates = sorted( [(k, v) for k, v in self._models.items() if v.reference_count == 0], key=lambda x: x[1].last_used ) for model_id, info in candidates: if freed_mb >= required_mb: break if self.unload(model_id, force=True): freed_mb += info.estimated_memory_mb logger.info(f"Evicted {model_id}, freed ~{info.estimated_memory_mb}MB") return freed_mb >= required_mb def _cleanup_loop(self): """Background thread for idle model cleanup.""" while not self._shutdown.is_set(): self._shutdown.wait(self.config.memory_check_interval_seconds) if self._shutdown.is_set(): break self._cleanup_idle_models() def _cleanup_idle_models(self): """Unload models that have been idle too long.""" now = datetime.now() timeout = self.config.model_idle_timeout_seconds with self._models_lock: to_unload = [] for model_id, info in self._models.items(): if info.reference_count > 0: continue idle_seconds = (now - info.last_used).total_seconds() if idle_seconds > timeout: to_unload.append(model_id) for model_id in to_unload: self.unload(model_id) def get_stats(self) -> Dict[str, Any]: """Get model manager statistics.""" with self._models_lock: models_info = {} for model_id, info in self._models.items(): models_info[model_id] = { "reference_count": info.reference_count, "loaded_at": info.loaded_at.isoformat(), "last_used": info.last_used.isoformat(), "estimated_memory_mb": info.estimated_memory_mb } return { "total_models": len(self._models), "models": models_info, "memory": self._monitor.get_stats().__dict__ } def shutdown(self): """Shutdown the model manager.""" logger.info("Shutting down ModelManager") self._shutdown.set() # Unload all models with self._models_lock: for model_id in list(self._models.keys()): self.unload(model_id, force=True) # ============================================================================ # Memory Policy Engine (Unified Interface) # ============================================================================ class MemoryPolicyEngine: """ Unified memory policy engine. Provides a single entry point for all memory management operations: - GPU memory monitoring - Prediction concurrency control - Model lifecycle management """ _instance: Optional['MemoryPolicyEngine'] = None _lock = threading.Lock() def __new__(cls, config: Optional[MemoryPolicyConfig] = None): with cls._lock: if cls._instance is None: instance = super().__new__(cls) instance._initialized = False cls._instance = instance return cls._instance def __init__(self, config: Optional[MemoryPolicyConfig] = None): if self._initialized: return self.config = config or MemoryPolicyConfig() self._monitor = GPUMemoryMonitor(self.config) self._model_manager = ModelManager(self.config) self._prediction_semaphore = PredictionSemaphore( self.config.max_concurrent_predictions ) self._initialized = True logger.info("MemoryPolicyEngine initialized") @classmethod def reset(cls): """Reset singleton (for testing).""" with cls._lock: if cls._instance is not None: cls._instance.shutdown() cls._instance = None ModelManager.reset() PredictionSemaphore.reset() @property def monitor(self) -> GPUMemoryMonitor: """Get the GPU memory monitor.""" return self._monitor @property def model_manager(self) -> ModelManager: """Get the model manager.""" return self._model_manager @property def prediction_semaphore(self) -> PredictionSemaphore: """Get the prediction semaphore.""" return self._prediction_semaphore def check_memory(self, required_mb: float = 0) -> Tuple[bool, str]: """Check if memory is available.""" return self._monitor.check_memory(required_mb) def get_memory_stats(self) -> MemoryStats: """Get current memory statistics.""" return self._monitor.get_stats() def clear_cache(self): """Clear GPU memory caches.""" self._monitor.clear_cache() def get_stats(self) -> Dict[str, Any]: """Get comprehensive statistics.""" return { "memory": self._monitor.get_stats().__dict__, "models": self._model_manager.get_stats(), "predictions": self._prediction_semaphore.get_stats() } def shutdown(self): """Shutdown all components.""" logger.info("Shutting down MemoryPolicyEngine") self._model_manager.shutdown() # ============================================================================ # Convenience Functions # ============================================================================ _engine: Optional[MemoryPolicyEngine] = None def get_memory_policy_engine(config: Optional[MemoryPolicyConfig] = None) -> MemoryPolicyEngine: """Get the global MemoryPolicyEngine instance.""" global _engine if _engine is None: _engine = MemoryPolicyEngine(config) return _engine def get_prediction_semaphore(max_concurrent: int = 2) -> PredictionSemaphore: """Get the global PredictionSemaphore instance.""" return PredictionSemaphore(max_concurrent) def get_model_manager(config: Optional[MemoryPolicyConfig] = None) -> ModelManager: """Get the global ModelManager instance.""" return ModelManager(config) def shutdown_memory_policy(): """Shutdown all memory management components.""" global _engine if _engine is not None: _engine.shutdown() _engine = None MemoryPolicyEngine.reset()