## Backend Changes - **Service Layer Refactoring**: - Add ProcessingOrchestrator for unified document processing - Add PDFTableRenderer for table rendering extraction - Add PDFFontManager for font management with CJK support - Add MemoryPolicyEngine (73% code reduction from MemoryGuard) - **Bug Fixes**: - Fix Direct Track table row span calculation - Fix OCR Track image path handling - Add cell_boxes coordinate validation - Filter out small decorative images - Add covering image detection ## Frontend Changes - **State Management**: - Add TaskStore for centralized task state management - Add localStorage persistence for recent tasks - Add processing state tracking - **Type Consolidation**: - Merge shared types from api.ts to apiV2.ts - Update imports in authStore, uploadStore, ResultsTable, SettingsPage - **Page Integration**: - Integrate TaskStore in ProcessingPage and TaskDetailPage - Update useTaskValidation hook with cache sync ## Testing - Direct Track: edit.pdf (3 pages, 1.281s), edit3.pdf (2 pages, 0.203s) - Cell boxes validation: 43 valid, 0 invalid - Table merging: 12 merged cells verified 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
792 lines
25 KiB
Python
792 lines
25 KiB
Python
"""
|
|
Memory Policy Engine - Simplified memory management for OCR processing.
|
|
|
|
This module consolidates the essential memory management features:
|
|
- GPU memory monitoring
|
|
- Prediction concurrency control
|
|
- Model lifecycle management
|
|
|
|
Removed unused features from the original memory_manager.py:
|
|
- BatchProcessor
|
|
- ProgressiveLoader
|
|
- PriorityOperationQueue
|
|
- RecoveryManager
|
|
- MemoryDumper
|
|
- PrometheusMetrics
|
|
"""
|
|
|
|
import gc
|
|
import logging
|
|
import threading
|
|
import time
|
|
from collections import deque
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ============================================================================
|
|
# Configuration
|
|
# ============================================================================
|
|
|
|
@dataclass
|
|
class MemoryPolicyConfig:
|
|
"""
|
|
Simplified memory policy configuration.
|
|
|
|
Only includes parameters that are actually used in production.
|
|
"""
|
|
# GPU memory thresholds (ratio 0.0-1.0)
|
|
warning_threshold: float = 0.80
|
|
critical_threshold: float = 0.95
|
|
emergency_threshold: float = 0.98
|
|
|
|
# Model management
|
|
model_idle_timeout_seconds: int = 300 # 5 minutes
|
|
memory_check_interval_seconds: int = 30
|
|
|
|
# Concurrency control
|
|
max_concurrent_predictions: int = 2
|
|
prediction_timeout_seconds: float = 300.0
|
|
|
|
# GPU settings
|
|
gpu_memory_limit_mb: int = 6144 # 6GB default
|
|
|
|
|
|
class MemoryBackend(Enum):
|
|
"""Available memory monitoring backends."""
|
|
PYNVML = "pynvml"
|
|
TORCH = "torch"
|
|
PADDLE = "paddle"
|
|
NONE = "none"
|
|
|
|
|
|
# ============================================================================
|
|
# Memory Statistics
|
|
# ============================================================================
|
|
|
|
@dataclass
|
|
class MemoryStats:
|
|
"""Memory usage statistics."""
|
|
gpu_used_mb: float = 0.0
|
|
gpu_free_mb: float = 0.0
|
|
gpu_total_mb: float = 0.0
|
|
gpu_used_ratio: float = 0.0
|
|
cpu_used_mb: float = 0.0
|
|
cpu_available_mb: float = 0.0
|
|
timestamp: datetime = field(default_factory=datetime.now)
|
|
backend: str = "none"
|
|
|
|
|
|
@dataclass
|
|
class MemoryAlert:
|
|
"""Memory alert record."""
|
|
level: str # warning, critical, emergency
|
|
message: str
|
|
stats: MemoryStats
|
|
timestamp: datetime = field(default_factory=datetime.now)
|
|
|
|
|
|
# ============================================================================
|
|
# GPU Memory Monitor
|
|
# ============================================================================
|
|
|
|
class GPUMemoryMonitor:
|
|
"""
|
|
Monitors GPU memory usage with multiple backend support.
|
|
|
|
Priority: pynvml > torch > paddle > none
|
|
"""
|
|
|
|
def __init__(self, config: MemoryPolicyConfig):
|
|
self.config = config
|
|
self._backend: MemoryBackend = MemoryBackend.NONE
|
|
self._nvml_handle = None
|
|
self._history: deque = deque(maxlen=100)
|
|
self._alerts: deque = deque(maxlen=50)
|
|
self._lock = threading.Lock()
|
|
|
|
self._init_backend()
|
|
|
|
def _init_backend(self):
|
|
"""Initialize the best available memory monitoring backend."""
|
|
# Try pynvml first (most accurate)
|
|
try:
|
|
import pynvml
|
|
pynvml.nvmlInit()
|
|
self._nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
|
self._backend = MemoryBackend.PYNVML
|
|
logger.info("GPU memory monitoring using pynvml")
|
|
return
|
|
except Exception as e:
|
|
logger.debug(f"pynvml not available: {e}")
|
|
|
|
# Try torch
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
self._backend = MemoryBackend.TORCH
|
|
logger.info("GPU memory monitoring using torch")
|
|
return
|
|
except Exception as e:
|
|
logger.debug(f"torch CUDA not available: {e}")
|
|
|
|
# Try paddle
|
|
try:
|
|
import paddle
|
|
if paddle.is_compiled_with_cuda():
|
|
self._backend = MemoryBackend.PADDLE
|
|
logger.info("GPU memory monitoring using paddle")
|
|
return
|
|
except Exception as e:
|
|
logger.debug(f"paddle CUDA not available: {e}")
|
|
|
|
logger.warning("No GPU memory monitoring available")
|
|
|
|
def get_stats(self, device_id: int = 0) -> MemoryStats:
|
|
"""Get current memory statistics."""
|
|
stats = MemoryStats(backend=self._backend.value)
|
|
|
|
try:
|
|
if self._backend == MemoryBackend.PYNVML:
|
|
stats = self._get_pynvml_stats(device_id)
|
|
elif self._backend == MemoryBackend.TORCH:
|
|
stats = self._get_torch_stats(device_id)
|
|
elif self._backend == MemoryBackend.PADDLE:
|
|
stats = self._get_paddle_stats(device_id)
|
|
|
|
# Add CPU stats
|
|
stats = self._add_cpu_stats(stats)
|
|
|
|
# Store in history
|
|
with self._lock:
|
|
self._history.append(stats)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get memory stats: {e}")
|
|
|
|
return stats
|
|
|
|
def _get_pynvml_stats(self, device_id: int) -> MemoryStats:
|
|
"""Get stats using pynvml."""
|
|
import pynvml
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
|
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
|
|
return MemoryStats(
|
|
gpu_used_mb=info.used / (1024 * 1024),
|
|
gpu_free_mb=info.free / (1024 * 1024),
|
|
gpu_total_mb=info.total / (1024 * 1024),
|
|
gpu_used_ratio=info.used / info.total if info.total > 0 else 0,
|
|
backend="pynvml"
|
|
)
|
|
|
|
def _get_torch_stats(self, device_id: int) -> MemoryStats:
|
|
"""Get stats using torch."""
|
|
import torch
|
|
allocated = torch.cuda.memory_allocated(device_id)
|
|
reserved = torch.cuda.memory_reserved(device_id)
|
|
total = torch.cuda.get_device_properties(device_id).total_memory
|
|
|
|
return MemoryStats(
|
|
gpu_used_mb=reserved / (1024 * 1024),
|
|
gpu_free_mb=(total - reserved) / (1024 * 1024),
|
|
gpu_total_mb=total / (1024 * 1024),
|
|
gpu_used_ratio=reserved / total if total > 0 else 0,
|
|
backend="torch"
|
|
)
|
|
|
|
def _get_paddle_stats(self, device_id: int) -> MemoryStats:
|
|
"""Get stats using paddle."""
|
|
import paddle
|
|
allocated = paddle.device.cuda.memory_allocated(device_id)
|
|
reserved = paddle.device.cuda.memory_reserved(device_id)
|
|
total = paddle.device.cuda.get_device_properties(device_id).total_memory
|
|
|
|
return MemoryStats(
|
|
gpu_used_mb=reserved / (1024 * 1024),
|
|
gpu_free_mb=(total - reserved) / (1024 * 1024),
|
|
gpu_total_mb=total / (1024 * 1024),
|
|
gpu_used_ratio=reserved / total if total > 0 else 0,
|
|
backend="paddle"
|
|
)
|
|
|
|
def _add_cpu_stats(self, stats: MemoryStats) -> MemoryStats:
|
|
"""Add CPU memory stats."""
|
|
try:
|
|
import psutil
|
|
mem = psutil.virtual_memory()
|
|
stats.cpu_used_mb = mem.used / (1024 * 1024)
|
|
stats.cpu_available_mb = mem.available / (1024 * 1024)
|
|
except Exception:
|
|
pass
|
|
return stats
|
|
|
|
def check_memory(self, required_mb: float = 0, device_id: int = 0) -> Tuple[bool, str]:
|
|
"""
|
|
Check if memory is available.
|
|
|
|
Returns:
|
|
Tuple of (is_available, message)
|
|
"""
|
|
stats = self.get_stats(device_id)
|
|
|
|
# Check thresholds
|
|
if stats.gpu_used_ratio >= self.config.emergency_threshold:
|
|
msg = f"Emergency: GPU at {stats.gpu_used_ratio*100:.1f}%"
|
|
self._add_alert("emergency", msg, stats)
|
|
return False, msg
|
|
|
|
if stats.gpu_used_ratio >= self.config.critical_threshold:
|
|
msg = f"Critical: GPU at {stats.gpu_used_ratio*100:.1f}%"
|
|
self._add_alert("critical", msg, stats)
|
|
return False, msg
|
|
|
|
if stats.gpu_used_ratio >= self.config.warning_threshold:
|
|
msg = f"Warning: GPU at {stats.gpu_used_ratio*100:.1f}%"
|
|
self._add_alert("warning", msg, stats)
|
|
# Warning doesn't block, just logs
|
|
|
|
# Check if required memory is available
|
|
if required_mb > 0 and stats.gpu_free_mb < required_mb:
|
|
msg = f"Insufficient memory: need {required_mb}MB, have {stats.gpu_free_mb:.0f}MB"
|
|
return False, msg
|
|
|
|
return True, "OK"
|
|
|
|
def _add_alert(self, level: str, message: str, stats: MemoryStats):
|
|
"""Add alert to history."""
|
|
with self._lock:
|
|
self._alerts.append(MemoryAlert(
|
|
level=level,
|
|
message=message,
|
|
stats=stats
|
|
))
|
|
log_func = getattr(logger, level if level != "emergency" else "error")
|
|
log_func(message)
|
|
|
|
def clear_cache(self):
|
|
"""Clear GPU memory caches."""
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
import paddle
|
|
if paddle.is_compiled_with_cuda():
|
|
paddle.device.cuda.empty_cache()
|
|
except Exception:
|
|
pass
|
|
|
|
gc.collect()
|
|
|
|
def get_alerts(self, limit: int = 10) -> List[MemoryAlert]:
|
|
"""Get recent alerts."""
|
|
with self._lock:
|
|
return list(self._alerts)[-limit:]
|
|
|
|
|
|
# ============================================================================
|
|
# Prediction Semaphore
|
|
# ============================================================================
|
|
|
|
class PredictionSemaphore:
|
|
"""
|
|
Controls concurrent predictions to prevent GPU OOM.
|
|
|
|
Singleton pattern ensures single point of concurrency control.
|
|
"""
|
|
|
|
_instance: Optional['PredictionSemaphore'] = None
|
|
_lock = threading.Lock()
|
|
|
|
def __new__(cls, max_concurrent: int = 2):
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
instance = super().__new__(cls)
|
|
instance._initialized = False
|
|
cls._instance = instance
|
|
return cls._instance
|
|
|
|
def __init__(self, max_concurrent: int = 2):
|
|
if self._initialized:
|
|
return
|
|
|
|
self._semaphore = threading.Semaphore(max_concurrent)
|
|
self._max_concurrent = max_concurrent
|
|
self._active_count = 0
|
|
self._queue_depth = 0
|
|
self._stats_lock = threading.Lock()
|
|
|
|
# Metrics
|
|
self._total_predictions = 0
|
|
self._total_timeouts = 0
|
|
self._total_wait_time = 0.0
|
|
|
|
self._initialized = True
|
|
logger.info(f"PredictionSemaphore initialized: max_concurrent={max_concurrent}")
|
|
|
|
@classmethod
|
|
def reset(cls):
|
|
"""Reset singleton (for testing)."""
|
|
with cls._lock:
|
|
cls._instance = None
|
|
|
|
def acquire(self, timeout: float = 300.0, task_id: str = "") -> bool:
|
|
"""
|
|
Acquire a prediction slot.
|
|
|
|
Args:
|
|
timeout: Maximum wait time in seconds
|
|
task_id: Optional task identifier for logging
|
|
|
|
Returns:
|
|
True if acquired, False on timeout
|
|
"""
|
|
start_time = time.time()
|
|
|
|
with self._stats_lock:
|
|
self._queue_depth += 1
|
|
|
|
try:
|
|
acquired = self._semaphore.acquire(timeout=timeout)
|
|
|
|
wait_time = time.time() - start_time
|
|
|
|
with self._stats_lock:
|
|
self._queue_depth -= 1
|
|
if acquired:
|
|
self._active_count += 1
|
|
self._total_predictions += 1
|
|
self._total_wait_time += wait_time
|
|
else:
|
|
self._total_timeouts += 1
|
|
|
|
if not acquired:
|
|
logger.warning(f"Prediction semaphore timeout after {timeout}s")
|
|
|
|
return acquired
|
|
|
|
except Exception as e:
|
|
with self._stats_lock:
|
|
self._queue_depth -= 1
|
|
logger.error(f"Semaphore acquire error: {e}")
|
|
return False
|
|
|
|
def release(self):
|
|
"""Release a prediction slot."""
|
|
with self._stats_lock:
|
|
if self._active_count > 0:
|
|
self._active_count -= 1
|
|
self._semaphore.release()
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get semaphore statistics."""
|
|
with self._stats_lock:
|
|
avg_wait = (self._total_wait_time / self._total_predictions
|
|
if self._total_predictions > 0 else 0)
|
|
return {
|
|
"max_concurrent": self._max_concurrent,
|
|
"active_predictions": self._active_count,
|
|
"queue_depth": self._queue_depth,
|
|
"total_predictions": self._total_predictions,
|
|
"total_timeouts": self._total_timeouts,
|
|
"average_wait_seconds": avg_wait
|
|
}
|
|
|
|
|
|
@contextmanager
|
|
def prediction_context(timeout: float = 300.0, task_id: str = ""):
|
|
"""
|
|
Context manager for prediction semaphore.
|
|
|
|
Usage:
|
|
with prediction_context(timeout=300) as acquired:
|
|
if acquired:
|
|
# run prediction
|
|
"""
|
|
semaphore = get_prediction_semaphore()
|
|
acquired = semaphore.acquire(timeout=timeout, task_id=task_id)
|
|
try:
|
|
yield acquired
|
|
finally:
|
|
if acquired:
|
|
semaphore.release()
|
|
|
|
|
|
# ============================================================================
|
|
# Model Manager
|
|
# ============================================================================
|
|
|
|
@dataclass
|
|
class ModelInfo:
|
|
"""Information about a loaded model."""
|
|
model_id: str
|
|
model: Any
|
|
reference_count: int = 0
|
|
loaded_at: datetime = field(default_factory=datetime.now)
|
|
last_used: datetime = field(default_factory=datetime.now)
|
|
estimated_memory_mb: float = 0.0
|
|
cleanup_callback: Optional[Callable] = None
|
|
|
|
|
|
class ModelManager:
|
|
"""
|
|
Manages model lifecycle with reference counting and idle cleanup.
|
|
|
|
Features:
|
|
- Reference-counted model loading
|
|
- Automatic unload after idle timeout
|
|
- LRU eviction on memory pressure
|
|
"""
|
|
|
|
_instance: Optional['ModelManager'] = None
|
|
_lock = threading.Lock()
|
|
|
|
def __new__(cls, config: Optional[MemoryPolicyConfig] = None):
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
instance = super().__new__(cls)
|
|
instance._initialized = False
|
|
cls._instance = instance
|
|
return cls._instance
|
|
|
|
def __init__(self, config: Optional[MemoryPolicyConfig] = None):
|
|
if self._initialized:
|
|
return
|
|
|
|
self.config = config or MemoryPolicyConfig()
|
|
self._models: Dict[str, ModelInfo] = {}
|
|
self._models_lock = threading.Lock()
|
|
self._monitor = GPUMemoryMonitor(self.config)
|
|
|
|
# Background cleanup thread
|
|
self._shutdown = threading.Event()
|
|
self._cleanup_thread = threading.Thread(
|
|
target=self._cleanup_loop,
|
|
daemon=True,
|
|
name="ModelManager-Cleanup"
|
|
)
|
|
self._cleanup_thread.start()
|
|
|
|
self._initialized = True
|
|
logger.info("ModelManager initialized")
|
|
|
|
@classmethod
|
|
def reset(cls):
|
|
"""Reset singleton (for testing)."""
|
|
with cls._lock:
|
|
if cls._instance is not None:
|
|
cls._instance.shutdown()
|
|
cls._instance = None
|
|
|
|
def get_or_load(
|
|
self,
|
|
model_id: str,
|
|
loader_func: Callable[[], Any],
|
|
estimated_memory_mb: float = 0,
|
|
cleanup_callback: Optional[Callable] = None
|
|
) -> Optional[Any]:
|
|
"""
|
|
Get a model, loading it if necessary.
|
|
|
|
Args:
|
|
model_id: Unique identifier for the model
|
|
loader_func: Function to load the model if not cached
|
|
estimated_memory_mb: Estimated GPU memory usage
|
|
cleanup_callback: Optional cleanup function when unloading
|
|
|
|
Returns:
|
|
The model, or None if loading failed
|
|
"""
|
|
with self._models_lock:
|
|
# Check if already loaded
|
|
if model_id in self._models:
|
|
info = self._models[model_id]
|
|
info.reference_count += 1
|
|
info.last_used = datetime.now()
|
|
logger.debug(f"Model {model_id} retrieved (refs={info.reference_count})")
|
|
return info.model
|
|
|
|
# Check memory before loading
|
|
if estimated_memory_mb > 0:
|
|
available, msg = self._monitor.check_memory(estimated_memory_mb)
|
|
if not available:
|
|
logger.warning(f"Cannot load {model_id}: {msg}")
|
|
# Try eviction
|
|
if not self._evict_lru(estimated_memory_mb):
|
|
return None
|
|
|
|
# Load the model
|
|
try:
|
|
logger.info(f"Loading model {model_id}")
|
|
model = loader_func()
|
|
|
|
self._models[model_id] = ModelInfo(
|
|
model_id=model_id,
|
|
model=model,
|
|
reference_count=1,
|
|
estimated_memory_mb=estimated_memory_mb,
|
|
cleanup_callback=cleanup_callback
|
|
)
|
|
|
|
logger.info(f"Model {model_id} loaded successfully")
|
|
return model
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model {model_id}: {e}")
|
|
return None
|
|
|
|
def release(self, model_id: str):
|
|
"""Release a reference to a model."""
|
|
with self._models_lock:
|
|
if model_id in self._models:
|
|
info = self._models[model_id]
|
|
info.reference_count = max(0, info.reference_count - 1)
|
|
logger.debug(f"Model {model_id} released (refs={info.reference_count})")
|
|
|
|
def unload(self, model_id: str, force: bool = False) -> bool:
|
|
"""
|
|
Unload a model from memory.
|
|
|
|
Args:
|
|
model_id: Model to unload
|
|
force: If True, unload even if references exist
|
|
|
|
Returns:
|
|
True if unloaded
|
|
"""
|
|
with self._models_lock:
|
|
if model_id not in self._models:
|
|
return False
|
|
|
|
info = self._models[model_id]
|
|
|
|
if not force and info.reference_count > 0:
|
|
logger.warning(f"Cannot unload {model_id}: {info.reference_count} refs")
|
|
return False
|
|
|
|
# Run cleanup callback
|
|
if info.cleanup_callback:
|
|
try:
|
|
info.cleanup_callback(info.model)
|
|
except Exception as e:
|
|
logger.error(f"Cleanup callback failed for {model_id}: {e}")
|
|
|
|
# Remove model
|
|
del self._models[model_id]
|
|
logger.info(f"Model {model_id} unloaded")
|
|
|
|
# Clear GPU cache
|
|
self._monitor.clear_cache()
|
|
return True
|
|
|
|
def _evict_lru(self, required_mb: float) -> bool:
|
|
"""Evict least-recently-used models to free memory."""
|
|
freed_mb = 0.0
|
|
|
|
# Sort by last_used (oldest first)
|
|
candidates = sorted(
|
|
[(k, v) for k, v in self._models.items() if v.reference_count == 0],
|
|
key=lambda x: x[1].last_used
|
|
)
|
|
|
|
for model_id, info in candidates:
|
|
if freed_mb >= required_mb:
|
|
break
|
|
|
|
if self.unload(model_id, force=True):
|
|
freed_mb += info.estimated_memory_mb
|
|
logger.info(f"Evicted {model_id}, freed ~{info.estimated_memory_mb}MB")
|
|
|
|
return freed_mb >= required_mb
|
|
|
|
def _cleanup_loop(self):
|
|
"""Background thread for idle model cleanup."""
|
|
while not self._shutdown.is_set():
|
|
self._shutdown.wait(self.config.memory_check_interval_seconds)
|
|
|
|
if self._shutdown.is_set():
|
|
break
|
|
|
|
self._cleanup_idle_models()
|
|
|
|
def _cleanup_idle_models(self):
|
|
"""Unload models that have been idle too long."""
|
|
now = datetime.now()
|
|
timeout = self.config.model_idle_timeout_seconds
|
|
|
|
with self._models_lock:
|
|
to_unload = []
|
|
|
|
for model_id, info in self._models.items():
|
|
if info.reference_count > 0:
|
|
continue
|
|
|
|
idle_seconds = (now - info.last_used).total_seconds()
|
|
if idle_seconds > timeout:
|
|
to_unload.append(model_id)
|
|
|
|
for model_id in to_unload:
|
|
self.unload(model_id)
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get model manager statistics."""
|
|
with self._models_lock:
|
|
models_info = {}
|
|
for model_id, info in self._models.items():
|
|
models_info[model_id] = {
|
|
"reference_count": info.reference_count,
|
|
"loaded_at": info.loaded_at.isoformat(),
|
|
"last_used": info.last_used.isoformat(),
|
|
"estimated_memory_mb": info.estimated_memory_mb
|
|
}
|
|
|
|
return {
|
|
"total_models": len(self._models),
|
|
"models": models_info,
|
|
"memory": self._monitor.get_stats().__dict__
|
|
}
|
|
|
|
def shutdown(self):
|
|
"""Shutdown the model manager."""
|
|
logger.info("Shutting down ModelManager")
|
|
self._shutdown.set()
|
|
|
|
# Unload all models
|
|
with self._models_lock:
|
|
for model_id in list(self._models.keys()):
|
|
self.unload(model_id, force=True)
|
|
|
|
|
|
# ============================================================================
|
|
# Memory Policy Engine (Unified Interface)
|
|
# ============================================================================
|
|
|
|
class MemoryPolicyEngine:
|
|
"""
|
|
Unified memory policy engine.
|
|
|
|
Provides a single entry point for all memory management operations:
|
|
- GPU memory monitoring
|
|
- Prediction concurrency control
|
|
- Model lifecycle management
|
|
"""
|
|
|
|
_instance: Optional['MemoryPolicyEngine'] = None
|
|
_lock = threading.Lock()
|
|
|
|
def __new__(cls, config: Optional[MemoryPolicyConfig] = None):
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
instance = super().__new__(cls)
|
|
instance._initialized = False
|
|
cls._instance = instance
|
|
return cls._instance
|
|
|
|
def __init__(self, config: Optional[MemoryPolicyConfig] = None):
|
|
if self._initialized:
|
|
return
|
|
|
|
self.config = config or MemoryPolicyConfig()
|
|
self._monitor = GPUMemoryMonitor(self.config)
|
|
self._model_manager = ModelManager(self.config)
|
|
self._prediction_semaphore = PredictionSemaphore(
|
|
self.config.max_concurrent_predictions
|
|
)
|
|
|
|
self._initialized = True
|
|
logger.info("MemoryPolicyEngine initialized")
|
|
|
|
@classmethod
|
|
def reset(cls):
|
|
"""Reset singleton (for testing)."""
|
|
with cls._lock:
|
|
if cls._instance is not None:
|
|
cls._instance.shutdown()
|
|
cls._instance = None
|
|
ModelManager.reset()
|
|
PredictionSemaphore.reset()
|
|
|
|
@property
|
|
def monitor(self) -> GPUMemoryMonitor:
|
|
"""Get the GPU memory monitor."""
|
|
return self._monitor
|
|
|
|
@property
|
|
def model_manager(self) -> ModelManager:
|
|
"""Get the model manager."""
|
|
return self._model_manager
|
|
|
|
@property
|
|
def prediction_semaphore(self) -> PredictionSemaphore:
|
|
"""Get the prediction semaphore."""
|
|
return self._prediction_semaphore
|
|
|
|
def check_memory(self, required_mb: float = 0) -> Tuple[bool, str]:
|
|
"""Check if memory is available."""
|
|
return self._monitor.check_memory(required_mb)
|
|
|
|
def get_memory_stats(self) -> MemoryStats:
|
|
"""Get current memory statistics."""
|
|
return self._monitor.get_stats()
|
|
|
|
def clear_cache(self):
|
|
"""Clear GPU memory caches."""
|
|
self._monitor.clear_cache()
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get comprehensive statistics."""
|
|
return {
|
|
"memory": self._monitor.get_stats().__dict__,
|
|
"models": self._model_manager.get_stats(),
|
|
"predictions": self._prediction_semaphore.get_stats()
|
|
}
|
|
|
|
def shutdown(self):
|
|
"""Shutdown all components."""
|
|
logger.info("Shutting down MemoryPolicyEngine")
|
|
self._model_manager.shutdown()
|
|
|
|
|
|
# ============================================================================
|
|
# Convenience Functions
|
|
# ============================================================================
|
|
|
|
_engine: Optional[MemoryPolicyEngine] = None
|
|
|
|
|
|
def get_memory_policy_engine(config: Optional[MemoryPolicyConfig] = None) -> MemoryPolicyEngine:
|
|
"""Get the global MemoryPolicyEngine instance."""
|
|
global _engine
|
|
if _engine is None:
|
|
_engine = MemoryPolicyEngine(config)
|
|
return _engine
|
|
|
|
|
|
def get_prediction_semaphore(max_concurrent: int = 2) -> PredictionSemaphore:
|
|
"""Get the global PredictionSemaphore instance."""
|
|
return PredictionSemaphore(max_concurrent)
|
|
|
|
|
|
def get_model_manager(config: Optional[MemoryPolicyConfig] = None) -> ModelManager:
|
|
"""Get the global ModelManager instance."""
|
|
return ModelManager(config)
|
|
|
|
|
|
def shutdown_memory_policy():
|
|
"""Shutdown all memory management components."""
|
|
global _engine
|
|
if _engine is not None:
|
|
_engine.shutdown()
|
|
_engine = None
|
|
MemoryPolicyEngine.reset()
|