Backend: - Add hybrid image extraction for Direct track (inline image blocks) - Add render_inline_image_regions() fallback when OCR doesn't find images - Add check_document_for_missing_images() for detecting missing images - Add memory management system (MemoryGuard, ModelManager, ServicePool) - Update pdf_generator_service to handle HYBRID processing track - Add ElementType.LOGO for logo extraction Frontend: - Fix PDF viewer re-rendering issues with memoization - Add TaskNotFound component and useTaskValidation hook - Disable StrictMode due to react-pdf incompatibility - Fix task detail and results page loading states 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
2270 lines
76 KiB
Python
2270 lines
76 KiB
Python
"""
|
|
Tool_OCR - Memory Management System
|
|
Provides centralized model lifecycle management with reference counting,
|
|
idle timeout, and GPU memory monitoring.
|
|
"""
|
|
|
|
import asyncio
|
|
import gc
|
|
import logging
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
from weakref import WeakValueDictionary
|
|
|
|
import paddle
|
|
|
|
# Optional torch import for additional GPU memory management
|
|
try:
|
|
import torch
|
|
TORCH_AVAILABLE = True
|
|
except ImportError:
|
|
TORCH_AVAILABLE = False
|
|
|
|
# Optional pynvml import for NVIDIA GPU monitoring
|
|
try:
|
|
import pynvml
|
|
PYNVML_AVAILABLE = True
|
|
except ImportError:
|
|
PYNVML_AVAILABLE = False
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MemoryBackend(Enum):
|
|
"""Available memory query backends"""
|
|
PADDLE = "paddle"
|
|
TORCH = "torch"
|
|
PYNVML = "pynvml"
|
|
NONE = "none"
|
|
|
|
|
|
@dataclass
|
|
class MemoryStats:
|
|
"""GPU/CPU memory statistics"""
|
|
gpu_used_mb: float = 0.0
|
|
gpu_free_mb: float = 0.0
|
|
gpu_total_mb: float = 0.0
|
|
gpu_used_ratio: float = 0.0
|
|
cpu_used_mb: float = 0.0
|
|
cpu_available_mb: float = 0.0
|
|
timestamp: float = field(default_factory=time.time)
|
|
backend: MemoryBackend = MemoryBackend.NONE
|
|
|
|
|
|
@dataclass
|
|
class ModelEntry:
|
|
"""Entry for a managed model"""
|
|
model: Any
|
|
model_id: str
|
|
ref_count: int = 0
|
|
last_used: float = field(default_factory=time.time)
|
|
created_at: float = field(default_factory=time.time)
|
|
estimated_memory_mb: float = 0.0
|
|
is_loading: bool = False
|
|
cleanup_callback: Optional[Callable] = None
|
|
|
|
|
|
class MemoryConfig:
|
|
"""Configuration for memory management"""
|
|
|
|
def __init__(
|
|
self,
|
|
warning_threshold: float = 0.80,
|
|
critical_threshold: float = 0.95,
|
|
emergency_threshold: float = 0.98,
|
|
model_idle_timeout_seconds: int = 300,
|
|
memory_check_interval_seconds: int = 30,
|
|
enable_auto_cleanup: bool = True,
|
|
enable_emergency_cleanup: bool = True,
|
|
max_concurrent_predictions: int = 2,
|
|
enable_cpu_fallback: bool = True,
|
|
gpu_memory_limit_mb: int = 6144,
|
|
):
|
|
self.warning_threshold = warning_threshold
|
|
self.critical_threshold = critical_threshold
|
|
self.emergency_threshold = emergency_threshold
|
|
self.model_idle_timeout_seconds = model_idle_timeout_seconds
|
|
self.memory_check_interval_seconds = memory_check_interval_seconds
|
|
self.enable_auto_cleanup = enable_auto_cleanup
|
|
self.enable_emergency_cleanup = enable_emergency_cleanup
|
|
self.max_concurrent_predictions = max_concurrent_predictions
|
|
self.enable_cpu_fallback = enable_cpu_fallback
|
|
self.gpu_memory_limit_mb = gpu_memory_limit_mb
|
|
|
|
|
|
class MemoryGuard:
|
|
"""
|
|
Monitor GPU/CPU memory usage and trigger preventive actions.
|
|
|
|
Supports multiple backends: paddle.device.cuda, pynvml, torch
|
|
"""
|
|
|
|
def __init__(self, config: Optional[MemoryConfig] = None):
|
|
self.config = config or MemoryConfig()
|
|
self.backend = self._detect_backend()
|
|
self._history: List[MemoryStats] = []
|
|
self._max_history = 100
|
|
self._alerts: List[Dict] = []
|
|
self._lock = threading.Lock()
|
|
|
|
# Initialize pynvml if available
|
|
self._nvml_handle = None
|
|
if self.backend == MemoryBackend.PYNVML:
|
|
self._init_pynvml()
|
|
|
|
logger.info(f"MemoryGuard initialized with backend: {self.backend.value}")
|
|
|
|
def _detect_backend(self) -> MemoryBackend:
|
|
"""Detect the best available memory query backend"""
|
|
# Prefer pynvml for accurate GPU memory info
|
|
if PYNVML_AVAILABLE:
|
|
try:
|
|
pynvml.nvmlInit()
|
|
pynvml.nvmlShutdown()
|
|
return MemoryBackend.PYNVML
|
|
except Exception:
|
|
pass
|
|
|
|
# Fall back to torch if available
|
|
if TORCH_AVAILABLE and torch.cuda.is_available():
|
|
return MemoryBackend.TORCH
|
|
|
|
# Fall back to paddle
|
|
if paddle.is_compiled_with_cuda():
|
|
try:
|
|
if paddle.device.cuda.device_count() > 0:
|
|
return MemoryBackend.PADDLE
|
|
except Exception:
|
|
pass
|
|
|
|
return MemoryBackend.NONE
|
|
|
|
def _init_pynvml(self):
|
|
"""Initialize pynvml for GPU monitoring"""
|
|
try:
|
|
pynvml.nvmlInit()
|
|
self._nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
|
logger.info("pynvml initialized for GPU monitoring")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize pynvml: {e}")
|
|
self.backend = MemoryBackend.PADDLE if paddle.is_compiled_with_cuda() else MemoryBackend.NONE
|
|
|
|
def get_memory_stats(self, device_id: int = 0) -> MemoryStats:
|
|
"""
|
|
Get current memory statistics.
|
|
|
|
Args:
|
|
device_id: GPU device ID (default 0)
|
|
|
|
Returns:
|
|
MemoryStats with current memory usage
|
|
"""
|
|
stats = MemoryStats(backend=self.backend)
|
|
|
|
try:
|
|
if self.backend == MemoryBackend.PYNVML and self._nvml_handle:
|
|
mem_info = pynvml.nvmlDeviceGetMemoryInfo(self._nvml_handle)
|
|
stats.gpu_total_mb = mem_info.total / (1024**2)
|
|
stats.gpu_used_mb = mem_info.used / (1024**2)
|
|
stats.gpu_free_mb = mem_info.free / (1024**2)
|
|
stats.gpu_used_ratio = mem_info.used / mem_info.total if mem_info.total > 0 else 0
|
|
|
|
elif self.backend == MemoryBackend.TORCH:
|
|
stats.gpu_total_mb = torch.cuda.get_device_properties(device_id).total_memory / (1024**2)
|
|
stats.gpu_used_mb = torch.cuda.memory_allocated(device_id) / (1024**2)
|
|
stats.gpu_free_mb = (torch.cuda.get_device_properties(device_id).total_memory -
|
|
torch.cuda.memory_allocated(device_id)) / (1024**2)
|
|
stats.gpu_used_ratio = stats.gpu_used_mb / stats.gpu_total_mb if stats.gpu_total_mb > 0 else 0
|
|
|
|
elif self.backend == MemoryBackend.PADDLE:
|
|
# Paddle doesn't provide total memory directly, use allocated/reserved
|
|
stats.gpu_used_mb = paddle.device.cuda.memory_allocated(device_id) / (1024**2)
|
|
stats.gpu_free_mb = 0 # Not directly available
|
|
stats.gpu_total_mb = self.config.gpu_memory_limit_mb
|
|
stats.gpu_used_ratio = stats.gpu_used_mb / stats.gpu_total_mb if stats.gpu_total_mb > 0 else 0
|
|
|
|
# Get CPU memory info
|
|
try:
|
|
import psutil
|
|
mem = psutil.virtual_memory()
|
|
stats.cpu_used_mb = mem.used / (1024**2)
|
|
stats.cpu_available_mb = mem.available / (1024**2)
|
|
except ImportError:
|
|
pass
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get memory stats: {e}")
|
|
|
|
# Store in history
|
|
with self._lock:
|
|
self._history.append(stats)
|
|
if len(self._history) > self._max_history:
|
|
self._history.pop(0)
|
|
|
|
return stats
|
|
|
|
def check_memory(self, required_mb: int = 0, device_id: int = 0) -> Tuple[bool, MemoryStats]:
|
|
"""
|
|
Check if sufficient GPU memory is available.
|
|
|
|
Args:
|
|
required_mb: Required memory in MB (0 for just checking thresholds)
|
|
device_id: GPU device ID
|
|
|
|
Returns:
|
|
Tuple of (is_available, current_stats)
|
|
"""
|
|
stats = self.get_memory_stats(device_id)
|
|
|
|
# Check if we have enough free memory
|
|
if required_mb > 0 and stats.gpu_free_mb > 0:
|
|
if stats.gpu_free_mb < required_mb:
|
|
self._add_alert("insufficient_memory",
|
|
f"Required {required_mb}MB but only {stats.gpu_free_mb:.0f}MB available")
|
|
return False, stats
|
|
|
|
# Check threshold levels
|
|
if stats.gpu_used_ratio > self.config.emergency_threshold:
|
|
self._add_alert("emergency",
|
|
f"GPU memory at {stats.gpu_used_ratio*100:.1f}% (emergency threshold)")
|
|
return False, stats
|
|
|
|
if stats.gpu_used_ratio > self.config.critical_threshold:
|
|
self._add_alert("critical",
|
|
f"GPU memory at {stats.gpu_used_ratio*100:.1f}% (critical threshold)")
|
|
return False, stats
|
|
|
|
if stats.gpu_used_ratio > self.config.warning_threshold:
|
|
self._add_alert("warning",
|
|
f"GPU memory at {stats.gpu_used_ratio*100:.1f}% (warning threshold)")
|
|
|
|
return True, stats
|
|
|
|
def _add_alert(self, level: str, message: str):
|
|
"""Add an alert to the alert history"""
|
|
alert = {
|
|
"level": level,
|
|
"message": message,
|
|
"timestamp": time.time()
|
|
}
|
|
with self._lock:
|
|
self._alerts.append(alert)
|
|
# Keep last 50 alerts
|
|
if len(self._alerts) > 50:
|
|
self._alerts.pop(0)
|
|
|
|
if level == "emergency":
|
|
logger.error(f"MEMORY ALERT [{level}]: {message}")
|
|
elif level == "critical":
|
|
logger.warning(f"MEMORY ALERT [{level}]: {message}")
|
|
else:
|
|
logger.info(f"MEMORY ALERT [{level}]: {message}")
|
|
|
|
def get_alerts(self, since_timestamp: float = 0) -> List[Dict]:
|
|
"""Get alerts since a given timestamp"""
|
|
with self._lock:
|
|
return [a for a in self._alerts if a["timestamp"] > since_timestamp]
|
|
|
|
def clear_gpu_cache(self):
|
|
"""Clear GPU memory cache"""
|
|
try:
|
|
if TORCH_AVAILABLE and torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
logger.debug("Cleared PyTorch GPU cache")
|
|
|
|
if paddle.is_compiled_with_cuda():
|
|
paddle.device.cuda.empty_cache()
|
|
logger.debug("Cleared PaddlePaddle GPU cache")
|
|
|
|
gc.collect()
|
|
|
|
except Exception as e:
|
|
logger.warning(f"GPU cache clear failed: {e}")
|
|
|
|
def shutdown(self):
|
|
"""Clean up resources"""
|
|
if PYNVML_AVAILABLE and self._nvml_handle:
|
|
try:
|
|
pynvml.nvmlShutdown()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
class PredictionSemaphore:
|
|
"""
|
|
Semaphore for controlling concurrent PP-StructureV3 predictions.
|
|
|
|
PP-StructureV3.predict() is memory-intensive. Running multiple predictions
|
|
simultaneously can cause OOM errors. This class limits concurrent predictions
|
|
and provides timeout handling.
|
|
"""
|
|
|
|
_instance = None
|
|
_lock = threading.Lock()
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
"""Singleton pattern - ensure only one PredictionSemaphore exists"""
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance._initialized = False
|
|
return cls._instance
|
|
|
|
def __init__(self, max_concurrent: int = 2, default_timeout: float = 300.0):
|
|
if self._initialized:
|
|
return
|
|
|
|
self._max_concurrent = max_concurrent
|
|
self._default_timeout = default_timeout
|
|
self._semaphore = threading.Semaphore(max_concurrent)
|
|
self._condition = threading.Condition()
|
|
self._queue_depth = 0
|
|
self._active_predictions = 0
|
|
|
|
# Metrics
|
|
self._total_predictions = 0
|
|
self._total_timeouts = 0
|
|
self._total_wait_time = 0.0
|
|
self._metrics_lock = threading.Lock()
|
|
|
|
self._initialized = True
|
|
logger.info(f"PredictionSemaphore initialized (max_concurrent={max_concurrent})")
|
|
|
|
def acquire(self, timeout: Optional[float] = None, task_id: Optional[str] = None) -> bool:
|
|
"""
|
|
Acquire a prediction slot.
|
|
|
|
Args:
|
|
timeout: Timeout in seconds (None for default, 0 for non-blocking)
|
|
task_id: Optional task identifier for logging
|
|
|
|
Returns:
|
|
True if acquired, False if timed out
|
|
"""
|
|
timeout = timeout if timeout is not None else self._default_timeout
|
|
start_time = time.time()
|
|
|
|
with self._condition:
|
|
self._queue_depth += 1
|
|
|
|
task_str = f" for task {task_id}" if task_id else ""
|
|
logger.debug(f"Waiting for prediction slot{task_str} (queue_depth={self._queue_depth})")
|
|
|
|
try:
|
|
acquired = self._semaphore.acquire(timeout=timeout if timeout > 0 else None)
|
|
|
|
wait_time = time.time() - start_time
|
|
with self._metrics_lock:
|
|
self._total_wait_time += wait_time
|
|
if acquired:
|
|
self._total_predictions += 1
|
|
self._active_predictions += 1
|
|
else:
|
|
self._total_timeouts += 1
|
|
|
|
with self._condition:
|
|
self._queue_depth -= 1
|
|
|
|
if acquired:
|
|
logger.debug(f"Prediction slot acquired{task_str} (waited {wait_time:.2f}s, active={self._active_predictions})")
|
|
else:
|
|
logger.warning(f"Prediction slot timeout{task_str} after {timeout}s")
|
|
|
|
return acquired
|
|
|
|
except Exception as e:
|
|
with self._condition:
|
|
self._queue_depth -= 1
|
|
logger.error(f"Error acquiring prediction slot: {e}")
|
|
return False
|
|
|
|
def release(self, task_id: Optional[str] = None):
|
|
"""
|
|
Release a prediction slot.
|
|
|
|
Args:
|
|
task_id: Optional task identifier for logging
|
|
"""
|
|
self._semaphore.release()
|
|
|
|
with self._metrics_lock:
|
|
self._active_predictions = max(0, self._active_predictions - 1)
|
|
|
|
task_str = f" for task {task_id}" if task_id else ""
|
|
logger.debug(f"Prediction slot released{task_str} (active={self._active_predictions})")
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get prediction semaphore statistics"""
|
|
with self._metrics_lock:
|
|
avg_wait = self._total_wait_time / max(1, self._total_predictions)
|
|
return {
|
|
"max_concurrent": self._max_concurrent,
|
|
"active_predictions": self._active_predictions,
|
|
"queue_depth": self._queue_depth,
|
|
"total_predictions": self._total_predictions,
|
|
"total_timeouts": self._total_timeouts,
|
|
"average_wait_seconds": round(avg_wait, 3),
|
|
}
|
|
|
|
def reset_metrics(self):
|
|
"""Reset metrics counters"""
|
|
with self._metrics_lock:
|
|
self._total_predictions = 0
|
|
self._total_timeouts = 0
|
|
self._total_wait_time = 0.0
|
|
|
|
|
|
class PredictionContext:
|
|
"""
|
|
Context manager for PP-StructureV3 predictions with semaphore control.
|
|
|
|
Usage:
|
|
with prediction_context(task_id="task_123") as acquired:
|
|
if acquired:
|
|
result = structure_engine.predict(image_path)
|
|
else:
|
|
# Handle timeout
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
semaphore: PredictionSemaphore,
|
|
timeout: Optional[float] = None,
|
|
task_id: Optional[str] = None
|
|
):
|
|
self._semaphore = semaphore
|
|
self._timeout = timeout
|
|
self._task_id = task_id
|
|
self._acquired = False
|
|
|
|
def __enter__(self) -> bool:
|
|
self._acquired = self._semaphore.acquire(
|
|
timeout=self._timeout,
|
|
task_id=self._task_id
|
|
)
|
|
return self._acquired
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self._acquired:
|
|
self._semaphore.release(task_id=self._task_id)
|
|
return False # Don't suppress exceptions
|
|
|
|
|
|
# Global prediction semaphore instance
|
|
_prediction_semaphore: Optional[PredictionSemaphore] = None
|
|
|
|
|
|
def get_prediction_semaphore(max_concurrent: Optional[int] = None) -> PredictionSemaphore:
|
|
"""
|
|
Get the global PredictionSemaphore instance.
|
|
|
|
Args:
|
|
max_concurrent: Max concurrent predictions (only used on first call)
|
|
|
|
Returns:
|
|
PredictionSemaphore singleton instance
|
|
"""
|
|
global _prediction_semaphore
|
|
if _prediction_semaphore is None:
|
|
from app.core.config import settings
|
|
max_conc = max_concurrent or settings.max_concurrent_predictions
|
|
_prediction_semaphore = PredictionSemaphore(max_concurrent=max_conc)
|
|
return _prediction_semaphore
|
|
|
|
|
|
def shutdown_prediction_semaphore():
|
|
"""Reset the global PredictionSemaphore instance"""
|
|
global _prediction_semaphore
|
|
if _prediction_semaphore is not None:
|
|
# Reset singleton for clean state
|
|
PredictionSemaphore._instance = None
|
|
PredictionSemaphore._lock = threading.Lock()
|
|
_prediction_semaphore = None
|
|
|
|
|
|
def prediction_context(
|
|
timeout: Optional[float] = None,
|
|
task_id: Optional[str] = None
|
|
) -> PredictionContext:
|
|
"""
|
|
Create a prediction context manager.
|
|
|
|
Args:
|
|
timeout: Timeout in seconds for acquiring slot
|
|
task_id: Optional task identifier for logging
|
|
|
|
Returns:
|
|
PredictionContext context manager
|
|
"""
|
|
semaphore = get_prediction_semaphore()
|
|
return PredictionContext(semaphore, timeout, task_id)
|
|
|
|
|
|
class ModelManager:
|
|
"""
|
|
Centralized model lifecycle management with reference counting and idle timeout.
|
|
|
|
Features:
|
|
- Reference counting for shared model instances
|
|
- Idle timeout for automatic unloading
|
|
- LRU eviction when memory pressure
|
|
- Thread-safe operations
|
|
"""
|
|
|
|
_instance = None
|
|
_lock = threading.Lock()
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
"""Singleton pattern - ensure only one ModelManager exists"""
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance._initialized = False
|
|
return cls._instance
|
|
|
|
def __init__(self, config: Optional[MemoryConfig] = None):
|
|
if self._initialized:
|
|
return
|
|
|
|
self.config = config or MemoryConfig()
|
|
self.models: Dict[str, ModelEntry] = {}
|
|
self.memory_guard = MemoryGuard(self.config)
|
|
self._model_lock = threading.RLock()
|
|
self._loading_locks: Dict[str, threading.Lock] = {}
|
|
|
|
# Start background timeout monitor
|
|
self._monitor_running = True
|
|
self._monitor_thread = threading.Thread(
|
|
target=self._timeout_monitor_loop,
|
|
daemon=True,
|
|
name="ModelManager-TimeoutMonitor"
|
|
)
|
|
self._monitor_thread.start()
|
|
|
|
self._initialized = True
|
|
logger.info("ModelManager initialized")
|
|
|
|
def _timeout_monitor_loop(self):
|
|
"""Background thread to monitor and unload idle models"""
|
|
while self._monitor_running:
|
|
try:
|
|
time.sleep(self.config.memory_check_interval_seconds)
|
|
if self.config.enable_auto_cleanup:
|
|
self._cleanup_idle_models()
|
|
except Exception as e:
|
|
logger.error(f"Error in timeout monitor: {e}")
|
|
|
|
def _cleanup_idle_models(self):
|
|
"""Unload models that have been idle longer than the timeout"""
|
|
current_time = time.time()
|
|
models_to_unload = []
|
|
|
|
with self._model_lock:
|
|
for model_id, entry in self.models.items():
|
|
# Only unload if no active references and idle timeout exceeded
|
|
if entry.ref_count <= 0:
|
|
idle_time = current_time - entry.last_used
|
|
if idle_time > self.config.model_idle_timeout_seconds:
|
|
models_to_unload.append(model_id)
|
|
|
|
for model_id in models_to_unload:
|
|
self.unload_model(model_id, force=False)
|
|
|
|
def get_or_load_model(
|
|
self,
|
|
model_id: str,
|
|
loader_func: Callable[[], Any],
|
|
estimated_memory_mb: float = 0,
|
|
cleanup_callback: Optional[Callable] = None
|
|
) -> Any:
|
|
"""
|
|
Get a model by ID, loading it if not already loaded.
|
|
|
|
Args:
|
|
model_id: Unique identifier for the model
|
|
loader_func: Function to call to load the model if not cached
|
|
estimated_memory_mb: Estimated memory usage for this model
|
|
cleanup_callback: Optional callback to run before unloading
|
|
|
|
Returns:
|
|
The model instance
|
|
"""
|
|
with self._model_lock:
|
|
# Check if model is already loaded
|
|
if model_id in self.models:
|
|
entry = self.models[model_id]
|
|
if not entry.is_loading:
|
|
entry.ref_count += 1
|
|
entry.last_used = time.time()
|
|
logger.debug(f"Model {model_id} acquired (ref_count={entry.ref_count})")
|
|
return entry.model
|
|
|
|
# Create loading lock for this model if not exists
|
|
if model_id not in self._loading_locks:
|
|
self._loading_locks[model_id] = threading.Lock()
|
|
|
|
# Load model outside the main lock to allow concurrent operations
|
|
loading_lock = self._loading_locks[model_id]
|
|
|
|
with loading_lock:
|
|
# Double-check after acquiring loading lock
|
|
with self._model_lock:
|
|
if model_id in self.models and not self.models[model_id].is_loading:
|
|
entry = self.models[model_id]
|
|
entry.ref_count += 1
|
|
entry.last_used = time.time()
|
|
return entry.model
|
|
|
|
# Mark as loading
|
|
self.models[model_id] = ModelEntry(
|
|
model=None,
|
|
model_id=model_id,
|
|
is_loading=True,
|
|
estimated_memory_mb=estimated_memory_mb,
|
|
cleanup_callback=cleanup_callback
|
|
)
|
|
|
|
try:
|
|
# Check memory before loading
|
|
if estimated_memory_mb > 0:
|
|
is_available, stats = self.memory_guard.check_memory(int(estimated_memory_mb))
|
|
if not is_available and self.config.enable_emergency_cleanup:
|
|
logger.warning(f"Memory low, attempting cleanup before loading {model_id}")
|
|
self._evict_lru_models(required_mb=estimated_memory_mb)
|
|
|
|
# Load the model
|
|
logger.info(f"Loading model {model_id} (estimated {estimated_memory_mb}MB)")
|
|
start_time = time.time()
|
|
model = loader_func()
|
|
load_time = time.time() - start_time
|
|
logger.info(f"Model {model_id} loaded in {load_time:.2f}s")
|
|
|
|
# Update entry
|
|
with self._model_lock:
|
|
self.models[model_id] = ModelEntry(
|
|
model=model,
|
|
model_id=model_id,
|
|
ref_count=1,
|
|
estimated_memory_mb=estimated_memory_mb,
|
|
cleanup_callback=cleanup_callback,
|
|
is_loading=False
|
|
)
|
|
|
|
return model
|
|
|
|
except Exception as e:
|
|
# Clean up failed entry
|
|
with self._model_lock:
|
|
if model_id in self.models:
|
|
del self.models[model_id]
|
|
logger.error(f"Failed to load model {model_id}: {e}")
|
|
raise
|
|
|
|
def release_model(self, model_id: str):
|
|
"""
|
|
Release a reference to a model.
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
"""
|
|
with self._model_lock:
|
|
if model_id in self.models:
|
|
entry = self.models[model_id]
|
|
entry.ref_count = max(0, entry.ref_count - 1)
|
|
entry.last_used = time.time()
|
|
logger.debug(f"Model {model_id} released (ref_count={entry.ref_count})")
|
|
|
|
def unload_model(self, model_id: str, force: bool = False) -> bool:
|
|
"""
|
|
Unload a model from memory.
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
force: Force unload even if references exist
|
|
|
|
Returns:
|
|
True if model was unloaded
|
|
"""
|
|
with self._model_lock:
|
|
if model_id not in self.models:
|
|
return False
|
|
|
|
entry = self.models[model_id]
|
|
|
|
# Don't unload if there are active references (unless forced)
|
|
if entry.ref_count > 0 and not force:
|
|
logger.warning(f"Cannot unload {model_id}: {entry.ref_count} active references")
|
|
return False
|
|
|
|
# Run cleanup callback if provided
|
|
if entry.cleanup_callback:
|
|
try:
|
|
entry.cleanup_callback()
|
|
except Exception as e:
|
|
logger.warning(f"Cleanup callback failed for {model_id}: {e}")
|
|
|
|
# Delete model
|
|
del self.models[model_id]
|
|
logger.info(f"Model {model_id} unloaded")
|
|
|
|
# Clear GPU cache after unloading
|
|
self.memory_guard.clear_gpu_cache()
|
|
return True
|
|
|
|
def _evict_lru_models(self, required_mb: float = 0):
|
|
"""
|
|
Evict least recently used models to free memory.
|
|
|
|
Args:
|
|
required_mb: Target amount of memory to free
|
|
"""
|
|
with self._model_lock:
|
|
# Sort models by last_used (oldest first), excluding those with references
|
|
eviction_candidates = [
|
|
(model_id, entry)
|
|
for model_id, entry in self.models.items()
|
|
if entry.ref_count <= 0 and not entry.is_loading
|
|
]
|
|
eviction_candidates.sort(key=lambda x: x[1].last_used)
|
|
|
|
freed_mb = 0
|
|
for model_id, entry in eviction_candidates:
|
|
if self.unload_model(model_id, force=False):
|
|
freed_mb += entry.estimated_memory_mb
|
|
logger.info(f"Evicted LRU model {model_id}, freed ~{entry.estimated_memory_mb}MB")
|
|
|
|
if required_mb > 0 and freed_mb >= required_mb:
|
|
break
|
|
|
|
def get_model_stats(self) -> Dict:
|
|
"""Get statistics about loaded models"""
|
|
with self._model_lock:
|
|
return {
|
|
"total_models": len(self.models),
|
|
"models": {
|
|
model_id: {
|
|
"ref_count": entry.ref_count,
|
|
"last_used": entry.last_used,
|
|
"estimated_memory_mb": entry.estimated_memory_mb,
|
|
"is_loading": entry.is_loading,
|
|
"idle_seconds": time.time() - entry.last_used
|
|
}
|
|
for model_id, entry in self.models.items()
|
|
},
|
|
"total_estimated_memory_mb": sum(
|
|
e.estimated_memory_mb for e in self.models.values()
|
|
),
|
|
"memory_stats": self.memory_guard.get_memory_stats().__dict__
|
|
}
|
|
|
|
def teardown(self):
|
|
"""
|
|
Clean up all models and resources.
|
|
Called during application shutdown.
|
|
"""
|
|
logger.info("ModelManager teardown started")
|
|
|
|
# Stop monitor thread
|
|
self._monitor_running = False
|
|
|
|
# Unload all models
|
|
with self._model_lock:
|
|
model_ids = list(self.models.keys())
|
|
|
|
for model_id in model_ids:
|
|
self.unload_model(model_id, force=True)
|
|
|
|
# Clean up memory guard
|
|
self.memory_guard.shutdown()
|
|
|
|
logger.info("ModelManager teardown completed")
|
|
|
|
|
|
# Global singleton instance
|
|
_model_manager: Optional[ModelManager] = None
|
|
|
|
|
|
def get_model_manager(config: Optional[MemoryConfig] = None) -> ModelManager:
|
|
"""
|
|
Get the global ModelManager instance.
|
|
|
|
Args:
|
|
config: Optional configuration (only used on first call)
|
|
|
|
Returns:
|
|
ModelManager singleton instance
|
|
"""
|
|
global _model_manager
|
|
if _model_manager is None:
|
|
_model_manager = ModelManager(config)
|
|
return _model_manager
|
|
|
|
|
|
def shutdown_model_manager():
|
|
"""Shutdown the global ModelManager instance"""
|
|
global _model_manager
|
|
if _model_manager is not None:
|
|
_model_manager.teardown()
|
|
_model_manager = None
|
|
|
|
|
|
# =============================================================================
|
|
# Section 4.2: Batch Processing and Progressive Loading
|
|
# =============================================================================
|
|
|
|
class BatchPriority(Enum):
|
|
"""Priority levels for batch operations"""
|
|
LOW = 0
|
|
NORMAL = 1
|
|
HIGH = 2
|
|
CRITICAL = 3
|
|
|
|
|
|
@dataclass
|
|
class BatchItem:
|
|
"""Item in a processing batch"""
|
|
item_id: str
|
|
data: Any
|
|
priority: BatchPriority = BatchPriority.NORMAL
|
|
created_at: float = field(default_factory=time.time)
|
|
estimated_memory_mb: float = 0.0
|
|
metadata: Dict = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class BatchResult:
|
|
"""Result of batch processing"""
|
|
item_id: str
|
|
success: bool
|
|
result: Any = None
|
|
error: Optional[str] = None
|
|
processing_time_ms: float = 0.0
|
|
|
|
|
|
class BatchProcessor:
|
|
"""
|
|
Process items in batches to optimize memory usage for large documents.
|
|
|
|
Features:
|
|
- Memory-aware batch sizing
|
|
- Priority-based processing
|
|
- Progress tracking
|
|
- Automatic memory cleanup between batches
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_batch_size: int = 5,
|
|
max_memory_per_batch_mb: float = 2000.0,
|
|
memory_guard: Optional[MemoryGuard] = None,
|
|
cleanup_between_batches: bool = True
|
|
):
|
|
"""
|
|
Initialize BatchProcessor.
|
|
|
|
Args:
|
|
max_batch_size: Maximum items per batch
|
|
max_memory_per_batch_mb: Maximum memory allowed per batch
|
|
memory_guard: MemoryGuard instance for memory monitoring
|
|
cleanup_between_batches: Whether to clear GPU cache between batches
|
|
"""
|
|
self.max_batch_size = max_batch_size
|
|
self.max_memory_per_batch_mb = max_memory_per_batch_mb
|
|
self.memory_guard = memory_guard or MemoryGuard()
|
|
self.cleanup_between_batches = cleanup_between_batches
|
|
|
|
self._queue: List[BatchItem] = []
|
|
self._lock = threading.Lock()
|
|
self._processing = False
|
|
|
|
# Statistics
|
|
self._total_processed = 0
|
|
self._total_batches = 0
|
|
self._total_failures = 0
|
|
|
|
logger.info(f"BatchProcessor initialized (max_batch_size={max_batch_size}, max_memory={max_memory_per_batch_mb}MB)")
|
|
|
|
def add_item(self, item: BatchItem):
|
|
"""Add an item to the processing queue"""
|
|
with self._lock:
|
|
self._queue.append(item)
|
|
# Sort by priority (highest first), then by creation time (oldest first)
|
|
self._queue.sort(key=lambda x: (-x.priority.value, x.created_at))
|
|
logger.debug(f"Added item {item.item_id} to batch queue (queue_size={len(self._queue)})")
|
|
|
|
def add_items(self, items: List[BatchItem]):
|
|
"""Add multiple items to the processing queue"""
|
|
with self._lock:
|
|
self._queue.extend(items)
|
|
self._queue.sort(key=lambda x: (-x.priority.value, x.created_at))
|
|
logger.debug(f"Added {len(items)} items to batch queue (queue_size={len(self._queue)})")
|
|
|
|
def _create_batch(self) -> List[BatchItem]:
|
|
"""Create a batch from the queue based on size and memory constraints"""
|
|
batch = []
|
|
batch_memory = 0.0
|
|
|
|
with self._lock:
|
|
remaining = []
|
|
for item in self._queue:
|
|
# Check if adding this item would exceed limits
|
|
if len(batch) >= self.max_batch_size:
|
|
remaining.append(item)
|
|
continue
|
|
|
|
if batch_memory + item.estimated_memory_mb > self.max_memory_per_batch_mb and batch:
|
|
remaining.append(item)
|
|
continue
|
|
|
|
batch.append(item)
|
|
batch_memory += item.estimated_memory_mb
|
|
|
|
self._queue = remaining
|
|
|
|
return batch
|
|
|
|
def process_batch(
|
|
self,
|
|
processor_func: Callable[[Any], Any],
|
|
progress_callback: Optional[Callable[[int, int, BatchResult], None]] = None
|
|
) -> List[BatchResult]:
|
|
"""
|
|
Process a single batch of items.
|
|
|
|
Args:
|
|
processor_func: Function to process each item (receives item.data)
|
|
progress_callback: Optional callback(current, total, result)
|
|
|
|
Returns:
|
|
List of BatchResult for each item in the batch
|
|
"""
|
|
batch = self._create_batch()
|
|
if not batch:
|
|
return []
|
|
|
|
self._processing = True
|
|
results = []
|
|
total = len(batch)
|
|
|
|
try:
|
|
for i, item in enumerate(batch):
|
|
start_time = time.time()
|
|
result = BatchResult(item_id=item.item_id, success=False)
|
|
|
|
try:
|
|
# Check memory before processing
|
|
is_available, stats = self.memory_guard.check_memory(
|
|
int(item.estimated_memory_mb)
|
|
)
|
|
if not is_available:
|
|
logger.warning(f"Insufficient memory for item {item.item_id}, cleaning up...")
|
|
self.memory_guard.clear_gpu_cache()
|
|
gc.collect()
|
|
|
|
# Process item
|
|
result.result = processor_func(item.data)
|
|
result.success = True
|
|
|
|
except Exception as e:
|
|
result.error = str(e)
|
|
self._total_failures += 1
|
|
logger.error(f"Failed to process item {item.item_id}: {e}")
|
|
|
|
result.processing_time_ms = (time.time() - start_time) * 1000
|
|
results.append(result)
|
|
self._total_processed += 1
|
|
|
|
# Call progress callback
|
|
if progress_callback:
|
|
progress_callback(i + 1, total, result)
|
|
|
|
self._total_batches += 1
|
|
|
|
finally:
|
|
self._processing = False
|
|
|
|
# Clean up after batch
|
|
if self.cleanup_between_batches:
|
|
self.memory_guard.clear_gpu_cache()
|
|
gc.collect()
|
|
|
|
return results
|
|
|
|
def process_all(
|
|
self,
|
|
processor_func: Callable[[Any], Any],
|
|
progress_callback: Optional[Callable[[int, int, BatchResult], None]] = None
|
|
) -> List[BatchResult]:
|
|
"""
|
|
Process all items in the queue.
|
|
|
|
Args:
|
|
processor_func: Function to process each item
|
|
progress_callback: Optional progress callback
|
|
|
|
Returns:
|
|
List of all BatchResults
|
|
"""
|
|
all_results = []
|
|
|
|
while True:
|
|
with self._lock:
|
|
if not self._queue:
|
|
break
|
|
|
|
batch_results = self.process_batch(processor_func, progress_callback)
|
|
all_results.extend(batch_results)
|
|
|
|
return all_results
|
|
|
|
def get_queue_size(self) -> int:
|
|
"""Get current queue size"""
|
|
with self._lock:
|
|
return len(self._queue)
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get processing statistics"""
|
|
with self._lock:
|
|
return {
|
|
"queue_size": len(self._queue),
|
|
"total_processed": self._total_processed,
|
|
"total_batches": self._total_batches,
|
|
"total_failures": self._total_failures,
|
|
"is_processing": self._processing,
|
|
"max_batch_size": self.max_batch_size,
|
|
"max_memory_per_batch_mb": self.max_memory_per_batch_mb,
|
|
}
|
|
|
|
def clear_queue(self):
|
|
"""Clear the processing queue"""
|
|
with self._lock:
|
|
self._queue.clear()
|
|
logger.info("Batch queue cleared")
|
|
|
|
|
|
class ProgressiveLoader:
|
|
"""
|
|
Progressive page loader for multi-page documents.
|
|
|
|
Loads and processes pages incrementally to minimize memory usage.
|
|
Supports lookahead loading for better performance.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
lookahead_pages: int = 2,
|
|
memory_guard: Optional[MemoryGuard] = None,
|
|
cleanup_after_pages: int = 5
|
|
):
|
|
"""
|
|
Initialize ProgressiveLoader.
|
|
|
|
Args:
|
|
lookahead_pages: Number of pages to load ahead
|
|
memory_guard: MemoryGuard instance
|
|
cleanup_after_pages: Trigger cleanup after this many pages
|
|
"""
|
|
self.lookahead_pages = lookahead_pages
|
|
self.memory_guard = memory_guard or MemoryGuard()
|
|
self.cleanup_after_pages = cleanup_after_pages
|
|
|
|
self._loaded_pages: Dict[int, Any] = {}
|
|
self._lock = threading.Lock()
|
|
self._current_page = 0
|
|
self._total_pages = 0
|
|
self._pages_since_cleanup = 0
|
|
|
|
logger.info(f"ProgressiveLoader initialized (lookahead={lookahead_pages})")
|
|
|
|
def initialize(self, total_pages: int):
|
|
"""Initialize loader with total page count"""
|
|
with self._lock:
|
|
self._total_pages = total_pages
|
|
self._current_page = 0
|
|
self._loaded_pages.clear()
|
|
self._pages_since_cleanup = 0
|
|
logger.info(f"ProgressiveLoader initialized for {total_pages} pages")
|
|
|
|
def load_page(
|
|
self,
|
|
page_num: int,
|
|
loader_func: Callable[[int], Any],
|
|
unload_distant: bool = True
|
|
) -> Any:
|
|
"""
|
|
Load a specific page.
|
|
|
|
Args:
|
|
page_num: Page number to load (0-indexed)
|
|
loader_func: Function to load page (receives page_num)
|
|
unload_distant: Unload pages far from current position
|
|
|
|
Returns:
|
|
Loaded page data
|
|
"""
|
|
with self._lock:
|
|
# Check if already loaded
|
|
if page_num in self._loaded_pages:
|
|
self._current_page = page_num
|
|
return self._loaded_pages[page_num]
|
|
|
|
# Load the page
|
|
logger.debug(f"Loading page {page_num}")
|
|
page_data = loader_func(page_num)
|
|
|
|
with self._lock:
|
|
self._loaded_pages[page_num] = page_data
|
|
self._current_page = page_num
|
|
self._pages_since_cleanup += 1
|
|
|
|
# Unload distant pages to save memory
|
|
if unload_distant:
|
|
self._unload_distant_pages()
|
|
|
|
# Trigger cleanup if needed
|
|
if self._pages_since_cleanup >= self.cleanup_after_pages:
|
|
self.memory_guard.clear_gpu_cache()
|
|
gc.collect()
|
|
self._pages_since_cleanup = 0
|
|
|
|
return page_data
|
|
|
|
def _unload_distant_pages(self):
|
|
"""Unload pages far from current position"""
|
|
keep_range = range(
|
|
max(0, self._current_page - 1),
|
|
min(self._total_pages, self._current_page + self.lookahead_pages + 1)
|
|
)
|
|
|
|
pages_to_unload = [
|
|
p for p in self._loaded_pages.keys()
|
|
if p not in keep_range
|
|
]
|
|
|
|
for page_num in pages_to_unload:
|
|
del self._loaded_pages[page_num]
|
|
logger.debug(f"Unloaded distant page {page_num}")
|
|
|
|
def prefetch_pages(
|
|
self,
|
|
start_page: int,
|
|
loader_func: Callable[[int], Any]
|
|
):
|
|
"""
|
|
Prefetch upcoming pages in background.
|
|
|
|
Args:
|
|
start_page: Starting page number
|
|
loader_func: Function to load page
|
|
"""
|
|
for i in range(self.lookahead_pages):
|
|
page_num = start_page + i + 1
|
|
if page_num >= self._total_pages:
|
|
break
|
|
|
|
with self._lock:
|
|
if page_num in self._loaded_pages:
|
|
continue
|
|
|
|
try:
|
|
self.load_page(page_num, loader_func, unload_distant=False)
|
|
except Exception as e:
|
|
logger.warning(f"Prefetch failed for page {page_num}: {e}")
|
|
|
|
def iterate_pages(
|
|
self,
|
|
loader_func: Callable[[int], Any],
|
|
processor_func: Callable[[int, Any], Any],
|
|
progress_callback: Optional[Callable[[int, int], None]] = None
|
|
) -> List[Any]:
|
|
"""
|
|
Iterate through all pages with progressive loading.
|
|
|
|
Args:
|
|
loader_func: Function to load a page
|
|
processor_func: Function to process page (receives page_num, data)
|
|
progress_callback: Optional callback(current_page, total_pages)
|
|
|
|
Returns:
|
|
List of results from processor_func
|
|
"""
|
|
results = []
|
|
|
|
for page_num in range(self._total_pages):
|
|
# Load page
|
|
page_data = self.load_page(page_num, loader_func)
|
|
|
|
# Process page
|
|
result = processor_func(page_num, page_data)
|
|
results.append(result)
|
|
|
|
# Report progress
|
|
if progress_callback:
|
|
progress_callback(page_num + 1, self._total_pages)
|
|
|
|
# Start prefetching next pages in background
|
|
if self.lookahead_pages > 0:
|
|
# Use thread for prefetching to not block
|
|
prefetch_thread = threading.Thread(
|
|
target=self.prefetch_pages,
|
|
args=(page_num, loader_func),
|
|
daemon=True
|
|
)
|
|
prefetch_thread.start()
|
|
|
|
return results
|
|
|
|
def get_loaded_pages(self) -> List[int]:
|
|
"""Get list of currently loaded page numbers"""
|
|
with self._lock:
|
|
return list(self._loaded_pages.keys())
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get loader statistics"""
|
|
with self._lock:
|
|
return {
|
|
"total_pages": self._total_pages,
|
|
"current_page": self._current_page,
|
|
"loaded_pages_count": len(self._loaded_pages),
|
|
"loaded_pages": list(self._loaded_pages.keys()),
|
|
"lookahead_pages": self.lookahead_pages,
|
|
"pages_since_cleanup": self._pages_since_cleanup,
|
|
}
|
|
|
|
def clear(self):
|
|
"""Clear all loaded pages"""
|
|
with self._lock:
|
|
self._loaded_pages.clear()
|
|
self._current_page = 0
|
|
self._pages_since_cleanup = 0
|
|
self.memory_guard.clear_gpu_cache()
|
|
gc.collect()
|
|
|
|
|
|
class PriorityOperationQueue:
|
|
"""
|
|
Priority queue for OCR operations.
|
|
|
|
Higher priority operations are processed first.
|
|
Supports timeout and cancellation.
|
|
"""
|
|
|
|
def __init__(self, max_size: int = 100):
|
|
"""
|
|
Initialize priority queue.
|
|
|
|
Args:
|
|
max_size: Maximum queue size (0 for unlimited)
|
|
"""
|
|
self.max_size = max_size
|
|
self._queue: List[Tuple[BatchPriority, float, str, Any]] = [] # (priority, timestamp, id, data)
|
|
self._lock = threading.Lock()
|
|
self._condition = threading.Condition(self._lock)
|
|
self._cancelled: set = set()
|
|
|
|
# Statistics
|
|
self._total_enqueued = 0
|
|
self._total_dequeued = 0
|
|
self._total_cancelled = 0
|
|
|
|
logger.info(f"PriorityOperationQueue initialized (max_size={max_size})")
|
|
|
|
def enqueue(
|
|
self,
|
|
item_id: str,
|
|
data: Any,
|
|
priority: BatchPriority = BatchPriority.NORMAL,
|
|
timeout: Optional[float] = None
|
|
) -> bool:
|
|
"""
|
|
Add an operation to the queue.
|
|
|
|
Args:
|
|
item_id: Unique identifier for the operation
|
|
data: Operation data
|
|
priority: Operation priority
|
|
timeout: Optional timeout to wait for space in queue
|
|
|
|
Returns:
|
|
True if enqueued successfully
|
|
"""
|
|
with self._condition:
|
|
# Wait for space if queue is full
|
|
if self.max_size > 0 and len(self._queue) >= self.max_size:
|
|
if timeout is not None:
|
|
result = self._condition.wait_for(
|
|
lambda: len(self._queue) < self.max_size,
|
|
timeout=timeout
|
|
)
|
|
if not result:
|
|
logger.warning(f"Queue full, timeout waiting to enqueue {item_id}")
|
|
return False
|
|
else:
|
|
logger.warning(f"Queue full, cannot enqueue {item_id}")
|
|
return False
|
|
|
|
# Add to queue (negative priority for max-heap behavior)
|
|
import heapq
|
|
heapq.heappush(
|
|
self._queue,
|
|
(-priority.value, time.time(), item_id, data)
|
|
)
|
|
self._total_enqueued += 1
|
|
self._condition.notify()
|
|
|
|
logger.debug(f"Enqueued operation {item_id} with priority {priority.name}")
|
|
return True
|
|
|
|
def dequeue(self, timeout: Optional[float] = None) -> Optional[Tuple[str, Any, BatchPriority]]:
|
|
"""
|
|
Get the highest priority operation from the queue.
|
|
|
|
Args:
|
|
timeout: Optional timeout to wait for an item
|
|
|
|
Returns:
|
|
Tuple of (item_id, data, priority) or None if timeout
|
|
"""
|
|
import heapq
|
|
|
|
with self._condition:
|
|
# Wait for an item
|
|
if not self._queue:
|
|
if timeout is not None:
|
|
result = self._condition.wait_for(
|
|
lambda: len(self._queue) > 0,
|
|
timeout=timeout
|
|
)
|
|
if not result:
|
|
return None
|
|
else:
|
|
return None
|
|
|
|
# Get highest priority item
|
|
neg_priority, _, item_id, data = heapq.heappop(self._queue)
|
|
priority = BatchPriority(-neg_priority)
|
|
|
|
# Skip if cancelled
|
|
if item_id in self._cancelled:
|
|
self._cancelled.discard(item_id)
|
|
self._total_cancelled += 1
|
|
self._condition.notify()
|
|
return self.dequeue(timeout=0) # Try next item
|
|
|
|
self._total_dequeued += 1
|
|
self._condition.notify()
|
|
|
|
logger.debug(f"Dequeued operation {item_id} with priority {priority.name}")
|
|
return item_id, data, priority
|
|
|
|
def cancel(self, item_id: str) -> bool:
|
|
"""
|
|
Cancel a pending operation.
|
|
|
|
Args:
|
|
item_id: Operation identifier to cancel
|
|
|
|
Returns:
|
|
True if the operation was found and marked for cancellation
|
|
"""
|
|
with self._lock:
|
|
# Check if item is in queue
|
|
for _, _, qid, _ in self._queue:
|
|
if qid == item_id:
|
|
self._cancelled.add(item_id)
|
|
logger.info(f"Operation {item_id} marked for cancellation")
|
|
return True
|
|
return False
|
|
|
|
def get_size(self) -> int:
|
|
"""Get current queue size"""
|
|
with self._lock:
|
|
return len(self._queue)
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get queue statistics"""
|
|
with self._lock:
|
|
# Count by priority
|
|
priority_counts = {p.name: 0 for p in BatchPriority}
|
|
for neg_priority, _, _, _ in self._queue:
|
|
priority = BatchPriority(-neg_priority)
|
|
priority_counts[priority.name] += 1
|
|
|
|
return {
|
|
"queue_size": len(self._queue),
|
|
"max_size": self.max_size,
|
|
"total_enqueued": self._total_enqueued,
|
|
"total_dequeued": self._total_dequeued,
|
|
"total_cancelled": self._total_cancelled,
|
|
"pending_cancellations": len(self._cancelled),
|
|
"by_priority": priority_counts,
|
|
}
|
|
|
|
def clear(self):
|
|
"""Clear the queue"""
|
|
with self._lock:
|
|
self._queue.clear()
|
|
self._cancelled.clear()
|
|
logger.info("Priority queue cleared")
|
|
|
|
|
|
# =============================================================================
|
|
# Section 5.2: Recovery Mechanisms
|
|
# =============================================================================
|
|
|
|
@dataclass
|
|
class RecoveryState:
|
|
"""State of recovery mechanism"""
|
|
last_recovery_time: float = 0.0
|
|
recovery_count: int = 0
|
|
in_cooldown: bool = False
|
|
cooldown_until: float = 0.0
|
|
last_error: Optional[str] = None
|
|
|
|
|
|
class RecoveryManager:
|
|
"""
|
|
Manages recovery mechanisms for memory issues and failures.
|
|
|
|
Features:
|
|
- Emergency memory release
|
|
- Cooldown period after recovery
|
|
- Recovery attempt limits
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cooldown_seconds: float = 30.0,
|
|
max_recovery_attempts: int = 3,
|
|
recovery_window_seconds: float = 300.0,
|
|
memory_guard: Optional[MemoryGuard] = None
|
|
):
|
|
"""
|
|
Initialize RecoveryManager.
|
|
|
|
Args:
|
|
cooldown_seconds: Cooldown period after recovery
|
|
max_recovery_attempts: Max recovery attempts within window
|
|
recovery_window_seconds: Window for counting recovery attempts
|
|
memory_guard: MemoryGuard instance
|
|
"""
|
|
self.cooldown_seconds = cooldown_seconds
|
|
self.max_recovery_attempts = max_recovery_attempts
|
|
self.recovery_window_seconds = recovery_window_seconds
|
|
self.memory_guard = memory_guard or MemoryGuard()
|
|
|
|
self._state = RecoveryState()
|
|
self._lock = threading.Lock()
|
|
self._recovery_times: List[float] = []
|
|
|
|
# Callbacks
|
|
self._on_recovery_start: List[Callable] = []
|
|
self._on_recovery_complete: List[Callable[[bool], None]] = []
|
|
|
|
logger.info(f"RecoveryManager initialized (cooldown={cooldown_seconds}s)")
|
|
|
|
def register_callbacks(
|
|
self,
|
|
on_start: Optional[Callable] = None,
|
|
on_complete: Optional[Callable[[bool], None]] = None
|
|
):
|
|
"""Register recovery event callbacks"""
|
|
if on_start:
|
|
self._on_recovery_start.append(on_start)
|
|
if on_complete:
|
|
self._on_recovery_complete.append(on_complete)
|
|
|
|
def is_in_cooldown(self) -> bool:
|
|
"""Check if currently in cooldown period"""
|
|
with self._lock:
|
|
if not self._state.in_cooldown:
|
|
return False
|
|
|
|
if time.time() >= self._state.cooldown_until:
|
|
self._state.in_cooldown = False
|
|
return False
|
|
|
|
return True
|
|
|
|
def get_cooldown_remaining(self) -> float:
|
|
"""Get remaining cooldown time in seconds"""
|
|
with self._lock:
|
|
if not self._state.in_cooldown:
|
|
return 0.0
|
|
return max(0, self._state.cooldown_until - time.time())
|
|
|
|
def _count_recent_recoveries(self) -> int:
|
|
"""Count recovery attempts within the window"""
|
|
cutoff = time.time() - self.recovery_window_seconds
|
|
with self._lock:
|
|
# Clean old entries
|
|
self._recovery_times = [t for t in self._recovery_times if t > cutoff]
|
|
return len(self._recovery_times)
|
|
|
|
def can_attempt_recovery(self) -> Tuple[bool, str]:
|
|
"""
|
|
Check if recovery can be attempted.
|
|
|
|
Returns:
|
|
Tuple of (can_recover, reason)
|
|
"""
|
|
if self.is_in_cooldown():
|
|
remaining = self.get_cooldown_remaining()
|
|
return False, f"In cooldown period ({remaining:.1f}s remaining)"
|
|
|
|
recent_count = self._count_recent_recoveries()
|
|
if recent_count >= self.max_recovery_attempts:
|
|
return False, f"Max recovery attempts ({self.max_recovery_attempts}) reached"
|
|
|
|
return True, "Recovery allowed"
|
|
|
|
def attempt_recovery(self, error: Optional[str] = None) -> bool:
|
|
"""
|
|
Attempt memory recovery.
|
|
|
|
Args:
|
|
error: Optional error message that triggered recovery
|
|
|
|
Returns:
|
|
True if recovery was successful
|
|
"""
|
|
can_recover, reason = self.can_attempt_recovery()
|
|
if not can_recover:
|
|
logger.warning(f"Cannot attempt recovery: {reason}")
|
|
return False
|
|
|
|
logger.info("Starting memory recovery...")
|
|
|
|
# Notify callbacks
|
|
for callback in self._on_recovery_start:
|
|
try:
|
|
callback()
|
|
except Exception as e:
|
|
logger.warning(f"Recovery start callback failed: {e}")
|
|
|
|
success = False
|
|
try:
|
|
# Step 1: Clear GPU cache
|
|
self.memory_guard.clear_gpu_cache()
|
|
|
|
# Step 2: Force garbage collection
|
|
gc.collect()
|
|
|
|
# Step 3: Check memory status
|
|
is_available, stats = self.memory_guard.check_memory()
|
|
success = is_available or stats.gpu_used_ratio < 0.9
|
|
|
|
if success:
|
|
logger.info(
|
|
f"Memory recovery successful. GPU: {stats.gpu_used_ratio*100:.1f}% used"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Memory recovery incomplete. GPU still at {stats.gpu_used_ratio*100:.1f}%"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Recovery failed with error: {e}")
|
|
success = False
|
|
|
|
# Update state
|
|
with self._lock:
|
|
self._state.last_recovery_time = time.time()
|
|
self._state.recovery_count += 1
|
|
self._state.last_error = error
|
|
self._recovery_times.append(time.time())
|
|
|
|
# Enter cooldown
|
|
self._state.in_cooldown = True
|
|
self._state.cooldown_until = time.time() + self.cooldown_seconds
|
|
|
|
logger.info(f"Entering cooldown period ({self.cooldown_seconds}s)")
|
|
|
|
# Notify callbacks
|
|
for callback in self._on_recovery_complete:
|
|
try:
|
|
callback(success)
|
|
except Exception as e:
|
|
logger.warning(f"Recovery complete callback failed: {e}")
|
|
|
|
return success
|
|
|
|
def emergency_release(self, model_manager: Optional['ModelManager'] = None) -> bool:
|
|
"""
|
|
Emergency memory release - more aggressive than normal recovery.
|
|
|
|
Args:
|
|
model_manager: Optional ModelManager to unload models from
|
|
|
|
Returns:
|
|
True if significant memory was freed
|
|
"""
|
|
logger.warning("Initiating EMERGENCY memory release")
|
|
|
|
initial_stats = self.memory_guard.get_memory_stats()
|
|
|
|
# Step 1: Unload all models if model_manager provided
|
|
if model_manager:
|
|
logger.info("Unloading all models...")
|
|
try:
|
|
model_ids = list(model_manager.models.keys())
|
|
for model_id in model_ids:
|
|
model_manager.unload_model(model_id, force=True)
|
|
except Exception as e:
|
|
logger.error(f"Failed to unload models: {e}")
|
|
|
|
# Step 2: Clear all caches
|
|
self.memory_guard.clear_gpu_cache()
|
|
|
|
# Step 3: Multiple rounds of garbage collection
|
|
for i in range(3):
|
|
gc.collect()
|
|
time.sleep(0.1)
|
|
|
|
# Step 4: Check improvement
|
|
final_stats = self.memory_guard.get_memory_stats()
|
|
freed_mb = initial_stats.gpu_used_mb - final_stats.gpu_used_mb
|
|
|
|
logger.info(
|
|
f"Emergency release complete. Freed ~{freed_mb:.0f}MB. "
|
|
f"GPU: {final_stats.gpu_used_mb:.0f}MB / {final_stats.gpu_total_mb:.0f}MB "
|
|
f"({final_stats.gpu_used_ratio*100:.1f}%)"
|
|
)
|
|
|
|
return freed_mb > 100 # Consider success if freed >100MB
|
|
|
|
def get_state(self) -> Dict:
|
|
"""Get current recovery state"""
|
|
with self._lock:
|
|
return {
|
|
"last_recovery_time": self._state.last_recovery_time,
|
|
"recovery_count": self._state.recovery_count,
|
|
"in_cooldown": self._state.in_cooldown,
|
|
"cooldown_remaining_seconds": self.get_cooldown_remaining(),
|
|
"recent_recoveries": self._count_recent_recoveries(),
|
|
"max_recovery_attempts": self.max_recovery_attempts,
|
|
"last_error": self._state.last_error,
|
|
}
|
|
|
|
|
|
# Global recovery manager instance
|
|
_recovery_manager: Optional[RecoveryManager] = None
|
|
|
|
|
|
def get_recovery_manager(
|
|
cooldown_seconds: float = 30.0,
|
|
max_recovery_attempts: int = 3
|
|
) -> RecoveryManager:
|
|
"""
|
|
Get the global RecoveryManager instance.
|
|
|
|
Args:
|
|
cooldown_seconds: Cooldown period after recovery
|
|
max_recovery_attempts: Max recovery attempts within window
|
|
|
|
Returns:
|
|
RecoveryManager singleton instance
|
|
"""
|
|
global _recovery_manager
|
|
if _recovery_manager is None:
|
|
_recovery_manager = RecoveryManager(
|
|
cooldown_seconds=cooldown_seconds,
|
|
max_recovery_attempts=max_recovery_attempts
|
|
)
|
|
return _recovery_manager
|
|
|
|
|
|
def shutdown_recovery_manager():
|
|
"""Shutdown the global RecoveryManager instance"""
|
|
global _recovery_manager
|
|
_recovery_manager = None
|
|
|
|
|
|
# =============================================================================
|
|
# Section 5.2: Memory Dump for Debugging
|
|
# =============================================================================
|
|
|
|
@dataclass
|
|
class MemoryDumpEntry:
|
|
"""Entry in a memory dump"""
|
|
object_type: str
|
|
object_id: str
|
|
size_bytes: int
|
|
ref_count: int
|
|
details: Dict = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class MemoryDump:
|
|
"""Complete memory dump for debugging"""
|
|
timestamp: float
|
|
total_gpu_memory_mb: float
|
|
used_gpu_memory_mb: float
|
|
free_gpu_memory_mb: float
|
|
total_cpu_memory_mb: float
|
|
used_cpu_memory_mb: float
|
|
loaded_models: List[Dict]
|
|
active_predictions: int
|
|
queue_depth: int
|
|
service_pool_stats: Dict
|
|
recovery_state: Dict
|
|
python_objects: List[MemoryDumpEntry] = field(default_factory=list)
|
|
gc_stats: Dict = field(default_factory=dict)
|
|
|
|
|
|
class MemoryDumper:
|
|
"""
|
|
Creates memory dumps for debugging memory issues.
|
|
|
|
Captures comprehensive memory state including:
|
|
- GPU/CPU memory usage
|
|
- Loaded models and their references
|
|
- Active predictions and queue state
|
|
- Python garbage collector statistics
|
|
- Large object tracking
|
|
"""
|
|
|
|
def __init__(self, memory_guard: Optional[MemoryGuard] = None):
|
|
"""
|
|
Initialize MemoryDumper.
|
|
|
|
Args:
|
|
memory_guard: MemoryGuard instance for memory queries
|
|
"""
|
|
self.memory_guard = memory_guard or MemoryGuard()
|
|
self._dump_history: List[MemoryDump] = []
|
|
self._max_history = 10
|
|
self._lock = threading.Lock()
|
|
|
|
logger.info("MemoryDumper initialized")
|
|
|
|
def create_dump(
|
|
self,
|
|
include_python_objects: bool = False,
|
|
min_object_size: int = 1048576 # 1MB
|
|
) -> MemoryDump:
|
|
"""
|
|
Create a memory dump capturing current state.
|
|
|
|
Args:
|
|
include_python_objects: Include large Python objects in dump
|
|
min_object_size: Minimum object size to include (bytes)
|
|
|
|
Returns:
|
|
MemoryDump with current memory state
|
|
"""
|
|
logger.info("Creating memory dump...")
|
|
|
|
# Get memory stats
|
|
stats = self.memory_guard.get_memory_stats()
|
|
|
|
# Get model manager stats
|
|
loaded_models = []
|
|
try:
|
|
model_manager = get_model_manager()
|
|
model_stats = model_manager.get_model_stats()
|
|
loaded_models = [
|
|
{
|
|
"model_id": model_id,
|
|
"ref_count": info["ref_count"],
|
|
"estimated_memory_mb": info["estimated_memory_mb"],
|
|
"idle_seconds": info["idle_seconds"],
|
|
"is_loading": info["is_loading"],
|
|
}
|
|
for model_id, info in model_stats.get("models", {}).items()
|
|
]
|
|
except Exception as e:
|
|
logger.debug(f"Could not get model stats: {e}")
|
|
|
|
# Get prediction semaphore stats
|
|
active_predictions = 0
|
|
queue_depth = 0
|
|
try:
|
|
semaphore = get_prediction_semaphore()
|
|
sem_stats = semaphore.get_stats()
|
|
active_predictions = sem_stats.get("active_predictions", 0)
|
|
queue_depth = sem_stats.get("queue_depth", 0)
|
|
except Exception as e:
|
|
logger.debug(f"Could not get semaphore stats: {e}")
|
|
|
|
# Get service pool stats
|
|
service_pool_stats = {}
|
|
try:
|
|
from app.services.service_pool import get_service_pool
|
|
pool = get_service_pool()
|
|
service_pool_stats = pool.get_pool_stats()
|
|
except Exception as e:
|
|
logger.debug(f"Could not get service pool stats: {e}")
|
|
|
|
# Get recovery state
|
|
recovery_state = {}
|
|
try:
|
|
recovery_manager = get_recovery_manager()
|
|
recovery_state = recovery_manager.get_state()
|
|
except Exception as e:
|
|
logger.debug(f"Could not get recovery state: {e}")
|
|
|
|
# Get GC stats
|
|
gc_stats = {
|
|
"counts": gc.get_count(),
|
|
"threshold": gc.get_threshold(),
|
|
"is_tracking": gc.isenabled(),
|
|
}
|
|
|
|
# Create dump
|
|
dump = MemoryDump(
|
|
timestamp=time.time(),
|
|
total_gpu_memory_mb=stats.gpu_total_mb,
|
|
used_gpu_memory_mb=stats.gpu_used_mb,
|
|
free_gpu_memory_mb=stats.gpu_free_mb,
|
|
total_cpu_memory_mb=stats.cpu_used_mb + stats.cpu_available_mb,
|
|
used_cpu_memory_mb=stats.cpu_used_mb,
|
|
loaded_models=loaded_models,
|
|
active_predictions=active_predictions,
|
|
queue_depth=queue_depth,
|
|
service_pool_stats=service_pool_stats,
|
|
recovery_state=recovery_state,
|
|
gc_stats=gc_stats,
|
|
)
|
|
|
|
# Optionally include large Python objects
|
|
if include_python_objects:
|
|
dump.python_objects = self._get_large_objects(min_object_size)
|
|
|
|
# Store in history
|
|
with self._lock:
|
|
self._dump_history.append(dump)
|
|
if len(self._dump_history) > self._max_history:
|
|
self._dump_history.pop(0)
|
|
|
|
logger.info(
|
|
f"Memory dump created: GPU {stats.gpu_used_mb:.0f}/{stats.gpu_total_mb:.0f}MB, "
|
|
f"{len(loaded_models)} models, {active_predictions} active predictions"
|
|
)
|
|
|
|
return dump
|
|
|
|
def _get_large_objects(self, min_size: int) -> List[MemoryDumpEntry]:
|
|
"""Get list of large Python objects for debugging"""
|
|
large_objects = []
|
|
|
|
try:
|
|
import sys
|
|
|
|
# Get all objects tracked by GC
|
|
for obj in gc.get_objects():
|
|
try:
|
|
size = sys.getsizeof(obj)
|
|
if size >= min_size:
|
|
entry = MemoryDumpEntry(
|
|
object_type=type(obj).__name__,
|
|
object_id=str(id(obj)),
|
|
size_bytes=size,
|
|
ref_count=sys.getrefcount(obj),
|
|
details={
|
|
"module": getattr(type(obj), "__module__", "unknown"),
|
|
}
|
|
)
|
|
large_objects.append(entry)
|
|
except Exception:
|
|
pass # Skip objects that can't be measured
|
|
|
|
# Sort by size descending
|
|
large_objects.sort(key=lambda x: x.size_bytes, reverse=True)
|
|
|
|
# Limit to top 100
|
|
return large_objects[:100]
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get large objects: {e}")
|
|
return []
|
|
|
|
def get_dump_history(self) -> List[MemoryDump]:
|
|
"""Get recent dump history"""
|
|
with self._lock:
|
|
return list(self._dump_history)
|
|
|
|
def get_latest_dump(self) -> Optional[MemoryDump]:
|
|
"""Get the most recent dump"""
|
|
with self._lock:
|
|
return self._dump_history[-1] if self._dump_history else None
|
|
|
|
def compare_dumps(
|
|
self,
|
|
dump1: MemoryDump,
|
|
dump2: MemoryDump
|
|
) -> Dict:
|
|
"""
|
|
Compare two memory dumps to identify changes.
|
|
|
|
Args:
|
|
dump1: First (earlier) dump
|
|
dump2: Second (later) dump
|
|
|
|
Returns:
|
|
Dictionary with comparison results
|
|
"""
|
|
return {
|
|
"time_delta_seconds": dump2.timestamp - dump1.timestamp,
|
|
"gpu_memory_change_mb": dump2.used_gpu_memory_mb - dump1.used_gpu_memory_mb,
|
|
"cpu_memory_change_mb": dump2.used_cpu_memory_mb - dump1.used_cpu_memory_mb,
|
|
"model_count_change": len(dump2.loaded_models) - len(dump1.loaded_models),
|
|
"prediction_count_change": dump2.active_predictions - dump1.active_predictions,
|
|
"dump1_timestamp": dump1.timestamp,
|
|
"dump2_timestamp": dump2.timestamp,
|
|
}
|
|
|
|
def to_dict(self, dump: MemoryDump) -> Dict:
|
|
"""Convert a MemoryDump to a dictionary for JSON serialization"""
|
|
return {
|
|
"timestamp": dump.timestamp,
|
|
"gpu": {
|
|
"total_mb": dump.total_gpu_memory_mb,
|
|
"used_mb": dump.used_gpu_memory_mb,
|
|
"free_mb": dump.free_gpu_memory_mb,
|
|
"utilization_percent": (
|
|
dump.used_gpu_memory_mb / dump.total_gpu_memory_mb * 100
|
|
if dump.total_gpu_memory_mb > 0 else 0
|
|
),
|
|
},
|
|
"cpu": {
|
|
"total_mb": dump.total_cpu_memory_mb,
|
|
"used_mb": dump.used_cpu_memory_mb,
|
|
},
|
|
"models": dump.loaded_models,
|
|
"predictions": {
|
|
"active": dump.active_predictions,
|
|
"queue_depth": dump.queue_depth,
|
|
},
|
|
"service_pool": dump.service_pool_stats,
|
|
"recovery": dump.recovery_state,
|
|
"gc": dump.gc_stats,
|
|
"large_objects_count": len(dump.python_objects),
|
|
}
|
|
|
|
|
|
# Global memory dumper instance
|
|
_memory_dumper: Optional[MemoryDumper] = None
|
|
|
|
|
|
def get_memory_dumper() -> MemoryDumper:
|
|
"""Get the global MemoryDumper instance"""
|
|
global _memory_dumper
|
|
if _memory_dumper is None:
|
|
_memory_dumper = MemoryDumper()
|
|
return _memory_dumper
|
|
|
|
|
|
def shutdown_memory_dumper():
|
|
"""Shutdown the global MemoryDumper instance"""
|
|
global _memory_dumper
|
|
_memory_dumper = None
|
|
|
|
|
|
# =============================================================================
|
|
# Section 7.2: Prometheus Metrics Export
|
|
# =============================================================================
|
|
|
|
class PrometheusMetrics:
|
|
"""
|
|
Prometheus metrics exporter for memory management.
|
|
|
|
Exposes metrics in Prometheus text format for monitoring:
|
|
- GPU/CPU memory usage
|
|
- Model lifecycle metrics
|
|
- Prediction semaphore metrics
|
|
- Service pool metrics
|
|
- Recovery metrics
|
|
"""
|
|
|
|
# Metric names
|
|
METRIC_PREFIX = "tool_ocr_memory_"
|
|
|
|
def __init__(self):
|
|
"""Initialize PrometheusMetrics"""
|
|
self._custom_metrics: Dict[str, float] = {}
|
|
self._lock = threading.Lock()
|
|
logger.info("PrometheusMetrics initialized")
|
|
|
|
def set_custom_metric(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
|
|
"""
|
|
Set a custom metric value.
|
|
|
|
Args:
|
|
name: Metric name
|
|
value: Metric value
|
|
labels: Optional labels for the metric
|
|
"""
|
|
with self._lock:
|
|
key = name if not labels else f"{name}{{{self._format_labels(labels)}}}"
|
|
self._custom_metrics[key] = value
|
|
|
|
def _format_labels(self, labels: Dict[str, str]) -> str:
|
|
"""Format labels for Prometheus"""
|
|
return ",".join(f'{k}="{v}"' for k, v in sorted(labels.items()))
|
|
|
|
def _format_metric(self, name: str, value: float, help_text: str, metric_type: str = "gauge") -> str:
|
|
"""Format a single metric in Prometheus format"""
|
|
lines = [
|
|
f"# HELP {self.METRIC_PREFIX}{name} {help_text}",
|
|
f"# TYPE {self.METRIC_PREFIX}{name} {metric_type}",
|
|
f"{self.METRIC_PREFIX}{name} {value}",
|
|
]
|
|
return "\n".join(lines)
|
|
|
|
def _format_metric_with_labels(
|
|
self,
|
|
name: str,
|
|
values: List[Tuple[Dict[str, str], float]],
|
|
help_text: str,
|
|
metric_type: str = "gauge"
|
|
) -> str:
|
|
"""Format a metric with labels in Prometheus format"""
|
|
lines = [
|
|
f"# HELP {self.METRIC_PREFIX}{name} {help_text}",
|
|
f"# TYPE {self.METRIC_PREFIX}{name} {metric_type}",
|
|
]
|
|
for labels, value in values:
|
|
label_str = self._format_labels(labels)
|
|
lines.append(f"{self.METRIC_PREFIX}{name}{{{label_str}}} {value}")
|
|
return "\n".join(lines)
|
|
|
|
def export_metrics(self) -> str:
|
|
"""
|
|
Export all metrics in Prometheus text format.
|
|
|
|
Returns:
|
|
String containing metrics in Prometheus exposition format
|
|
"""
|
|
metrics = []
|
|
|
|
# GPU Memory metrics
|
|
try:
|
|
guard = MemoryGuard()
|
|
stats = guard.get_memory_stats()
|
|
|
|
metrics.append(self._format_metric(
|
|
"gpu_total_bytes",
|
|
stats.gpu_total_mb * 1024 * 1024,
|
|
"Total GPU memory in bytes"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"gpu_used_bytes",
|
|
stats.gpu_used_mb * 1024 * 1024,
|
|
"Used GPU memory in bytes"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"gpu_free_bytes",
|
|
stats.gpu_free_mb * 1024 * 1024,
|
|
"Free GPU memory in bytes"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"gpu_utilization_ratio",
|
|
stats.gpu_used_ratio,
|
|
"GPU memory utilization ratio (0-1)"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"cpu_used_bytes",
|
|
stats.cpu_used_mb * 1024 * 1024,
|
|
"Used CPU memory in bytes"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"cpu_available_bytes",
|
|
stats.cpu_available_mb * 1024 * 1024,
|
|
"Available CPU memory in bytes"
|
|
))
|
|
|
|
guard.shutdown()
|
|
except Exception as e:
|
|
logger.debug(f"Could not get memory stats for metrics: {e}")
|
|
|
|
# Model Manager metrics
|
|
try:
|
|
model_manager = get_model_manager()
|
|
model_stats = model_manager.get_model_stats()
|
|
|
|
metrics.append(self._format_metric(
|
|
"models_loaded_total",
|
|
model_stats.get("total_models", 0),
|
|
"Total number of loaded models"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"models_memory_bytes",
|
|
model_stats.get("total_estimated_memory_mb", 0) * 1024 * 1024,
|
|
"Estimated total memory used by loaded models in bytes"
|
|
))
|
|
|
|
# Per-model metrics
|
|
model_values = []
|
|
for model_id, info in model_stats.get("models", {}).items():
|
|
model_values.append((
|
|
{"model_id": model_id},
|
|
info.get("ref_count", 0)
|
|
))
|
|
if model_values:
|
|
metrics.append(self._format_metric_with_labels(
|
|
"model_ref_count",
|
|
model_values,
|
|
"Reference count per model"
|
|
))
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Could not get model stats for metrics: {e}")
|
|
|
|
# Prediction Semaphore metrics
|
|
try:
|
|
semaphore = get_prediction_semaphore()
|
|
sem_stats = semaphore.get_stats()
|
|
|
|
metrics.append(self._format_metric(
|
|
"predictions_active",
|
|
sem_stats.get("active_predictions", 0),
|
|
"Number of currently active predictions"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"predictions_queue_depth",
|
|
sem_stats.get("queue_depth", 0),
|
|
"Number of predictions waiting in queue"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"predictions_total",
|
|
sem_stats.get("total_predictions", 0),
|
|
"Total number of predictions processed",
|
|
metric_type="counter"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"predictions_timeouts_total",
|
|
sem_stats.get("total_timeouts", 0),
|
|
"Total number of prediction timeouts",
|
|
metric_type="counter"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"predictions_avg_wait_seconds",
|
|
sem_stats.get("average_wait_seconds", 0),
|
|
"Average wait time for predictions in seconds"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"predictions_max_concurrent",
|
|
sem_stats.get("max_concurrent", 2),
|
|
"Maximum concurrent predictions allowed"
|
|
))
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Could not get semaphore stats for metrics: {e}")
|
|
|
|
# Service Pool metrics
|
|
try:
|
|
from app.services.service_pool import get_service_pool
|
|
pool = get_service_pool()
|
|
pool_stats = pool.get_pool_stats()
|
|
|
|
metrics.append(self._format_metric(
|
|
"pool_services_total",
|
|
pool_stats.get("total_services", 0),
|
|
"Total number of services in pool"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"pool_services_available",
|
|
pool_stats.get("available_services", 0),
|
|
"Number of available services in pool"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"pool_services_in_use",
|
|
pool_stats.get("in_use_services", 0),
|
|
"Number of services currently in use"
|
|
))
|
|
|
|
pool_metrics = pool_stats.get("metrics", {})
|
|
metrics.append(self._format_metric(
|
|
"pool_acquisitions_total",
|
|
pool_metrics.get("total_acquisitions", 0),
|
|
"Total number of service acquisitions",
|
|
metric_type="counter"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"pool_releases_total",
|
|
pool_metrics.get("total_releases", 0),
|
|
"Total number of service releases",
|
|
metric_type="counter"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"pool_timeouts_total",
|
|
pool_metrics.get("total_timeouts", 0),
|
|
"Total number of acquisition timeouts",
|
|
metric_type="counter"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"pool_errors_total",
|
|
pool_metrics.get("total_errors", 0),
|
|
"Total number of pool errors",
|
|
metric_type="counter"
|
|
))
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Could not get pool stats for metrics: {e}")
|
|
|
|
# Recovery Manager metrics
|
|
try:
|
|
recovery_manager = get_recovery_manager()
|
|
recovery_state = recovery_manager.get_state()
|
|
|
|
metrics.append(self._format_metric(
|
|
"recovery_count_total",
|
|
recovery_state.get("recovery_count", 0),
|
|
"Total number of recovery attempts",
|
|
metric_type="counter"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"recovery_in_cooldown",
|
|
1 if recovery_state.get("in_cooldown", False) else 0,
|
|
"Whether recovery is currently in cooldown (1=yes, 0=no)"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"recovery_cooldown_remaining_seconds",
|
|
recovery_state.get("cooldown_remaining_seconds", 0),
|
|
"Remaining cooldown time in seconds"
|
|
))
|
|
metrics.append(self._format_metric(
|
|
"recovery_recent_count",
|
|
recovery_state.get("recent_recoveries", 0),
|
|
"Number of recent recovery attempts within window"
|
|
))
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Could not get recovery stats for metrics: {e}")
|
|
|
|
# Custom metrics
|
|
with self._lock:
|
|
for name, value in self._custom_metrics.items():
|
|
if "{" in name:
|
|
# Metric with labels
|
|
base_name = name.split("{")[0]
|
|
metrics.append(f"{self.METRIC_PREFIX}{name} {value}")
|
|
else:
|
|
metrics.append(f"{self.METRIC_PREFIX}{name} {value}")
|
|
|
|
return "\n\n".join(metrics) + "\n"
|
|
|
|
|
|
# Global Prometheus metrics instance
|
|
_prometheus_metrics: Optional[PrometheusMetrics] = None
|
|
|
|
|
|
def get_prometheus_metrics() -> PrometheusMetrics:
|
|
"""Get the global PrometheusMetrics instance"""
|
|
global _prometheus_metrics
|
|
if _prometheus_metrics is None:
|
|
_prometheus_metrics = PrometheusMetrics()
|
|
return _prometheus_metrics
|
|
|
|
|
|
def shutdown_prometheus_metrics():
|
|
"""Shutdown the global PrometheusMetrics instance"""
|
|
global _prometheus_metrics
|
|
_prometheus_metrics = None
|