feat: refactor dual-track architecture (Phase 1-5)
## Backend Changes - **Service Layer Refactoring**: - Add ProcessingOrchestrator for unified document processing - Add PDFTableRenderer for table rendering extraction - Add PDFFontManager for font management with CJK support - Add MemoryPolicyEngine (73% code reduction from MemoryGuard) - **Bug Fixes**: - Fix Direct Track table row span calculation - Fix OCR Track image path handling - Add cell_boxes coordinate validation - Filter out small decorative images - Add covering image detection ## Frontend Changes - **State Management**: - Add TaskStore for centralized task state management - Add localStorage persistence for recent tasks - Add processing state tracking - **Type Consolidation**: - Merge shared types from api.ts to apiV2.ts - Update imports in authStore, uploadStore, ResultsTable, SettingsPage - **Page Integration**: - Integrate TaskStore in ProcessingPage and TaskDetailPage - Update useTaskValidation hook with cache sync ## Testing - Direct Track: edit.pdf (3 pages, 1.281s), edit3.pdf (2 pages, 0.203s) - Cell boxes validation: 43 valid, 0 invalid - Table merging: 12 merged cells verified 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
791
backend/app/services/memory_policy_engine.py
Normal file
791
backend/app/services/memory_policy_engine.py
Normal file
@@ -0,0 +1,791 @@
|
||||
"""
|
||||
Memory Policy Engine - Simplified memory management for OCR processing.
|
||||
|
||||
This module consolidates the essential memory management features:
|
||||
- GPU memory monitoring
|
||||
- Prediction concurrency control
|
||||
- Model lifecycle management
|
||||
|
||||
Removed unused features from the original memory_manager.py:
|
||||
- BatchProcessor
|
||||
- ProgressiveLoader
|
||||
- PriorityOperationQueue
|
||||
- RecoveryManager
|
||||
- MemoryDumper
|
||||
- PrometheusMetrics
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class MemoryPolicyConfig:
|
||||
"""
|
||||
Simplified memory policy configuration.
|
||||
|
||||
Only includes parameters that are actually used in production.
|
||||
"""
|
||||
# GPU memory thresholds (ratio 0.0-1.0)
|
||||
warning_threshold: float = 0.80
|
||||
critical_threshold: float = 0.95
|
||||
emergency_threshold: float = 0.98
|
||||
|
||||
# Model management
|
||||
model_idle_timeout_seconds: int = 300 # 5 minutes
|
||||
memory_check_interval_seconds: int = 30
|
||||
|
||||
# Concurrency control
|
||||
max_concurrent_predictions: int = 2
|
||||
prediction_timeout_seconds: float = 300.0
|
||||
|
||||
# GPU settings
|
||||
gpu_memory_limit_mb: int = 6144 # 6GB default
|
||||
|
||||
|
||||
class MemoryBackend(Enum):
|
||||
"""Available memory monitoring backends."""
|
||||
PYNVML = "pynvml"
|
||||
TORCH = "torch"
|
||||
PADDLE = "paddle"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Memory Statistics
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class MemoryStats:
|
||||
"""Memory usage statistics."""
|
||||
gpu_used_mb: float = 0.0
|
||||
gpu_free_mb: float = 0.0
|
||||
gpu_total_mb: float = 0.0
|
||||
gpu_used_ratio: float = 0.0
|
||||
cpu_used_mb: float = 0.0
|
||||
cpu_available_mb: float = 0.0
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
backend: str = "none"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryAlert:
|
||||
"""Memory alert record."""
|
||||
level: str # warning, critical, emergency
|
||||
message: str
|
||||
stats: MemoryStats
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# GPU Memory Monitor
|
||||
# ============================================================================
|
||||
|
||||
class GPUMemoryMonitor:
|
||||
"""
|
||||
Monitors GPU memory usage with multiple backend support.
|
||||
|
||||
Priority: pynvml > torch > paddle > none
|
||||
"""
|
||||
|
||||
def __init__(self, config: MemoryPolicyConfig):
|
||||
self.config = config
|
||||
self._backend: MemoryBackend = MemoryBackend.NONE
|
||||
self._nvml_handle = None
|
||||
self._history: deque = deque(maxlen=100)
|
||||
self._alerts: deque = deque(maxlen=50)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._init_backend()
|
||||
|
||||
def _init_backend(self):
|
||||
"""Initialize the best available memory monitoring backend."""
|
||||
# Try pynvml first (most accurate)
|
||||
try:
|
||||
import pynvml
|
||||
pynvml.nvmlInit()
|
||||
self._nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
self._backend = MemoryBackend.PYNVML
|
||||
logger.info("GPU memory monitoring using pynvml")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"pynvml not available: {e}")
|
||||
|
||||
# Try torch
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
self._backend = MemoryBackend.TORCH
|
||||
logger.info("GPU memory monitoring using torch")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"torch CUDA not available: {e}")
|
||||
|
||||
# Try paddle
|
||||
try:
|
||||
import paddle
|
||||
if paddle.is_compiled_with_cuda():
|
||||
self._backend = MemoryBackend.PADDLE
|
||||
logger.info("GPU memory monitoring using paddle")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"paddle CUDA not available: {e}")
|
||||
|
||||
logger.warning("No GPU memory monitoring available")
|
||||
|
||||
def get_stats(self, device_id: int = 0) -> MemoryStats:
|
||||
"""Get current memory statistics."""
|
||||
stats = MemoryStats(backend=self._backend.value)
|
||||
|
||||
try:
|
||||
if self._backend == MemoryBackend.PYNVML:
|
||||
stats = self._get_pynvml_stats(device_id)
|
||||
elif self._backend == MemoryBackend.TORCH:
|
||||
stats = self._get_torch_stats(device_id)
|
||||
elif self._backend == MemoryBackend.PADDLE:
|
||||
stats = self._get_paddle_stats(device_id)
|
||||
|
||||
# Add CPU stats
|
||||
stats = self._add_cpu_stats(stats)
|
||||
|
||||
# Store in history
|
||||
with self._lock:
|
||||
self._history.append(stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get memory stats: {e}")
|
||||
|
||||
return stats
|
||||
|
||||
def _get_pynvml_stats(self, device_id: int) -> MemoryStats:
|
||||
"""Get stats using pynvml."""
|
||||
import pynvml
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
|
||||
return MemoryStats(
|
||||
gpu_used_mb=info.used / (1024 * 1024),
|
||||
gpu_free_mb=info.free / (1024 * 1024),
|
||||
gpu_total_mb=info.total / (1024 * 1024),
|
||||
gpu_used_ratio=info.used / info.total if info.total > 0 else 0,
|
||||
backend="pynvml"
|
||||
)
|
||||
|
||||
def _get_torch_stats(self, device_id: int) -> MemoryStats:
|
||||
"""Get stats using torch."""
|
||||
import torch
|
||||
allocated = torch.cuda.memory_allocated(device_id)
|
||||
reserved = torch.cuda.memory_reserved(device_id)
|
||||
total = torch.cuda.get_device_properties(device_id).total_memory
|
||||
|
||||
return MemoryStats(
|
||||
gpu_used_mb=reserved / (1024 * 1024),
|
||||
gpu_free_mb=(total - reserved) / (1024 * 1024),
|
||||
gpu_total_mb=total / (1024 * 1024),
|
||||
gpu_used_ratio=reserved / total if total > 0 else 0,
|
||||
backend="torch"
|
||||
)
|
||||
|
||||
def _get_paddle_stats(self, device_id: int) -> MemoryStats:
|
||||
"""Get stats using paddle."""
|
||||
import paddle
|
||||
allocated = paddle.device.cuda.memory_allocated(device_id)
|
||||
reserved = paddle.device.cuda.memory_reserved(device_id)
|
||||
total = paddle.device.cuda.get_device_properties(device_id).total_memory
|
||||
|
||||
return MemoryStats(
|
||||
gpu_used_mb=reserved / (1024 * 1024),
|
||||
gpu_free_mb=(total - reserved) / (1024 * 1024),
|
||||
gpu_total_mb=total / (1024 * 1024),
|
||||
gpu_used_ratio=reserved / total if total > 0 else 0,
|
||||
backend="paddle"
|
||||
)
|
||||
|
||||
def _add_cpu_stats(self, stats: MemoryStats) -> MemoryStats:
|
||||
"""Add CPU memory stats."""
|
||||
try:
|
||||
import psutil
|
||||
mem = psutil.virtual_memory()
|
||||
stats.cpu_used_mb = mem.used / (1024 * 1024)
|
||||
stats.cpu_available_mb = mem.available / (1024 * 1024)
|
||||
except Exception:
|
||||
pass
|
||||
return stats
|
||||
|
||||
def check_memory(self, required_mb: float = 0, device_id: int = 0) -> Tuple[bool, str]:
|
||||
"""
|
||||
Check if memory is available.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, message)
|
||||
"""
|
||||
stats = self.get_stats(device_id)
|
||||
|
||||
# Check thresholds
|
||||
if stats.gpu_used_ratio >= self.config.emergency_threshold:
|
||||
msg = f"Emergency: GPU at {stats.gpu_used_ratio*100:.1f}%"
|
||||
self._add_alert("emergency", msg, stats)
|
||||
return False, msg
|
||||
|
||||
if stats.gpu_used_ratio >= self.config.critical_threshold:
|
||||
msg = f"Critical: GPU at {stats.gpu_used_ratio*100:.1f}%"
|
||||
self._add_alert("critical", msg, stats)
|
||||
return False, msg
|
||||
|
||||
if stats.gpu_used_ratio >= self.config.warning_threshold:
|
||||
msg = f"Warning: GPU at {stats.gpu_used_ratio*100:.1f}%"
|
||||
self._add_alert("warning", msg, stats)
|
||||
# Warning doesn't block, just logs
|
||||
|
||||
# Check if required memory is available
|
||||
if required_mb > 0 and stats.gpu_free_mb < required_mb:
|
||||
msg = f"Insufficient memory: need {required_mb}MB, have {stats.gpu_free_mb:.0f}MB"
|
||||
return False, msg
|
||||
|
||||
return True, "OK"
|
||||
|
||||
def _add_alert(self, level: str, message: str, stats: MemoryStats):
|
||||
"""Add alert to history."""
|
||||
with self._lock:
|
||||
self._alerts.append(MemoryAlert(
|
||||
level=level,
|
||||
message=message,
|
||||
stats=stats
|
||||
))
|
||||
log_func = getattr(logger, level if level != "emergency" else "error")
|
||||
log_func(message)
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear GPU memory caches."""
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
import paddle
|
||||
if paddle.is_compiled_with_cuda():
|
||||
paddle.device.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
gc.collect()
|
||||
|
||||
def get_alerts(self, limit: int = 10) -> List[MemoryAlert]:
|
||||
"""Get recent alerts."""
|
||||
with self._lock:
|
||||
return list(self._alerts)[-limit:]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Prediction Semaphore
|
||||
# ============================================================================
|
||||
|
||||
class PredictionSemaphore:
|
||||
"""
|
||||
Controls concurrent predictions to prevent GPU OOM.
|
||||
|
||||
Singleton pattern ensures single point of concurrency control.
|
||||
"""
|
||||
|
||||
_instance: Optional['PredictionSemaphore'] = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, max_concurrent: int = 2):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
instance = super().__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instance = instance
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, max_concurrent: int = 2):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._semaphore = threading.Semaphore(max_concurrent)
|
||||
self._max_concurrent = max_concurrent
|
||||
self._active_count = 0
|
||||
self._queue_depth = 0
|
||||
self._stats_lock = threading.Lock()
|
||||
|
||||
# Metrics
|
||||
self._total_predictions = 0
|
||||
self._total_timeouts = 0
|
||||
self._total_wait_time = 0.0
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"PredictionSemaphore initialized: max_concurrent={max_concurrent}")
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset singleton (for testing)."""
|
||||
with cls._lock:
|
||||
cls._instance = None
|
||||
|
||||
def acquire(self, timeout: float = 300.0, task_id: str = "") -> bool:
|
||||
"""
|
||||
Acquire a prediction slot.
|
||||
|
||||
Args:
|
||||
timeout: Maximum wait time in seconds
|
||||
task_id: Optional task identifier for logging
|
||||
|
||||
Returns:
|
||||
True if acquired, False on timeout
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
with self._stats_lock:
|
||||
self._queue_depth += 1
|
||||
|
||||
try:
|
||||
acquired = self._semaphore.acquire(timeout=timeout)
|
||||
|
||||
wait_time = time.time() - start_time
|
||||
|
||||
with self._stats_lock:
|
||||
self._queue_depth -= 1
|
||||
if acquired:
|
||||
self._active_count += 1
|
||||
self._total_predictions += 1
|
||||
self._total_wait_time += wait_time
|
||||
else:
|
||||
self._total_timeouts += 1
|
||||
|
||||
if not acquired:
|
||||
logger.warning(f"Prediction semaphore timeout after {timeout}s")
|
||||
|
||||
return acquired
|
||||
|
||||
except Exception as e:
|
||||
with self._stats_lock:
|
||||
self._queue_depth -= 1
|
||||
logger.error(f"Semaphore acquire error: {e}")
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
"""Release a prediction slot."""
|
||||
with self._stats_lock:
|
||||
if self._active_count > 0:
|
||||
self._active_count -= 1
|
||||
self._semaphore.release()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get semaphore statistics."""
|
||||
with self._stats_lock:
|
||||
avg_wait = (self._total_wait_time / self._total_predictions
|
||||
if self._total_predictions > 0 else 0)
|
||||
return {
|
||||
"max_concurrent": self._max_concurrent,
|
||||
"active_predictions": self._active_count,
|
||||
"queue_depth": self._queue_depth,
|
||||
"total_predictions": self._total_predictions,
|
||||
"total_timeouts": self._total_timeouts,
|
||||
"average_wait_seconds": avg_wait
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def prediction_context(timeout: float = 300.0, task_id: str = ""):
|
||||
"""
|
||||
Context manager for prediction semaphore.
|
||||
|
||||
Usage:
|
||||
with prediction_context(timeout=300) as acquired:
|
||||
if acquired:
|
||||
# run prediction
|
||||
"""
|
||||
semaphore = get_prediction_semaphore()
|
||||
acquired = semaphore.acquire(timeout=timeout, task_id=task_id)
|
||||
try:
|
||||
yield acquired
|
||||
finally:
|
||||
if acquired:
|
||||
semaphore.release()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Model Manager
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about a loaded model."""
|
||||
model_id: str
|
||||
model: Any
|
||||
reference_count: int = 0
|
||||
loaded_at: datetime = field(default_factory=datetime.now)
|
||||
last_used: datetime = field(default_factory=datetime.now)
|
||||
estimated_memory_mb: float = 0.0
|
||||
cleanup_callback: Optional[Callable] = None
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""
|
||||
Manages model lifecycle with reference counting and idle cleanup.
|
||||
|
||||
Features:
|
||||
- Reference-counted model loading
|
||||
- Automatic unload after idle timeout
|
||||
- LRU eviction on memory pressure
|
||||
"""
|
||||
|
||||
_instance: Optional['ModelManager'] = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, config: Optional[MemoryPolicyConfig] = None):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
instance = super().__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instance = instance
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Optional[MemoryPolicyConfig] = None):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.config = config or MemoryPolicyConfig()
|
||||
self._models: Dict[str, ModelInfo] = {}
|
||||
self._models_lock = threading.Lock()
|
||||
self._monitor = GPUMemoryMonitor(self.config)
|
||||
|
||||
# Background cleanup thread
|
||||
self._shutdown = threading.Event()
|
||||
self._cleanup_thread = threading.Thread(
|
||||
target=self._cleanup_loop,
|
||||
daemon=True,
|
||||
name="ModelManager-Cleanup"
|
||||
)
|
||||
self._cleanup_thread.start()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("ModelManager initialized")
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset singleton (for testing)."""
|
||||
with cls._lock:
|
||||
if cls._instance is not None:
|
||||
cls._instance.shutdown()
|
||||
cls._instance = None
|
||||
|
||||
def get_or_load(
|
||||
self,
|
||||
model_id: str,
|
||||
loader_func: Callable[[], Any],
|
||||
estimated_memory_mb: float = 0,
|
||||
cleanup_callback: Optional[Callable] = None
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Get a model, loading it if necessary.
|
||||
|
||||
Args:
|
||||
model_id: Unique identifier for the model
|
||||
loader_func: Function to load the model if not cached
|
||||
estimated_memory_mb: Estimated GPU memory usage
|
||||
cleanup_callback: Optional cleanup function when unloading
|
||||
|
||||
Returns:
|
||||
The model, or None if loading failed
|
||||
"""
|
||||
with self._models_lock:
|
||||
# Check if already loaded
|
||||
if model_id in self._models:
|
||||
info = self._models[model_id]
|
||||
info.reference_count += 1
|
||||
info.last_used = datetime.now()
|
||||
logger.debug(f"Model {model_id} retrieved (refs={info.reference_count})")
|
||||
return info.model
|
||||
|
||||
# Check memory before loading
|
||||
if estimated_memory_mb > 0:
|
||||
available, msg = self._monitor.check_memory(estimated_memory_mb)
|
||||
if not available:
|
||||
logger.warning(f"Cannot load {model_id}: {msg}")
|
||||
# Try eviction
|
||||
if not self._evict_lru(estimated_memory_mb):
|
||||
return None
|
||||
|
||||
# Load the model
|
||||
try:
|
||||
logger.info(f"Loading model {model_id}")
|
||||
model = loader_func()
|
||||
|
||||
self._models[model_id] = ModelInfo(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
reference_count=1,
|
||||
estimated_memory_mb=estimated_memory_mb,
|
||||
cleanup_callback=cleanup_callback
|
||||
)
|
||||
|
||||
logger.info(f"Model {model_id} loaded successfully")
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model {model_id}: {e}")
|
||||
return None
|
||||
|
||||
def release(self, model_id: str):
|
||||
"""Release a reference to a model."""
|
||||
with self._models_lock:
|
||||
if model_id in self._models:
|
||||
info = self._models[model_id]
|
||||
info.reference_count = max(0, info.reference_count - 1)
|
||||
logger.debug(f"Model {model_id} released (refs={info.reference_count})")
|
||||
|
||||
def unload(self, model_id: str, force: bool = False) -> bool:
|
||||
"""
|
||||
Unload a model from memory.
|
||||
|
||||
Args:
|
||||
model_id: Model to unload
|
||||
force: If True, unload even if references exist
|
||||
|
||||
Returns:
|
||||
True if unloaded
|
||||
"""
|
||||
with self._models_lock:
|
||||
if model_id not in self._models:
|
||||
return False
|
||||
|
||||
info = self._models[model_id]
|
||||
|
||||
if not force and info.reference_count > 0:
|
||||
logger.warning(f"Cannot unload {model_id}: {info.reference_count} refs")
|
||||
return False
|
||||
|
||||
# Run cleanup callback
|
||||
if info.cleanup_callback:
|
||||
try:
|
||||
info.cleanup_callback(info.model)
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup callback failed for {model_id}: {e}")
|
||||
|
||||
# Remove model
|
||||
del self._models[model_id]
|
||||
logger.info(f"Model {model_id} unloaded")
|
||||
|
||||
# Clear GPU cache
|
||||
self._monitor.clear_cache()
|
||||
return True
|
||||
|
||||
def _evict_lru(self, required_mb: float) -> bool:
|
||||
"""Evict least-recently-used models to free memory."""
|
||||
freed_mb = 0.0
|
||||
|
||||
# Sort by last_used (oldest first)
|
||||
candidates = sorted(
|
||||
[(k, v) for k, v in self._models.items() if v.reference_count == 0],
|
||||
key=lambda x: x[1].last_used
|
||||
)
|
||||
|
||||
for model_id, info in candidates:
|
||||
if freed_mb >= required_mb:
|
||||
break
|
||||
|
||||
if self.unload(model_id, force=True):
|
||||
freed_mb += info.estimated_memory_mb
|
||||
logger.info(f"Evicted {model_id}, freed ~{info.estimated_memory_mb}MB")
|
||||
|
||||
return freed_mb >= required_mb
|
||||
|
||||
def _cleanup_loop(self):
|
||||
"""Background thread for idle model cleanup."""
|
||||
while not self._shutdown.is_set():
|
||||
self._shutdown.wait(self.config.memory_check_interval_seconds)
|
||||
|
||||
if self._shutdown.is_set():
|
||||
break
|
||||
|
||||
self._cleanup_idle_models()
|
||||
|
||||
def _cleanup_idle_models(self):
|
||||
"""Unload models that have been idle too long."""
|
||||
now = datetime.now()
|
||||
timeout = self.config.model_idle_timeout_seconds
|
||||
|
||||
with self._models_lock:
|
||||
to_unload = []
|
||||
|
||||
for model_id, info in self._models.items():
|
||||
if info.reference_count > 0:
|
||||
continue
|
||||
|
||||
idle_seconds = (now - info.last_used).total_seconds()
|
||||
if idle_seconds > timeout:
|
||||
to_unload.append(model_id)
|
||||
|
||||
for model_id in to_unload:
|
||||
self.unload(model_id)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get model manager statistics."""
|
||||
with self._models_lock:
|
||||
models_info = {}
|
||||
for model_id, info in self._models.items():
|
||||
models_info[model_id] = {
|
||||
"reference_count": info.reference_count,
|
||||
"loaded_at": info.loaded_at.isoformat(),
|
||||
"last_used": info.last_used.isoformat(),
|
||||
"estimated_memory_mb": info.estimated_memory_mb
|
||||
}
|
||||
|
||||
return {
|
||||
"total_models": len(self._models),
|
||||
"models": models_info,
|
||||
"memory": self._monitor.get_stats().__dict__
|
||||
}
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the model manager."""
|
||||
logger.info("Shutting down ModelManager")
|
||||
self._shutdown.set()
|
||||
|
||||
# Unload all models
|
||||
with self._models_lock:
|
||||
for model_id in list(self._models.keys()):
|
||||
self.unload(model_id, force=True)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Memory Policy Engine (Unified Interface)
|
||||
# ============================================================================
|
||||
|
||||
class MemoryPolicyEngine:
|
||||
"""
|
||||
Unified memory policy engine.
|
||||
|
||||
Provides a single entry point for all memory management operations:
|
||||
- GPU memory monitoring
|
||||
- Prediction concurrency control
|
||||
- Model lifecycle management
|
||||
"""
|
||||
|
||||
_instance: Optional['MemoryPolicyEngine'] = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, config: Optional[MemoryPolicyConfig] = None):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
instance = super().__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instance = instance
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Optional[MemoryPolicyConfig] = None):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.config = config or MemoryPolicyConfig()
|
||||
self._monitor = GPUMemoryMonitor(self.config)
|
||||
self._model_manager = ModelManager(self.config)
|
||||
self._prediction_semaphore = PredictionSemaphore(
|
||||
self.config.max_concurrent_predictions
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("MemoryPolicyEngine initialized")
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset singleton (for testing)."""
|
||||
with cls._lock:
|
||||
if cls._instance is not None:
|
||||
cls._instance.shutdown()
|
||||
cls._instance = None
|
||||
ModelManager.reset()
|
||||
PredictionSemaphore.reset()
|
||||
|
||||
@property
|
||||
def monitor(self) -> GPUMemoryMonitor:
|
||||
"""Get the GPU memory monitor."""
|
||||
return self._monitor
|
||||
|
||||
@property
|
||||
def model_manager(self) -> ModelManager:
|
||||
"""Get the model manager."""
|
||||
return self._model_manager
|
||||
|
||||
@property
|
||||
def prediction_semaphore(self) -> PredictionSemaphore:
|
||||
"""Get the prediction semaphore."""
|
||||
return self._prediction_semaphore
|
||||
|
||||
def check_memory(self, required_mb: float = 0) -> Tuple[bool, str]:
|
||||
"""Check if memory is available."""
|
||||
return self._monitor.check_memory(required_mb)
|
||||
|
||||
def get_memory_stats(self) -> MemoryStats:
|
||||
"""Get current memory statistics."""
|
||||
return self._monitor.get_stats()
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear GPU memory caches."""
|
||||
self._monitor.clear_cache()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive statistics."""
|
||||
return {
|
||||
"memory": self._monitor.get_stats().__dict__,
|
||||
"models": self._model_manager.get_stats(),
|
||||
"predictions": self._prediction_semaphore.get_stats()
|
||||
}
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown all components."""
|
||||
logger.info("Shutting down MemoryPolicyEngine")
|
||||
self._model_manager.shutdown()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Convenience Functions
|
||||
# ============================================================================
|
||||
|
||||
_engine: Optional[MemoryPolicyEngine] = None
|
||||
|
||||
|
||||
def get_memory_policy_engine(config: Optional[MemoryPolicyConfig] = None) -> MemoryPolicyEngine:
|
||||
"""Get the global MemoryPolicyEngine instance."""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = MemoryPolicyEngine(config)
|
||||
return _engine
|
||||
|
||||
|
||||
def get_prediction_semaphore(max_concurrent: int = 2) -> PredictionSemaphore:
|
||||
"""Get the global PredictionSemaphore instance."""
|
||||
return PredictionSemaphore(max_concurrent)
|
||||
|
||||
|
||||
def get_model_manager(config: Optional[MemoryPolicyConfig] = None) -> ModelManager:
|
||||
"""Get the global ModelManager instance."""
|
||||
return ModelManager(config)
|
||||
|
||||
|
||||
def shutdown_memory_policy():
|
||||
"""Shutdown all memory management components."""
|
||||
global _engine
|
||||
if _engine is not None:
|
||||
_engine.shutdown()
|
||||
_engine = None
|
||||
MemoryPolicyEngine.reset()
|
||||
Reference in New Issue
Block a user