Backend: - Add hybrid image extraction for Direct track (inline image blocks) - Add render_inline_image_regions() fallback when OCR doesn't find images - Add check_document_for_missing_images() for detecting missing images - Add memory management system (MemoryGuard, ModelManager, ServicePool) - Update pdf_generator_service to handle HYBRID processing track - Add ElementType.LOGO for logo extraction Frontend: - Fix PDF viewer re-rendering issues with memoization - Add TaskNotFound component and useTaskValidation hook - Disable StrictMode due to react-pdf incompatibility - Fix task detail and results page loading states 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1987 lines
64 KiB
Python
1987 lines
64 KiB
Python
"""
|
|
Tests for Memory Management Components
|
|
|
|
Tests ModelManager, MemoryGuard, and related functionality.
|
|
"""
|
|
|
|
import gc
|
|
import pytest
|
|
import threading
|
|
import time
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
import sys
|
|
|
|
# Mock paddle before importing memory_manager to avoid import errors
|
|
# when paddle is not installed in the test environment
|
|
paddle_mock = MagicMock()
|
|
paddle_mock.is_compiled_with_cuda.return_value = False
|
|
paddle_mock.device.cuda.device_count.return_value = 0
|
|
paddle_mock.device.cuda.memory_allocated.return_value = 0
|
|
paddle_mock.device.cuda.memory_reserved.return_value = 0
|
|
paddle_mock.device.cuda.empty_cache = MagicMock()
|
|
sys.modules['paddle'] = paddle_mock
|
|
|
|
from app.services.memory_manager import (
|
|
ModelManager,
|
|
ModelEntry,
|
|
MemoryGuard,
|
|
MemoryConfig,
|
|
MemoryStats,
|
|
MemoryBackend,
|
|
get_model_manager,
|
|
shutdown_model_manager,
|
|
)
|
|
|
|
|
|
class TestMemoryConfig:
|
|
"""Tests for MemoryConfig class"""
|
|
|
|
def test_default_values(self):
|
|
"""Test default configuration values"""
|
|
config = MemoryConfig()
|
|
assert config.warning_threshold == 0.80
|
|
assert config.critical_threshold == 0.95
|
|
assert config.emergency_threshold == 0.98
|
|
assert config.model_idle_timeout_seconds == 300
|
|
assert config.enable_auto_cleanup is True
|
|
assert config.max_concurrent_predictions == 2
|
|
|
|
def test_custom_values(self):
|
|
"""Test custom configuration values"""
|
|
config = MemoryConfig(
|
|
warning_threshold=0.70,
|
|
critical_threshold=0.85,
|
|
model_idle_timeout_seconds=600,
|
|
)
|
|
assert config.warning_threshold == 0.70
|
|
assert config.critical_threshold == 0.85
|
|
assert config.model_idle_timeout_seconds == 600
|
|
|
|
|
|
class TestMemoryGuard:
|
|
"""Tests for MemoryGuard class"""
|
|
|
|
def setup_method(self):
|
|
"""Setup for each test"""
|
|
self.config = MemoryConfig(
|
|
warning_threshold=0.80,
|
|
critical_threshold=0.95,
|
|
emergency_threshold=0.98,
|
|
)
|
|
|
|
def test_initialization(self):
|
|
"""Test MemoryGuard initialization"""
|
|
guard = MemoryGuard(self.config)
|
|
assert guard.config == self.config
|
|
assert guard.backend is not None
|
|
guard.shutdown()
|
|
|
|
def test_get_memory_stats(self):
|
|
"""Test getting memory statistics"""
|
|
guard = MemoryGuard(self.config)
|
|
stats = guard.get_memory_stats()
|
|
assert isinstance(stats, MemoryStats)
|
|
assert stats.timestamp > 0
|
|
guard.shutdown()
|
|
|
|
def test_check_memory_below_warning(self):
|
|
"""Test memory check when below warning threshold"""
|
|
guard = MemoryGuard(self.config)
|
|
|
|
# Mock stats to be below warning
|
|
with patch.object(guard, 'get_memory_stats') as mock_stats:
|
|
mock_stats.return_value = MemoryStats(
|
|
gpu_used_ratio=0.50,
|
|
gpu_free_mb=4000,
|
|
gpu_total_mb=8000,
|
|
)
|
|
is_available, stats = guard.check_memory(required_mb=1000)
|
|
assert is_available is True
|
|
|
|
guard.shutdown()
|
|
|
|
def test_check_memory_above_warning(self):
|
|
"""Test memory check when above warning threshold"""
|
|
guard = MemoryGuard(self.config)
|
|
|
|
# Mock stats to be above warning
|
|
with patch.object(guard, 'get_memory_stats') as mock_stats:
|
|
mock_stats.return_value = MemoryStats(
|
|
gpu_used_ratio=0.85,
|
|
gpu_free_mb=1200,
|
|
gpu_total_mb=8000,
|
|
)
|
|
is_available, stats = guard.check_memory(required_mb=500)
|
|
# Should still return True (warning, not critical)
|
|
assert is_available is True
|
|
|
|
guard.shutdown()
|
|
|
|
def test_check_memory_above_critical(self):
|
|
"""Test memory check when above critical threshold"""
|
|
guard = MemoryGuard(self.config)
|
|
|
|
# Mock stats to be above critical
|
|
with patch.object(guard, 'get_memory_stats') as mock_stats:
|
|
mock_stats.return_value = MemoryStats(
|
|
gpu_used_ratio=0.96,
|
|
gpu_free_mb=320,
|
|
gpu_total_mb=8000,
|
|
)
|
|
is_available, stats = guard.check_memory(required_mb=100)
|
|
# Should return False (critical)
|
|
assert is_available is False
|
|
|
|
guard.shutdown()
|
|
|
|
def test_check_memory_insufficient_free(self):
|
|
"""Test memory check when insufficient free memory"""
|
|
guard = MemoryGuard(self.config)
|
|
|
|
# Mock stats with insufficient free memory
|
|
with patch.object(guard, 'get_memory_stats') as mock_stats:
|
|
mock_stats.return_value = MemoryStats(
|
|
gpu_used_ratio=0.70,
|
|
gpu_free_mb=500,
|
|
gpu_total_mb=8000,
|
|
)
|
|
is_available, stats = guard.check_memory(required_mb=1000)
|
|
# Should return False (not enough free)
|
|
assert is_available is False
|
|
|
|
guard.shutdown()
|
|
|
|
def test_alert_history(self):
|
|
"""Test alert history functionality"""
|
|
guard = MemoryGuard(self.config)
|
|
|
|
# Trigger some alerts
|
|
guard._add_alert("warning", "Test warning")
|
|
guard._add_alert("critical", "Test critical")
|
|
|
|
alerts = guard.get_alerts()
|
|
assert len(alerts) == 2
|
|
assert alerts[0]["level"] == "warning"
|
|
assert alerts[1]["level"] == "critical"
|
|
|
|
guard.shutdown()
|
|
|
|
def test_clear_gpu_cache(self):
|
|
"""Test GPU cache clearing"""
|
|
guard = MemoryGuard(self.config)
|
|
# Should not raise even if no GPU
|
|
guard.clear_gpu_cache()
|
|
guard.shutdown()
|
|
|
|
|
|
class TestModelManager:
|
|
"""Tests for ModelManager class"""
|
|
|
|
def setup_method(self):
|
|
"""Reset singleton before each test"""
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
ModelManager._lock = threading.Lock()
|
|
|
|
def teardown_method(self):
|
|
"""Cleanup after each test"""
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def test_singleton_pattern(self):
|
|
"""Test that ModelManager is a singleton"""
|
|
config = MemoryConfig()
|
|
manager1 = ModelManager(config)
|
|
manager2 = ModelManager()
|
|
assert manager1 is manager2
|
|
manager1.teardown()
|
|
|
|
def test_get_or_load_model_new(self):
|
|
"""Test loading a new model"""
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
mock_model = Mock()
|
|
loader_called = False
|
|
|
|
def loader():
|
|
nonlocal loader_called
|
|
loader_called = True
|
|
return mock_model
|
|
|
|
model = manager.get_or_load_model(
|
|
model_id="test_model",
|
|
loader_func=loader,
|
|
estimated_memory_mb=100
|
|
)
|
|
|
|
assert model is mock_model
|
|
assert loader_called is True
|
|
assert "test_model" in manager.models
|
|
assert manager.models["test_model"].ref_count == 1
|
|
|
|
manager.teardown()
|
|
|
|
def test_get_or_load_model_cached(self):
|
|
"""Test getting a cached model"""
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
mock_model = Mock()
|
|
load_count = 0
|
|
|
|
def loader():
|
|
nonlocal load_count
|
|
load_count += 1
|
|
return mock_model
|
|
|
|
# First load
|
|
model1 = manager.get_or_load_model("test_model", loader)
|
|
# Second load (should return cached)
|
|
model2 = manager.get_or_load_model("test_model", loader)
|
|
|
|
assert model1 is model2
|
|
assert load_count == 1 # Loader should only be called once
|
|
assert manager.models["test_model"].ref_count == 2
|
|
|
|
manager.teardown()
|
|
|
|
def test_release_model(self):
|
|
"""Test releasing model references"""
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
model = manager.get_or_load_model("test_model", lambda: Mock())
|
|
assert manager.models["test_model"].ref_count == 1
|
|
|
|
manager.release_model("test_model")
|
|
assert manager.models["test_model"].ref_count == 0
|
|
|
|
manager.teardown()
|
|
|
|
def test_unload_model(self):
|
|
"""Test unloading a model"""
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
manager.get_or_load_model("test_model", lambda: Mock())
|
|
manager.release_model("test_model")
|
|
|
|
success = manager.unload_model("test_model")
|
|
assert success is True
|
|
assert "test_model" not in manager.models
|
|
|
|
manager.teardown()
|
|
|
|
def test_unload_model_with_references(self):
|
|
"""Test that model with active references is not unloaded"""
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
manager.get_or_load_model("test_model", lambda: Mock())
|
|
# Don't release - ref_count is still 1
|
|
|
|
success = manager.unload_model("test_model", force=False)
|
|
assert success is False
|
|
assert "test_model" in manager.models
|
|
|
|
manager.teardown()
|
|
|
|
def test_unload_model_force(self):
|
|
"""Test force unloading a model with references"""
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
manager.get_or_load_model("test_model", lambda: Mock())
|
|
|
|
success = manager.unload_model("test_model", force=True)
|
|
assert success is True
|
|
assert "test_model" not in manager.models
|
|
|
|
manager.teardown()
|
|
|
|
def test_cleanup_callback(self):
|
|
"""Test cleanup callback is called on unload"""
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
cleanup_called = False
|
|
def cleanup():
|
|
nonlocal cleanup_called
|
|
cleanup_called = True
|
|
|
|
manager.get_or_load_model(
|
|
"test_model",
|
|
lambda: Mock(),
|
|
cleanup_callback=cleanup
|
|
)
|
|
manager.release_model("test_model")
|
|
manager.unload_model("test_model")
|
|
|
|
assert cleanup_called is True
|
|
manager.teardown()
|
|
|
|
def test_get_model_stats(self):
|
|
"""Test getting model statistics"""
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
manager.get_or_load_model("model1", lambda: Mock(), estimated_memory_mb=100)
|
|
manager.get_or_load_model("model2", lambda: Mock(), estimated_memory_mb=200)
|
|
|
|
stats = manager.get_model_stats()
|
|
assert stats["total_models"] == 2
|
|
assert "model1" in stats["models"]
|
|
assert "model2" in stats["models"]
|
|
assert stats["total_estimated_memory_mb"] == 300
|
|
|
|
manager.teardown()
|
|
|
|
def test_idle_cleanup(self):
|
|
"""Test idle model cleanup"""
|
|
config = MemoryConfig(
|
|
model_idle_timeout_seconds=1, # Short timeout for testing
|
|
memory_check_interval_seconds=60, # Don't auto-cleanup
|
|
)
|
|
manager = ModelManager(config)
|
|
|
|
manager.get_or_load_model("test_model", lambda: Mock())
|
|
manager.release_model("test_model")
|
|
|
|
# Manually set last_used to simulate idle
|
|
manager.models["test_model"].last_used = time.time() - 10
|
|
|
|
# Manually trigger cleanup
|
|
manager._cleanup_idle_models()
|
|
|
|
assert "test_model" not in manager.models
|
|
manager.teardown()
|
|
|
|
|
|
class TestGetModelManager:
|
|
"""Tests for get_model_manager helper function"""
|
|
|
|
def setup_method(self):
|
|
"""Reset singleton before each test"""
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def teardown_method(self):
|
|
"""Cleanup after each test"""
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def test_get_model_manager_creates_singleton(self):
|
|
"""Test that get_model_manager creates a singleton"""
|
|
manager1 = get_model_manager()
|
|
manager2 = get_model_manager()
|
|
assert manager1 is manager2
|
|
shutdown_model_manager()
|
|
|
|
def test_shutdown_model_manager(self):
|
|
"""Test shutdown_model_manager cleans up"""
|
|
manager = get_model_manager()
|
|
manager.get_or_load_model("test", lambda: Mock())
|
|
|
|
shutdown_model_manager()
|
|
|
|
# Should be able to create new manager
|
|
new_manager = get_model_manager()
|
|
assert "test" not in new_manager.models
|
|
shutdown_model_manager()
|
|
|
|
|
|
class TestConcurrency:
|
|
"""Tests for concurrent access"""
|
|
|
|
def setup_method(self):
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def teardown_method(self):
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def test_concurrent_model_access(self):
|
|
"""Test concurrent model loading"""
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
load_count = 0
|
|
lock = threading.Lock()
|
|
|
|
def loader():
|
|
nonlocal load_count
|
|
with lock:
|
|
load_count += 1
|
|
time.sleep(0.1) # Simulate slow load
|
|
return Mock()
|
|
|
|
results = []
|
|
|
|
def worker():
|
|
model = manager.get_or_load_model("shared_model", loader)
|
|
results.append(model)
|
|
|
|
threads = [threading.Thread(target=worker) for _ in range(5)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# All threads should get the same model
|
|
assert len(set(id(r) for r in results)) == 1
|
|
# Loader should only be called once
|
|
assert load_count == 1
|
|
# Ref count should match thread count
|
|
assert manager.models["shared_model"].ref_count == 5
|
|
|
|
manager.teardown()
|
|
|
|
|
|
class TestPredictionSemaphore:
|
|
"""Tests for PredictionSemaphore class"""
|
|
|
|
def setup_method(self):
|
|
"""Reset singleton before each test"""
|
|
from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore
|
|
shutdown_prediction_semaphore()
|
|
PredictionSemaphore._instance = None
|
|
PredictionSemaphore._lock = threading.Lock()
|
|
|
|
def teardown_method(self):
|
|
"""Cleanup after each test"""
|
|
from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore
|
|
shutdown_prediction_semaphore()
|
|
PredictionSemaphore._instance = None
|
|
|
|
def test_singleton_pattern(self):
|
|
"""Test that PredictionSemaphore is a singleton"""
|
|
from app.services.memory_manager import PredictionSemaphore
|
|
sem1 = PredictionSemaphore(max_concurrent=2)
|
|
sem2 = PredictionSemaphore(max_concurrent=4) # Different config should be ignored
|
|
assert sem1 is sem2
|
|
assert sem1._max_concurrent == 2
|
|
|
|
def test_acquire_release(self):
|
|
"""Test basic acquire and release"""
|
|
from app.services.memory_manager import PredictionSemaphore
|
|
sem = PredictionSemaphore(max_concurrent=2)
|
|
|
|
# Acquire first slot
|
|
assert sem.acquire(task_id="task1") is True
|
|
assert sem._active_predictions == 1
|
|
|
|
# Acquire second slot
|
|
assert sem.acquire(task_id="task2") is True
|
|
assert sem._active_predictions == 2
|
|
|
|
# Release one
|
|
sem.release(task_id="task1")
|
|
assert sem._active_predictions == 1
|
|
|
|
# Release another
|
|
sem.release(task_id="task2")
|
|
assert sem._active_predictions == 0
|
|
|
|
def test_acquire_blocks_when_full(self):
|
|
"""Test that acquire blocks when all slots are taken"""
|
|
from app.services.memory_manager import PredictionSemaphore
|
|
sem = PredictionSemaphore(max_concurrent=1)
|
|
|
|
# Acquire the only slot
|
|
assert sem.acquire(task_id="task1") is True
|
|
|
|
# Try to acquire another with short timeout - should fail
|
|
result = sem.acquire(timeout=0.1, task_id="task2")
|
|
assert result is False
|
|
assert sem._total_timeouts == 1
|
|
|
|
# Release first slot
|
|
sem.release(task_id="task1")
|
|
|
|
# Now should succeed
|
|
assert sem.acquire(task_id="task3") is True
|
|
|
|
def test_get_stats(self):
|
|
"""Test statistics tracking"""
|
|
from app.services.memory_manager import PredictionSemaphore
|
|
sem = PredictionSemaphore(max_concurrent=2)
|
|
|
|
sem.acquire(task_id="task1")
|
|
sem.acquire(task_id="task2")
|
|
sem.release(task_id="task1")
|
|
|
|
stats = sem.get_stats()
|
|
assert stats["max_concurrent"] == 2
|
|
assert stats["active_predictions"] == 1
|
|
assert stats["total_predictions"] == 2
|
|
assert stats["total_timeouts"] == 0
|
|
|
|
def test_concurrent_acquire(self):
|
|
"""Test concurrent access to semaphore"""
|
|
from app.services.memory_manager import PredictionSemaphore
|
|
sem = PredictionSemaphore(max_concurrent=2)
|
|
|
|
results = []
|
|
acquired_count = 0
|
|
lock = threading.Lock()
|
|
|
|
def worker(worker_id):
|
|
nonlocal acquired_count
|
|
if sem.acquire(timeout=1.0, task_id=f"task_{worker_id}"):
|
|
with lock:
|
|
acquired_count += 1
|
|
time.sleep(0.1) # Simulate work
|
|
sem.release(task_id=f"task_{worker_id}")
|
|
results.append(worker_id)
|
|
|
|
# Start 4 workers but only 2 slots
|
|
threads = [threading.Thread(target=worker, args=(i,)) for i in range(4)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# All should complete eventually
|
|
assert len(results) == 4
|
|
|
|
|
|
class TestPredictionContext:
|
|
"""Tests for prediction_context helper function"""
|
|
|
|
def setup_method(self):
|
|
"""Reset singleton before each test"""
|
|
from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore
|
|
shutdown_prediction_semaphore()
|
|
PredictionSemaphore._instance = None
|
|
PredictionSemaphore._lock = threading.Lock()
|
|
|
|
def teardown_method(self):
|
|
"""Cleanup after each test"""
|
|
from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore
|
|
shutdown_prediction_semaphore()
|
|
PredictionSemaphore._instance = None
|
|
|
|
def test_context_manager_success(self):
|
|
"""Test context manager for successful prediction"""
|
|
from app.services.memory_manager import prediction_context, get_prediction_semaphore
|
|
|
|
# Initialize semaphore
|
|
sem = get_prediction_semaphore(max_concurrent=2)
|
|
|
|
with prediction_context(timeout=5.0, task_id="test") as acquired:
|
|
assert acquired is True
|
|
assert sem._active_predictions == 1
|
|
|
|
# After exiting context, slot should be released
|
|
assert sem._active_predictions == 0
|
|
|
|
def test_context_manager_with_exception(self):
|
|
"""Test context manager releases on exception"""
|
|
from app.services.memory_manager import prediction_context, get_prediction_semaphore
|
|
|
|
sem = get_prediction_semaphore(max_concurrent=2)
|
|
|
|
with pytest.raises(ValueError):
|
|
with prediction_context(task_id="test") as acquired:
|
|
assert acquired is True
|
|
raise ValueError("Test error")
|
|
|
|
# Should still release the slot
|
|
assert sem._active_predictions == 0
|
|
|
|
def test_context_manager_timeout(self):
|
|
"""Test context manager when timeout occurs"""
|
|
from app.services.memory_manager import prediction_context, get_prediction_semaphore
|
|
|
|
sem = get_prediction_semaphore(max_concurrent=1)
|
|
|
|
# Acquire the only slot
|
|
sem.acquire(task_id="blocker")
|
|
|
|
# Context manager should timeout
|
|
with prediction_context(timeout=0.1, task_id="waiter") as acquired:
|
|
assert acquired is False
|
|
|
|
# Release blocker
|
|
sem.release(task_id="blocker")
|
|
|
|
|
|
class TestBatchProcessor:
|
|
"""Tests for BatchProcessor class"""
|
|
|
|
def test_add_item(self):
|
|
"""Test adding items to batch queue"""
|
|
from app.services.memory_manager import BatchProcessor, BatchItem, BatchPriority
|
|
|
|
processor = BatchProcessor(max_batch_size=5)
|
|
item = BatchItem(item_id="test1", data="data1", priority=BatchPriority.NORMAL)
|
|
processor.add_item(item)
|
|
|
|
assert processor.get_queue_size() == 1
|
|
|
|
def test_add_items_sorted_by_priority(self):
|
|
"""Test that items are sorted by priority"""
|
|
from app.services.memory_manager import BatchProcessor, BatchItem, BatchPriority
|
|
|
|
processor = BatchProcessor(max_batch_size=5)
|
|
|
|
processor.add_item(BatchItem(item_id="low", data="low", priority=BatchPriority.LOW))
|
|
processor.add_item(BatchItem(item_id="high", data="high", priority=BatchPriority.HIGH))
|
|
processor.add_item(BatchItem(item_id="normal", data="normal", priority=BatchPriority.NORMAL))
|
|
|
|
# High priority should be first
|
|
assert processor._queue[0].item_id == "high"
|
|
assert processor._queue[1].item_id == "normal"
|
|
assert processor._queue[2].item_id == "low"
|
|
|
|
def test_process_batch(self):
|
|
"""Test processing a batch of items"""
|
|
from app.services.memory_manager import BatchProcessor, BatchItem, BatchPriority
|
|
|
|
processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False)
|
|
|
|
for i in range(5):
|
|
processor.add_item(BatchItem(item_id=f"item{i}", data=i))
|
|
|
|
results = processor.process_batch(lambda x: x * 2)
|
|
|
|
# Should process max_batch_size items
|
|
assert len(results) == 3
|
|
assert all(r.success for r in results)
|
|
assert processor.get_queue_size() == 2 # 2 remaining
|
|
|
|
def test_process_all(self):
|
|
"""Test processing all items"""
|
|
from app.services.memory_manager import BatchProcessor, BatchItem
|
|
|
|
processor = BatchProcessor(max_batch_size=2, cleanup_between_batches=False)
|
|
|
|
for i in range(5):
|
|
processor.add_item(BatchItem(item_id=f"item{i}", data=i))
|
|
|
|
results = processor.process_all(lambda x: x * 2)
|
|
|
|
assert len(results) == 5
|
|
assert processor.get_queue_size() == 0
|
|
|
|
def test_process_with_failure(self):
|
|
"""Test handling of processing failures"""
|
|
from app.services.memory_manager import BatchProcessor, BatchItem
|
|
|
|
processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False)
|
|
|
|
processor.add_item(BatchItem(item_id="good", data=1))
|
|
processor.add_item(BatchItem(item_id="bad", data="error"))
|
|
|
|
def processor_func(data):
|
|
if data == "error":
|
|
raise ValueError("Test error")
|
|
return data * 2
|
|
|
|
results = processor.process_all(processor_func)
|
|
|
|
assert len(results) == 2
|
|
assert results[0].success is True
|
|
assert results[1].success is False
|
|
assert "Test error" in results[1].error
|
|
|
|
def test_memory_constraint_batching(self):
|
|
"""Test that batches respect memory constraints"""
|
|
from app.services.memory_manager import BatchProcessor, BatchItem
|
|
|
|
processor = BatchProcessor(
|
|
max_batch_size=10,
|
|
max_memory_per_batch_mb=100.0,
|
|
cleanup_between_batches=False
|
|
)
|
|
|
|
# Add items that exceed memory limit
|
|
processor.add_item(BatchItem(item_id="item1", data=1, estimated_memory_mb=60.0))
|
|
processor.add_item(BatchItem(item_id="item2", data=2, estimated_memory_mb=60.0))
|
|
processor.add_item(BatchItem(item_id="item3", data=3, estimated_memory_mb=60.0))
|
|
|
|
results = processor.process_batch(lambda x: x)
|
|
|
|
# Should only process items that fit in memory limit
|
|
# First item (60MB) fits, second (60MB) doesn't exceed 100MB together, third does
|
|
assert len(results) == 1 or len(results) == 2
|
|
assert processor.get_queue_size() >= 1
|
|
|
|
def test_get_stats(self):
|
|
"""Test statistics tracking"""
|
|
from app.services.memory_manager import BatchProcessor, BatchItem
|
|
|
|
processor = BatchProcessor(max_batch_size=2, cleanup_between_batches=False)
|
|
processor.add_item(BatchItem(item_id="item1", data=1))
|
|
processor.add_item(BatchItem(item_id="item2", data=2))
|
|
|
|
processor.process_all(lambda x: x)
|
|
|
|
stats = processor.get_stats()
|
|
assert stats["total_processed"] == 2
|
|
assert stats["total_batches"] == 1
|
|
assert stats["total_failures"] == 0
|
|
|
|
|
|
class TestProgressiveLoader:
|
|
"""Tests for ProgressiveLoader class"""
|
|
|
|
def test_initialize(self):
|
|
"""Test initializing loader with page count"""
|
|
from app.services.memory_manager import ProgressiveLoader
|
|
|
|
loader = ProgressiveLoader(lookahead_pages=2)
|
|
loader.initialize(total_pages=10)
|
|
|
|
stats = loader.get_stats()
|
|
assert stats["total_pages"] == 10
|
|
assert stats["loaded_pages_count"] == 0
|
|
|
|
def test_load_page(self):
|
|
"""Test loading a single page"""
|
|
from app.services.memory_manager import ProgressiveLoader
|
|
|
|
loader = ProgressiveLoader(lookahead_pages=2, cleanup_after_pages=10)
|
|
loader.initialize(total_pages=5)
|
|
|
|
page_data = loader.load_page(0, lambda p: f"page_{p}_data")
|
|
|
|
assert page_data == "page_0_data"
|
|
assert 0 in loader.get_loaded_pages()
|
|
|
|
def test_load_page_caching(self):
|
|
"""Test that loaded pages are cached"""
|
|
from app.services.memory_manager import ProgressiveLoader
|
|
|
|
load_count = 0
|
|
|
|
def loader_func(page_num):
|
|
nonlocal load_count
|
|
load_count += 1
|
|
return f"page_{page_num}"
|
|
|
|
loader = ProgressiveLoader(lookahead_pages=2, cleanup_after_pages=10)
|
|
loader.initialize(total_pages=5)
|
|
|
|
# Load same page twice
|
|
loader.load_page(0, loader_func)
|
|
loader.load_page(0, loader_func)
|
|
|
|
assert load_count == 1 # Should only load once
|
|
|
|
def test_unload_distant_pages(self):
|
|
"""Test that distant pages are unloaded"""
|
|
from app.services.memory_manager import ProgressiveLoader
|
|
|
|
loader = ProgressiveLoader(lookahead_pages=1, cleanup_after_pages=100)
|
|
loader.initialize(total_pages=10)
|
|
|
|
# Load several pages
|
|
for i in range(5):
|
|
loader.load_page(i, lambda p: f"page_{p}")
|
|
|
|
# After loading page 4, distant pages should be unloaded
|
|
loaded = loader.get_loaded_pages()
|
|
# Should keep only pages near current (4): pages 3, 4, and potentially 5
|
|
assert 0 not in loaded # Page 0 should be unloaded
|
|
|
|
def test_iterate_pages(self):
|
|
"""Test iterating through all pages"""
|
|
from app.services.memory_manager import ProgressiveLoader
|
|
|
|
loader = ProgressiveLoader(lookahead_pages=0, cleanup_after_pages=100)
|
|
loader.initialize(total_pages=5)
|
|
|
|
results = loader.iterate_pages(
|
|
loader_func=lambda p: f"page_{p}",
|
|
processor_func=lambda p, data: f"processed_{data}"
|
|
)
|
|
|
|
assert len(results) == 5
|
|
assert results[0] == "processed_page_0"
|
|
assert results[4] == "processed_page_4"
|
|
|
|
def test_progress_callback(self):
|
|
"""Test progress callback during iteration"""
|
|
from app.services.memory_manager import ProgressiveLoader
|
|
|
|
loader = ProgressiveLoader(lookahead_pages=0, cleanup_after_pages=100)
|
|
loader.initialize(total_pages=3)
|
|
|
|
progress_reports = []
|
|
|
|
def callback(current, total):
|
|
progress_reports.append((current, total))
|
|
|
|
loader.iterate_pages(
|
|
loader_func=lambda p: p,
|
|
processor_func=lambda p, d: d,
|
|
progress_callback=callback
|
|
)
|
|
|
|
assert progress_reports == [(1, 3), (2, 3), (3, 3)]
|
|
|
|
def test_clear(self):
|
|
"""Test clearing loaded pages"""
|
|
from app.services.memory_manager import ProgressiveLoader
|
|
|
|
loader = ProgressiveLoader(cleanup_after_pages=100)
|
|
loader.initialize(total_pages=5)
|
|
|
|
for i in range(3):
|
|
loader.load_page(i, lambda p: p)
|
|
|
|
loader.clear()
|
|
|
|
assert loader.get_loaded_pages() == []
|
|
|
|
|
|
class TestPriorityOperationQueue:
|
|
"""Tests for PriorityOperationQueue class"""
|
|
|
|
def test_enqueue_dequeue(self):
|
|
"""Test basic enqueue and dequeue"""
|
|
from app.services.memory_manager import PriorityOperationQueue, BatchPriority
|
|
|
|
queue = PriorityOperationQueue(max_size=10)
|
|
|
|
queue.enqueue("item1", "data1", BatchPriority.NORMAL)
|
|
result = queue.dequeue()
|
|
|
|
assert result is not None
|
|
item_id, data, priority = result
|
|
assert item_id == "item1"
|
|
assert data == "data1"
|
|
assert priority == BatchPriority.NORMAL
|
|
|
|
def test_priority_ordering(self):
|
|
"""Test that higher priority items are dequeued first"""
|
|
from app.services.memory_manager import PriorityOperationQueue, BatchPriority
|
|
|
|
queue = PriorityOperationQueue()
|
|
|
|
queue.enqueue("low", "low_data", BatchPriority.LOW)
|
|
queue.enqueue("high", "high_data", BatchPriority.HIGH)
|
|
queue.enqueue("normal", "normal_data", BatchPriority.NORMAL)
|
|
queue.enqueue("critical", "critical_data", BatchPriority.CRITICAL)
|
|
|
|
# Dequeue in priority order
|
|
item_id, _, _ = queue.dequeue()
|
|
assert item_id == "critical"
|
|
|
|
item_id, _, _ = queue.dequeue()
|
|
assert item_id == "high"
|
|
|
|
item_id, _, _ = queue.dequeue()
|
|
assert item_id == "normal"
|
|
|
|
item_id, _, _ = queue.dequeue()
|
|
assert item_id == "low"
|
|
|
|
def test_cancel(self):
|
|
"""Test cancelling an operation"""
|
|
from app.services.memory_manager import PriorityOperationQueue, BatchPriority
|
|
|
|
queue = PriorityOperationQueue()
|
|
|
|
queue.enqueue("item1", "data1", BatchPriority.NORMAL)
|
|
queue.enqueue("item2", "data2", BatchPriority.NORMAL)
|
|
|
|
# Cancel item1
|
|
assert queue.cancel("item1") is True
|
|
|
|
# Dequeue should skip cancelled item
|
|
result = queue.dequeue()
|
|
assert result[0] == "item2"
|
|
|
|
def test_dequeue_empty_returns_none(self):
|
|
"""Test that dequeue on empty queue returns None"""
|
|
from app.services.memory_manager import PriorityOperationQueue
|
|
|
|
queue = PriorityOperationQueue()
|
|
result = queue.dequeue(timeout=0.1)
|
|
assert result is None
|
|
|
|
def test_max_size_limit(self):
|
|
"""Test that queue respects max size"""
|
|
from app.services.memory_manager import PriorityOperationQueue, BatchPriority
|
|
|
|
queue = PriorityOperationQueue(max_size=2)
|
|
|
|
assert queue.enqueue("item1", "data1") is True
|
|
assert queue.enqueue("item2", "data2") is True
|
|
# Third item should fail without timeout
|
|
assert queue.enqueue("item3", "data3", timeout=0.1) is False
|
|
|
|
def test_get_stats(self):
|
|
"""Test queue statistics"""
|
|
from app.services.memory_manager import PriorityOperationQueue, BatchPriority
|
|
|
|
queue = PriorityOperationQueue()
|
|
|
|
queue.enqueue("item1", "data1", BatchPriority.HIGH)
|
|
queue.enqueue("item2", "data2", BatchPriority.LOW)
|
|
queue.dequeue()
|
|
|
|
stats = queue.get_stats()
|
|
assert stats["total_enqueued"] == 2
|
|
assert stats["total_dequeued"] == 1
|
|
assert stats["queue_size"] == 1
|
|
|
|
|
|
class TestRecoveryManager:
|
|
"""Tests for RecoveryManager class"""
|
|
|
|
def test_can_attempt_recovery(self):
|
|
"""Test recovery attempt checking"""
|
|
from app.services.memory_manager import RecoveryManager
|
|
|
|
manager = RecoveryManager(
|
|
cooldown_seconds=1.0,
|
|
max_recovery_attempts=3
|
|
)
|
|
|
|
can_recover, reason = manager.can_attempt_recovery()
|
|
assert can_recover is True
|
|
assert "allowed" in reason.lower()
|
|
|
|
def test_cooldown_period(self):
|
|
"""Test that cooldown period is enforced"""
|
|
from app.services.memory_manager import RecoveryManager
|
|
|
|
manager = RecoveryManager(
|
|
cooldown_seconds=60.0,
|
|
max_recovery_attempts=10
|
|
)
|
|
|
|
# First recovery
|
|
manager.attempt_recovery()
|
|
|
|
# Should be in cooldown
|
|
assert manager.is_in_cooldown() is True
|
|
|
|
can_recover, reason = manager.can_attempt_recovery()
|
|
assert can_recover is False
|
|
assert "cooldown" in reason.lower()
|
|
|
|
def test_max_recovery_attempts(self):
|
|
"""Test that max recovery attempts are enforced"""
|
|
from app.services.memory_manager import RecoveryManager
|
|
|
|
manager = RecoveryManager(
|
|
cooldown_seconds=0.01, # Very short cooldown for testing
|
|
max_recovery_attempts=2,
|
|
recovery_window_seconds=60.0
|
|
)
|
|
|
|
# Perform max attempts
|
|
for _ in range(2):
|
|
manager.attempt_recovery()
|
|
time.sleep(0.02) # Wait for cooldown
|
|
|
|
# Next attempt should be blocked
|
|
can_recover, reason = manager.can_attempt_recovery()
|
|
assert can_recover is False
|
|
assert "max" in reason.lower()
|
|
|
|
def test_recovery_callbacks(self):
|
|
"""Test recovery event callbacks"""
|
|
from app.services.memory_manager import RecoveryManager
|
|
|
|
manager = RecoveryManager(cooldown_seconds=0.01)
|
|
|
|
start_called = False
|
|
complete_called = False
|
|
complete_success = None
|
|
|
|
def on_start():
|
|
nonlocal start_called
|
|
start_called = True
|
|
|
|
def on_complete(success):
|
|
nonlocal complete_called, complete_success
|
|
complete_called = True
|
|
complete_success = success
|
|
|
|
manager.register_callbacks(on_start=on_start, on_complete=on_complete)
|
|
manager.attempt_recovery()
|
|
|
|
assert start_called is True
|
|
assert complete_called is True
|
|
assert complete_success is not None
|
|
|
|
def test_get_state(self):
|
|
"""Test getting recovery state"""
|
|
from app.services.memory_manager import RecoveryManager
|
|
|
|
manager = RecoveryManager(cooldown_seconds=60.0)
|
|
manager.attempt_recovery(error="Test error")
|
|
|
|
state = manager.get_state()
|
|
assert state["recovery_count"] == 1
|
|
assert state["in_cooldown"] is True
|
|
assert state["last_error"] == "Test error"
|
|
assert state["cooldown_remaining_seconds"] > 0
|
|
|
|
def test_emergency_release(self):
|
|
"""Test emergency memory release"""
|
|
from app.services.memory_manager import RecoveryManager
|
|
|
|
manager = RecoveryManager()
|
|
|
|
# Emergency release without model manager
|
|
result = manager.emergency_release(model_manager=None)
|
|
# Should complete without error (may or may not free memory in test env)
|
|
assert isinstance(result, bool)
|
|
|
|
|
|
# =============================================================================
|
|
# Section 1.2: Test model reload after unload
|
|
# =============================================================================
|
|
|
|
class TestModelReloadAfterUnload:
|
|
"""Tests for model reload after unload functionality"""
|
|
|
|
def setup_method(self):
|
|
"""Reset singleton before each test"""
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
ModelManager._lock = threading.Lock()
|
|
|
|
def teardown_method(self):
|
|
"""Cleanup after each test"""
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def test_reload_after_unload(self):
|
|
"""Test that a model can be reloaded after being unloaded"""
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
load_count = 0
|
|
|
|
def loader():
|
|
nonlocal load_count
|
|
load_count += 1
|
|
return Mock(name=f"model_instance_{load_count}")
|
|
|
|
# First load
|
|
model1 = manager.get_or_load_model("test_model", loader, estimated_memory_mb=100)
|
|
assert load_count == 1
|
|
assert "test_model" in manager.models
|
|
|
|
# Release and unload
|
|
manager.release_model("test_model")
|
|
success = manager.unload_model("test_model")
|
|
assert success is True
|
|
assert "test_model" not in manager.models
|
|
|
|
# Reload
|
|
model2 = manager.get_or_load_model("test_model", loader, estimated_memory_mb=100)
|
|
assert load_count == 2 # Loader called again
|
|
assert "test_model" in manager.models
|
|
|
|
# Models should be different instances
|
|
assert model1 is not model2
|
|
|
|
manager.teardown()
|
|
|
|
def test_reload_preserves_config(self):
|
|
"""Test that reloaded model uses the same configuration"""
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
cleanup_count = 0
|
|
|
|
def cleanup():
|
|
nonlocal cleanup_count
|
|
cleanup_count += 1
|
|
|
|
# Load with cleanup callback
|
|
manager.get_or_load_model(
|
|
"test_model",
|
|
lambda: Mock(),
|
|
estimated_memory_mb=200,
|
|
cleanup_callback=cleanup
|
|
)
|
|
|
|
# Release, unload (cleanup should be called)
|
|
manager.release_model("test_model")
|
|
manager.unload_model("test_model")
|
|
assert cleanup_count == 1
|
|
|
|
# Reload with new cleanup callback
|
|
def new_cleanup():
|
|
nonlocal cleanup_count
|
|
cleanup_count += 10
|
|
|
|
manager.get_or_load_model(
|
|
"test_model",
|
|
lambda: Mock(),
|
|
estimated_memory_mb=300,
|
|
cleanup_callback=new_cleanup
|
|
)
|
|
|
|
# Verify new estimated memory
|
|
assert manager.models["test_model"].estimated_memory_mb == 300
|
|
|
|
# Release and unload again
|
|
manager.release_model("test_model")
|
|
manager.unload_model("test_model")
|
|
assert cleanup_count == 11 # New cleanup was called
|
|
|
|
manager.teardown()
|
|
|
|
def test_concurrent_reload(self):
|
|
"""Test concurrent reload operations"""
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
load_count = 0
|
|
lock = threading.Lock()
|
|
|
|
def loader():
|
|
nonlocal load_count
|
|
with lock:
|
|
load_count += 1
|
|
time.sleep(0.05) # Simulate slow load
|
|
return Mock()
|
|
|
|
# Load, release, unload
|
|
manager.get_or_load_model("test_model", loader)
|
|
manager.release_model("test_model")
|
|
manager.unload_model("test_model")
|
|
|
|
# Concurrent reload attempts
|
|
results = []
|
|
|
|
def worker():
|
|
model = manager.get_or_load_model("test_model", loader)
|
|
results.append(model)
|
|
|
|
threads = [threading.Thread(target=worker) for _ in range(5)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# All threads should get the same model instance
|
|
assert len(results) == 5
|
|
assert len(set(id(r) for r in results)) == 1
|
|
# Only one additional load should have occurred
|
|
assert load_count == 2
|
|
|
|
manager.teardown()
|
|
|
|
|
|
# =============================================================================
|
|
# Section 4.2: Test memory savings with selective processing
|
|
# =============================================================================
|
|
|
|
class TestSelectiveProcessingMemorySavings:
|
|
"""Tests for memory savings with selective processing"""
|
|
|
|
def test_batch_processor_memory_constraints(self):
|
|
"""Test that batch processor respects memory constraints"""
|
|
from app.services.memory_manager import BatchProcessor, BatchItem
|
|
|
|
# Create processor with strict memory limit
|
|
processor = BatchProcessor(
|
|
max_batch_size=10,
|
|
max_memory_per_batch_mb=100.0,
|
|
cleanup_between_batches=False
|
|
)
|
|
|
|
# Add items with known memory estimates
|
|
processor.add_item(BatchItem(item_id="small1", data=1, estimated_memory_mb=20.0))
|
|
processor.add_item(BatchItem(item_id="small2", data=2, estimated_memory_mb=20.0))
|
|
processor.add_item(BatchItem(item_id="small3", data=3, estimated_memory_mb=20.0))
|
|
processor.add_item(BatchItem(item_id="large1", data=4, estimated_memory_mb=80.0))
|
|
|
|
# First batch should include items that fit
|
|
results = processor.process_batch(lambda x: x)
|
|
|
|
# Items should be processed respecting memory limit
|
|
total_memory_in_batch = sum(
|
|
item.estimated_memory_mb
|
|
for item in processor._queue
|
|
) + sum(20.0 for _ in results) # Processed items had 20MB each
|
|
|
|
# Remaining items should include the large one
|
|
assert processor.get_queue_size() >= 1
|
|
|
|
def test_progressive_loader_memory_efficiency(self):
|
|
"""Test that progressive loader manages memory efficiently"""
|
|
from app.services.memory_manager import ProgressiveLoader
|
|
|
|
loader = ProgressiveLoader(lookahead_pages=1, cleanup_after_pages=100)
|
|
loader.initialize(total_pages=10)
|
|
|
|
pages_loaded = []
|
|
|
|
def loader_func(page_num):
|
|
pages_loaded.append(page_num)
|
|
return f"page_data_{page_num}"
|
|
|
|
# Load pages sequentially
|
|
for i in range(10):
|
|
loader.load_page(i, loader_func)
|
|
|
|
# Only recent pages should be kept in memory
|
|
loaded = loader.get_loaded_pages()
|
|
|
|
# Should have unloaded distant pages
|
|
assert 0 not in loaded # First page should be unloaded
|
|
assert len(loaded) <= 3 # Current + lookahead
|
|
|
|
loader.clear()
|
|
|
|
def test_priority_queue_processing_order(self):
|
|
"""Test that priority queue processes high priority items first"""
|
|
from app.services.memory_manager import PriorityOperationQueue, BatchPriority
|
|
|
|
queue = PriorityOperationQueue()
|
|
|
|
# Add items with different priorities
|
|
queue.enqueue("low1", "data", BatchPriority.LOW)
|
|
queue.enqueue("critical1", "data", BatchPriority.CRITICAL)
|
|
queue.enqueue("normal1", "data", BatchPriority.NORMAL)
|
|
queue.enqueue("high1", "data", BatchPriority.HIGH)
|
|
|
|
# Process in priority order
|
|
order = []
|
|
while True:
|
|
result = queue.dequeue(timeout=0.01)
|
|
if result is None:
|
|
break
|
|
order.append(result[0])
|
|
|
|
assert order[0] == "critical1"
|
|
assert order[1] == "high1"
|
|
assert order[2] == "normal1"
|
|
assert order[3] == "low1"
|
|
|
|
|
|
# =============================================================================
|
|
# Section 5.2: Test recovery under various scenarios
|
|
# =============================================================================
|
|
|
|
class TestRecoveryScenarios:
|
|
"""Tests for recovery under various scenarios"""
|
|
|
|
def test_recovery_after_oom_simulation(self):
|
|
"""Test recovery behavior after simulated OOM"""
|
|
from app.services.memory_manager import RecoveryManager
|
|
|
|
manager = RecoveryManager(
|
|
cooldown_seconds=0.1,
|
|
max_recovery_attempts=5
|
|
)
|
|
|
|
# Simulate OOM recovery
|
|
success = manager.attempt_recovery(error="CUDA out of memory")
|
|
assert success is not None # Recovery attempted
|
|
|
|
# Check state
|
|
state = manager.get_state()
|
|
assert state["recovery_count"] == 1
|
|
assert state["last_error"] == "CUDA out of memory"
|
|
|
|
def test_recovery_cooldown_enforcement(self):
|
|
"""Test that cooldown period is strictly enforced"""
|
|
from app.services.memory_manager import RecoveryManager
|
|
|
|
manager = RecoveryManager(
|
|
cooldown_seconds=1.0,
|
|
max_recovery_attempts=10
|
|
)
|
|
|
|
# First recovery
|
|
manager.attempt_recovery()
|
|
assert manager.is_in_cooldown() is True
|
|
|
|
# Try immediate second recovery
|
|
can_recover, reason = manager.can_attempt_recovery()
|
|
assert can_recover is False
|
|
assert "cooldown" in reason.lower()
|
|
|
|
# Wait for cooldown
|
|
time.sleep(1.1)
|
|
can_recover, reason = manager.can_attempt_recovery()
|
|
assert can_recover is True
|
|
|
|
def test_recovery_max_attempts_window(self):
|
|
"""Test that max recovery attempts are enforced within window"""
|
|
from app.services.memory_manager import RecoveryManager
|
|
|
|
manager = RecoveryManager(
|
|
cooldown_seconds=0.01,
|
|
max_recovery_attempts=3,
|
|
recovery_window_seconds=60.0
|
|
)
|
|
|
|
# Perform max attempts
|
|
for i in range(3):
|
|
manager.attempt_recovery(error=f"Error {i}")
|
|
time.sleep(0.02) # Wait for cooldown
|
|
|
|
# Next attempt should be blocked
|
|
can_recover, reason = manager.can_attempt_recovery()
|
|
assert can_recover is False
|
|
assert "max" in reason.lower()
|
|
|
|
def test_emergency_release_with_model_manager(self):
|
|
"""Test emergency release unloads models"""
|
|
from app.services.memory_manager import RecoveryManager
|
|
|
|
# Create a model manager with test models
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
config = MemoryConfig()
|
|
model_manager = ModelManager(config)
|
|
|
|
# Load some test models
|
|
model_manager.get_or_load_model("model1", lambda: Mock(), estimated_memory_mb=100)
|
|
model_manager.get_or_load_model("model2", lambda: Mock(), estimated_memory_mb=200)
|
|
|
|
# Release references
|
|
model_manager.release_model("model1")
|
|
model_manager.release_model("model2")
|
|
|
|
assert len(model_manager.models) == 2
|
|
|
|
# Emergency release
|
|
recovery_manager = RecoveryManager()
|
|
result = recovery_manager.emergency_release(model_manager=model_manager)
|
|
|
|
# Models should be unloaded
|
|
assert len(model_manager.models) == 0
|
|
|
|
model_manager.teardown()
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
|
|
# =============================================================================
|
|
# Section 6.1: Test shutdown sequence
|
|
# =============================================================================
|
|
|
|
class TestShutdownSequence:
|
|
"""Tests for shutdown sequence"""
|
|
|
|
def test_model_manager_teardown(self):
|
|
"""Test that teardown properly cleans up models"""
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
# Load models
|
|
manager.get_or_load_model("model1", lambda: Mock())
|
|
manager.get_or_load_model("model2", lambda: Mock())
|
|
|
|
assert len(manager.models) == 2
|
|
|
|
# Teardown
|
|
manager.teardown()
|
|
|
|
assert len(manager.models) == 0
|
|
assert manager._monitor_running is False
|
|
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def test_cleanup_callbacks_called_on_teardown(self):
|
|
"""Test that cleanup callbacks are called during teardown"""
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
cleanup_calls = []
|
|
|
|
def cleanup1():
|
|
cleanup_calls.append("model1")
|
|
|
|
def cleanup2():
|
|
cleanup_calls.append("model2")
|
|
|
|
manager.get_or_load_model("model1", lambda: Mock(), cleanup_callback=cleanup1)
|
|
manager.get_or_load_model("model2", lambda: Mock(), cleanup_callback=cleanup2)
|
|
|
|
# Teardown with force unload
|
|
manager.teardown()
|
|
|
|
# Both callbacks should have been called
|
|
assert "model1" in cleanup_calls
|
|
assert "model2" in cleanup_calls
|
|
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def test_prediction_semaphore_shutdown(self):
|
|
"""Test prediction semaphore shutdown"""
|
|
from app.services.memory_manager import (
|
|
get_prediction_semaphore,
|
|
shutdown_prediction_semaphore,
|
|
PredictionSemaphore
|
|
)
|
|
|
|
shutdown_prediction_semaphore()
|
|
PredictionSemaphore._instance = None
|
|
PredictionSemaphore._lock = threading.Lock()
|
|
|
|
sem = get_prediction_semaphore(max_concurrent=2)
|
|
sem.acquire(task_id="test1")
|
|
|
|
# Shutdown should reset the semaphore
|
|
shutdown_prediction_semaphore()
|
|
|
|
# New instance should be fresh
|
|
new_sem = get_prediction_semaphore(max_concurrent=3)
|
|
assert new_sem._active_predictions == 0
|
|
|
|
shutdown_prediction_semaphore()
|
|
PredictionSemaphore._instance = None
|
|
|
|
|
|
# =============================================================================
|
|
# Section 6.2: Test cleanup in error scenarios
|
|
# =============================================================================
|
|
|
|
class TestCleanupInErrorScenarios:
|
|
"""Tests for cleanup in error scenarios"""
|
|
|
|
def test_cleanup_after_loader_exception(self):
|
|
"""Test cleanup when model loader raises exception"""
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
def failing_loader():
|
|
raise RuntimeError("Loader failed")
|
|
|
|
with pytest.raises(RuntimeError):
|
|
manager.get_or_load_model("failing_model", failing_loader)
|
|
|
|
# Model should not be in the manager
|
|
assert "failing_model" not in manager.models
|
|
|
|
manager.teardown()
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def test_cleanup_after_processing_error(self):
|
|
"""Test cleanup after processing error in batch processor"""
|
|
from app.services.memory_manager import BatchProcessor, BatchItem
|
|
|
|
processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False)
|
|
|
|
processor.add_item(BatchItem(item_id="good1", data=1))
|
|
processor.add_item(BatchItem(item_id="bad", data="error"))
|
|
processor.add_item(BatchItem(item_id="good2", data=2))
|
|
|
|
def processor_func(data):
|
|
if data == "error":
|
|
raise ValueError("Processing error")
|
|
return data * 2
|
|
|
|
results = processor.process_all(processor_func)
|
|
|
|
# Good items should succeed, bad item should fail
|
|
assert len(results) == 3
|
|
assert results[0].success is True
|
|
assert results[1].success is False
|
|
assert results[2].success is True
|
|
|
|
# Stats should reflect failure
|
|
stats = processor.get_stats()
|
|
assert stats["total_failures"] == 1
|
|
|
|
def test_pool_release_with_error(self):
|
|
"""Test that pool properly handles release with error"""
|
|
from app.services.service_pool import (
|
|
OCRServicePool,
|
|
PoolConfig,
|
|
PooledService,
|
|
ServiceState,
|
|
shutdown_service_pool
|
|
)
|
|
|
|
shutdown_service_pool()
|
|
OCRServicePool._instance = None
|
|
OCRServicePool._lock = threading.Lock()
|
|
|
|
config = PoolConfig(max_consecutive_errors=2)
|
|
pool = OCRServicePool(config)
|
|
|
|
# Pre-populate with a mock service
|
|
mock_service = Mock()
|
|
pooled_service = PooledService(service=mock_service, device="GPU:0")
|
|
pool.services["GPU:0"].append(pooled_service)
|
|
|
|
# Acquire and release with errors
|
|
pooled = pool.acquire(device="GPU:0")
|
|
pool.release(pooled, error=Exception("Error 1"))
|
|
assert pooled.error_count == 1
|
|
assert pooled.state == ServiceState.AVAILABLE
|
|
|
|
pooled = pool.acquire(device="GPU:0")
|
|
pool.release(pooled, error=Exception("Error 2"))
|
|
assert pooled.error_count == 2
|
|
assert pooled.state == ServiceState.UNHEALTHY
|
|
|
|
pool.shutdown()
|
|
shutdown_service_pool()
|
|
OCRServicePool._instance = None
|
|
|
|
|
|
# =============================================================================
|
|
# Section 8.1: Memory leak detection tests
|
|
# =============================================================================
|
|
|
|
class TestMemoryLeakDetection:
|
|
"""Tests for memory leak detection"""
|
|
|
|
def test_no_leak_on_model_cycle(self):
|
|
"""Test that loading and unloading models doesn't leak"""
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
initial_model_count = len(manager.models)
|
|
|
|
# Perform multiple load/unload cycles
|
|
for i in range(10):
|
|
manager.get_or_load_model(f"temp_model_{i}", lambda: Mock())
|
|
manager.release_model(f"temp_model_{i}")
|
|
manager.unload_model(f"temp_model_{i}")
|
|
|
|
# Should be back to initial state
|
|
assert len(manager.models) == initial_model_count
|
|
|
|
# Force gc and check
|
|
gc.collect()
|
|
|
|
manager.teardown()
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def test_no_leak_on_semaphore_cycle(self):
|
|
"""Test that semaphore acquire/release doesn't leak"""
|
|
from app.services.memory_manager import (
|
|
PredictionSemaphore,
|
|
shutdown_prediction_semaphore
|
|
)
|
|
|
|
shutdown_prediction_semaphore()
|
|
PredictionSemaphore._instance = None
|
|
PredictionSemaphore._lock = threading.Lock()
|
|
|
|
sem = PredictionSemaphore(max_concurrent=2)
|
|
|
|
# Perform many acquire/release cycles
|
|
for i in range(100):
|
|
sem.acquire(task_id=f"task_{i}")
|
|
sem.release(task_id=f"task_{i}")
|
|
|
|
# Active predictions should be 0
|
|
assert sem._active_predictions == 0
|
|
assert sem._queue_depth == 0
|
|
|
|
shutdown_prediction_semaphore()
|
|
PredictionSemaphore._instance = None
|
|
|
|
def test_no_leak_in_batch_processor(self):
|
|
"""Test that batch processor doesn't leak items"""
|
|
from app.services.memory_manager import BatchProcessor, BatchItem
|
|
|
|
processor = BatchProcessor(max_batch_size=5, cleanup_between_batches=False)
|
|
|
|
# Add and process many items
|
|
for i in range(50):
|
|
processor.add_item(BatchItem(item_id=f"item_{i}", data=i))
|
|
|
|
processor.process_all(lambda x: x * 2)
|
|
|
|
# Queue should be empty
|
|
assert processor.get_queue_size() == 0
|
|
|
|
# Stats should be accurate
|
|
stats = processor.get_stats()
|
|
assert stats["total_processed"] == 50
|
|
|
|
|
|
# =============================================================================
|
|
# Section 8.1: Stress tests with concurrent requests
|
|
# =============================================================================
|
|
|
|
class TestStressConcurrentRequests:
|
|
"""Stress tests with concurrent requests"""
|
|
|
|
def test_concurrent_model_access_stress(self):
|
|
"""Stress test concurrent model access"""
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
load_count = 0
|
|
lock = threading.Lock()
|
|
|
|
def loader():
|
|
nonlocal load_count
|
|
with lock:
|
|
load_count += 1
|
|
return Mock()
|
|
|
|
results = []
|
|
errors = []
|
|
|
|
def worker(worker_id):
|
|
try:
|
|
model = manager.get_or_load_model("shared_model", loader)
|
|
time.sleep(0.01) # Simulate work
|
|
manager.release_model("shared_model")
|
|
results.append(worker_id)
|
|
except Exception as e:
|
|
errors.append(str(e))
|
|
|
|
# Launch many concurrent workers
|
|
threads = [threading.Thread(target=worker, args=(i,)) for i in range(20)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# All workers should complete without errors
|
|
assert len(errors) == 0
|
|
assert len(results) == 20
|
|
|
|
# Loader should only be called once
|
|
assert load_count == 1
|
|
|
|
manager.teardown()
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def test_concurrent_semaphore_stress(self):
|
|
"""Stress test concurrent semaphore operations"""
|
|
from app.services.memory_manager import (
|
|
PredictionSemaphore,
|
|
shutdown_prediction_semaphore
|
|
)
|
|
|
|
shutdown_prediction_semaphore()
|
|
PredictionSemaphore._instance = None
|
|
PredictionSemaphore._lock = threading.Lock()
|
|
|
|
sem = PredictionSemaphore(max_concurrent=3)
|
|
|
|
results = []
|
|
max_concurrent_observed = 0
|
|
current_count = 0
|
|
lock = threading.Lock()
|
|
|
|
def worker(worker_id):
|
|
nonlocal max_concurrent_observed, current_count
|
|
if sem.acquire(timeout=10.0, task_id=f"task_{worker_id}"):
|
|
with lock:
|
|
current_count += 1
|
|
max_concurrent_observed = max(max_concurrent_observed, current_count)
|
|
|
|
time.sleep(0.02) # Simulate work
|
|
|
|
with lock:
|
|
current_count -= 1
|
|
|
|
sem.release(task_id=f"task_{worker_id}")
|
|
results.append(worker_id)
|
|
|
|
# Launch many concurrent workers
|
|
threads = [threading.Thread(target=worker, args=(i,)) for i in range(15)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# All should complete
|
|
assert len(results) == 15
|
|
|
|
# Max concurrent should not exceed limit
|
|
assert max_concurrent_observed <= 3
|
|
|
|
shutdown_prediction_semaphore()
|
|
PredictionSemaphore._instance = None
|
|
|
|
def test_concurrent_batch_processing(self):
|
|
"""Stress test concurrent batch processing"""
|
|
from app.services.memory_manager import BatchProcessor, BatchItem
|
|
|
|
processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False)
|
|
|
|
# Add items from multiple threads
|
|
def add_items(start_id):
|
|
for i in range(10):
|
|
processor.add_item(BatchItem(item_id=f"item_{start_id}_{i}", data=i))
|
|
|
|
threads = [threading.Thread(target=add_items, args=(i,)) for i in range(5)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# Should have 50 items
|
|
assert processor.get_queue_size() == 50
|
|
|
|
# Process all
|
|
results = processor.process_all(lambda x: x)
|
|
assert len(results) == 50
|
|
|
|
|
|
# =============================================================================
|
|
# Section 8.1: Performance benchmarks
|
|
# =============================================================================
|
|
|
|
class TestPerformanceBenchmarks:
|
|
"""Performance benchmark tests"""
|
|
|
|
def test_model_load_performance(self):
|
|
"""Benchmark model loading performance"""
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
config = MemoryConfig()
|
|
manager = ModelManager(config)
|
|
|
|
load_times = []
|
|
|
|
for i in range(5):
|
|
start = time.time()
|
|
manager.get_or_load_model(f"model_{i}", lambda: Mock())
|
|
load_times.append(time.time() - start)
|
|
|
|
# Average load time should be reasonable (< 100ms for mock)
|
|
avg_load_time = sum(load_times) / len(load_times)
|
|
assert avg_load_time < 0.1 # 100ms
|
|
|
|
manager.teardown()
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def test_semaphore_throughput(self):
|
|
"""Benchmark semaphore throughput"""
|
|
from app.services.memory_manager import (
|
|
PredictionSemaphore,
|
|
shutdown_prediction_semaphore
|
|
)
|
|
|
|
shutdown_prediction_semaphore()
|
|
PredictionSemaphore._instance = None
|
|
PredictionSemaphore._lock = threading.Lock()
|
|
|
|
sem = PredictionSemaphore(max_concurrent=10)
|
|
|
|
start = time.time()
|
|
iterations = 1000
|
|
|
|
for i in range(iterations):
|
|
sem.acquire(timeout=1.0)
|
|
sem.release()
|
|
|
|
elapsed = time.time() - start
|
|
|
|
# Should handle at least 10000 ops/sec
|
|
ops_per_sec = iterations / elapsed
|
|
assert ops_per_sec > 1000
|
|
|
|
shutdown_prediction_semaphore()
|
|
PredictionSemaphore._instance = None
|
|
|
|
def test_batch_processor_throughput(self):
|
|
"""Benchmark batch processor throughput"""
|
|
from app.services.memory_manager import BatchProcessor, BatchItem
|
|
|
|
processor = BatchProcessor(max_batch_size=100, cleanup_between_batches=False)
|
|
|
|
# Add many items
|
|
for i in range(1000):
|
|
processor.add_item(BatchItem(item_id=f"item_{i}", data=i))
|
|
|
|
start = time.time()
|
|
results = processor.process_all(lambda x: x * 2)
|
|
elapsed = time.time() - start
|
|
|
|
# Should process at least 10000 items/sec
|
|
items_per_sec = len(results) / elapsed
|
|
assert items_per_sec > 1000
|
|
|
|
stats = processor.get_stats()
|
|
assert stats["total_processed"] == 1000
|
|
|
|
|
|
# =============================================================================
|
|
# Tests for Memory Dump and Prometheus Metrics (Section 5.2 & 7.2)
|
|
# =============================================================================
|
|
|
|
class TestMemoryDumper:
|
|
"""Tests for MemoryDumper class"""
|
|
|
|
def setup_method(self):
|
|
"""Reset singletons before each test"""
|
|
from app.services.memory_manager import shutdown_memory_dumper
|
|
shutdown_memory_dumper()
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def teardown_method(self):
|
|
"""Cleanup after each test"""
|
|
from app.services.memory_manager import shutdown_memory_dumper
|
|
shutdown_memory_dumper()
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def test_create_dump(self):
|
|
"""Test creating a memory dump"""
|
|
from app.services.memory_manager import MemoryDumper, MemoryDump
|
|
|
|
dumper = MemoryDumper()
|
|
dump = dumper.create_dump()
|
|
|
|
assert isinstance(dump, MemoryDump)
|
|
assert dump.timestamp > 0
|
|
assert isinstance(dump.loaded_models, list)
|
|
assert isinstance(dump.gc_stats, dict)
|
|
|
|
def test_dump_history(self):
|
|
"""Test dump history tracking"""
|
|
from app.services.memory_manager import MemoryDumper
|
|
|
|
dumper = MemoryDumper()
|
|
|
|
# Create multiple dumps
|
|
for _ in range(5):
|
|
dumper.create_dump()
|
|
|
|
history = dumper.get_dump_history()
|
|
assert len(history) == 5
|
|
|
|
latest = dumper.get_latest_dump()
|
|
assert latest is history[-1]
|
|
|
|
def test_dump_comparison(self):
|
|
"""Test comparing two dumps"""
|
|
from app.services.memory_manager import MemoryDumper
|
|
|
|
dumper = MemoryDumper()
|
|
|
|
dump1 = dumper.create_dump()
|
|
time.sleep(0.1)
|
|
dump2 = dumper.create_dump()
|
|
|
|
comparison = dumper.compare_dumps(dump1, dump2)
|
|
|
|
assert "time_delta_seconds" in comparison
|
|
assert comparison["time_delta_seconds"] > 0
|
|
assert "gpu_memory_change_mb" in comparison
|
|
assert "cpu_memory_change_mb" in comparison
|
|
|
|
def test_dump_to_dict(self):
|
|
"""Test converting dump to dictionary"""
|
|
from app.services.memory_manager import MemoryDumper
|
|
|
|
dumper = MemoryDumper()
|
|
dump = dumper.create_dump()
|
|
dump_dict = dumper.to_dict(dump)
|
|
|
|
assert "timestamp" in dump_dict
|
|
assert "gpu" in dump_dict
|
|
assert "cpu" in dump_dict
|
|
assert "models" in dump_dict
|
|
assert "predictions" in dump_dict
|
|
|
|
|
|
class TestPrometheusMetrics:
|
|
"""Tests for PrometheusMetrics class"""
|
|
|
|
def setup_method(self):
|
|
"""Reset singletons before each test"""
|
|
from app.services.memory_manager import shutdown_prometheus_metrics
|
|
shutdown_prometheus_metrics()
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def teardown_method(self):
|
|
"""Cleanup after each test"""
|
|
from app.services.memory_manager import shutdown_prometheus_metrics
|
|
shutdown_prometheus_metrics()
|
|
shutdown_model_manager()
|
|
ModelManager._instance = None
|
|
|
|
def test_export_metrics(self):
|
|
"""Test exporting metrics in Prometheus format"""
|
|
from app.services.memory_manager import PrometheusMetrics
|
|
|
|
prometheus = PrometheusMetrics()
|
|
metrics = prometheus.export_metrics()
|
|
|
|
# Should be a non-empty string
|
|
assert isinstance(metrics, str)
|
|
assert len(metrics) > 0
|
|
|
|
# Should contain expected metric prefixes
|
|
assert "tool_ocr_memory_" in metrics
|
|
|
|
def test_metric_format(self):
|
|
"""Test that metrics follow Prometheus format"""
|
|
from app.services.memory_manager import PrometheusMetrics
|
|
|
|
prometheus = PrometheusMetrics()
|
|
metrics = prometheus.export_metrics()
|
|
|
|
lines = metrics.split("\n")
|
|
|
|
# Check for HELP and TYPE comments
|
|
help_lines = [l for l in lines if l.startswith("# HELP")]
|
|
type_lines = [l for l in lines if l.startswith("# TYPE")]
|
|
|
|
assert len(help_lines) > 0
|
|
assert len(type_lines) > 0
|
|
|
|
def test_custom_metrics(self):
|
|
"""Test setting custom metrics"""
|
|
from app.services.memory_manager import PrometheusMetrics
|
|
|
|
prometheus = PrometheusMetrics()
|
|
|
|
prometheus.set_custom_metric("custom_value", 42.0)
|
|
prometheus.set_custom_metric("labeled_value", 100.0, {"env": "test"})
|
|
|
|
metrics = prometheus.export_metrics()
|
|
|
|
assert "custom_value" in metrics or "42" in metrics
|
|
|
|
def test_get_prometheus_metrics_singleton(self):
|
|
"""Test prometheus metrics singleton"""
|
|
from app.services.memory_manager import get_prometheus_metrics, shutdown_prometheus_metrics
|
|
|
|
metrics1 = get_prometheus_metrics()
|
|
metrics2 = get_prometheus_metrics()
|
|
|
|
assert metrics1 is metrics2
|
|
|
|
shutdown_prometheus_metrics()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|