""" Tests for Memory Management Components Tests ModelManager, MemoryGuard, and related functionality. """ import gc import pytest import threading import time from unittest.mock import Mock, patch, MagicMock import sys # Mock paddle before importing memory_manager to avoid import errors # when paddle is not installed in the test environment paddle_mock = MagicMock() paddle_mock.is_compiled_with_cuda.return_value = False paddle_mock.device.cuda.device_count.return_value = 0 paddle_mock.device.cuda.memory_allocated.return_value = 0 paddle_mock.device.cuda.memory_reserved.return_value = 0 paddle_mock.device.cuda.empty_cache = MagicMock() sys.modules['paddle'] = paddle_mock from app.services.memory_manager import ( ModelManager, ModelEntry, MemoryGuard, MemoryConfig, MemoryStats, MemoryBackend, get_model_manager, shutdown_model_manager, ) class TestMemoryConfig: """Tests for MemoryConfig class""" def test_default_values(self): """Test default configuration values""" config = MemoryConfig() assert config.warning_threshold == 0.80 assert config.critical_threshold == 0.95 assert config.emergency_threshold == 0.98 assert config.model_idle_timeout_seconds == 300 assert config.enable_auto_cleanup is True assert config.max_concurrent_predictions == 2 def test_custom_values(self): """Test custom configuration values""" config = MemoryConfig( warning_threshold=0.70, critical_threshold=0.85, model_idle_timeout_seconds=600, ) assert config.warning_threshold == 0.70 assert config.critical_threshold == 0.85 assert config.model_idle_timeout_seconds == 600 class TestMemoryGuard: """Tests for MemoryGuard class""" def setup_method(self): """Setup for each test""" self.config = MemoryConfig( warning_threshold=0.80, critical_threshold=0.95, emergency_threshold=0.98, ) def test_initialization(self): """Test MemoryGuard initialization""" guard = MemoryGuard(self.config) assert guard.config == self.config assert guard.backend is not None guard.shutdown() def test_get_memory_stats(self): """Test getting memory statistics""" guard = MemoryGuard(self.config) stats = guard.get_memory_stats() assert isinstance(stats, MemoryStats) assert stats.timestamp > 0 guard.shutdown() def test_check_memory_below_warning(self): """Test memory check when below warning threshold""" guard = MemoryGuard(self.config) # Mock stats to be below warning with patch.object(guard, 'get_memory_stats') as mock_stats: mock_stats.return_value = MemoryStats( gpu_used_ratio=0.50, gpu_free_mb=4000, gpu_total_mb=8000, ) is_available, stats = guard.check_memory(required_mb=1000) assert is_available is True guard.shutdown() def test_check_memory_above_warning(self): """Test memory check when above warning threshold""" guard = MemoryGuard(self.config) # Mock stats to be above warning with patch.object(guard, 'get_memory_stats') as mock_stats: mock_stats.return_value = MemoryStats( gpu_used_ratio=0.85, gpu_free_mb=1200, gpu_total_mb=8000, ) is_available, stats = guard.check_memory(required_mb=500) # Should still return True (warning, not critical) assert is_available is True guard.shutdown() def test_check_memory_above_critical(self): """Test memory check when above critical threshold""" guard = MemoryGuard(self.config) # Mock stats to be above critical with patch.object(guard, 'get_memory_stats') as mock_stats: mock_stats.return_value = MemoryStats( gpu_used_ratio=0.96, gpu_free_mb=320, gpu_total_mb=8000, ) is_available, stats = guard.check_memory(required_mb=100) # Should return False (critical) assert is_available is False guard.shutdown() def test_check_memory_insufficient_free(self): """Test memory check when insufficient free memory""" guard = MemoryGuard(self.config) # Mock stats with insufficient free memory with patch.object(guard, 'get_memory_stats') as mock_stats: mock_stats.return_value = MemoryStats( gpu_used_ratio=0.70, gpu_free_mb=500, gpu_total_mb=8000, ) is_available, stats = guard.check_memory(required_mb=1000) # Should return False (not enough free) assert is_available is False guard.shutdown() def test_alert_history(self): """Test alert history functionality""" guard = MemoryGuard(self.config) # Trigger some alerts guard._add_alert("warning", "Test warning") guard._add_alert("critical", "Test critical") alerts = guard.get_alerts() assert len(alerts) == 2 assert alerts[0]["level"] == "warning" assert alerts[1]["level"] == "critical" guard.shutdown() def test_clear_gpu_cache(self): """Test GPU cache clearing""" guard = MemoryGuard(self.config) # Should not raise even if no GPU guard.clear_gpu_cache() guard.shutdown() class TestModelManager: """Tests for ModelManager class""" def setup_method(self): """Reset singleton before each test""" shutdown_model_manager() ModelManager._instance = None ModelManager._lock = threading.Lock() def teardown_method(self): """Cleanup after each test""" shutdown_model_manager() ModelManager._instance = None def test_singleton_pattern(self): """Test that ModelManager is a singleton""" config = MemoryConfig() manager1 = ModelManager(config) manager2 = ModelManager() assert manager1 is manager2 manager1.teardown() def test_get_or_load_model_new(self): """Test loading a new model""" config = MemoryConfig() manager = ModelManager(config) mock_model = Mock() loader_called = False def loader(): nonlocal loader_called loader_called = True return mock_model model = manager.get_or_load_model( model_id="test_model", loader_func=loader, estimated_memory_mb=100 ) assert model is mock_model assert loader_called is True assert "test_model" in manager.models assert manager.models["test_model"].ref_count == 1 manager.teardown() def test_get_or_load_model_cached(self): """Test getting a cached model""" config = MemoryConfig() manager = ModelManager(config) mock_model = Mock() load_count = 0 def loader(): nonlocal load_count load_count += 1 return mock_model # First load model1 = manager.get_or_load_model("test_model", loader) # Second load (should return cached) model2 = manager.get_or_load_model("test_model", loader) assert model1 is model2 assert load_count == 1 # Loader should only be called once assert manager.models["test_model"].ref_count == 2 manager.teardown() def test_release_model(self): """Test releasing model references""" config = MemoryConfig() manager = ModelManager(config) model = manager.get_or_load_model("test_model", lambda: Mock()) assert manager.models["test_model"].ref_count == 1 manager.release_model("test_model") assert manager.models["test_model"].ref_count == 0 manager.teardown() def test_unload_model(self): """Test unloading a model""" config = MemoryConfig() manager = ModelManager(config) manager.get_or_load_model("test_model", lambda: Mock()) manager.release_model("test_model") success = manager.unload_model("test_model") assert success is True assert "test_model" not in manager.models manager.teardown() def test_unload_model_with_references(self): """Test that model with active references is not unloaded""" config = MemoryConfig() manager = ModelManager(config) manager.get_or_load_model("test_model", lambda: Mock()) # Don't release - ref_count is still 1 success = manager.unload_model("test_model", force=False) assert success is False assert "test_model" in manager.models manager.teardown() def test_unload_model_force(self): """Test force unloading a model with references""" config = MemoryConfig() manager = ModelManager(config) manager.get_or_load_model("test_model", lambda: Mock()) success = manager.unload_model("test_model", force=True) assert success is True assert "test_model" not in manager.models manager.teardown() def test_cleanup_callback(self): """Test cleanup callback is called on unload""" config = MemoryConfig() manager = ModelManager(config) cleanup_called = False def cleanup(): nonlocal cleanup_called cleanup_called = True manager.get_or_load_model( "test_model", lambda: Mock(), cleanup_callback=cleanup ) manager.release_model("test_model") manager.unload_model("test_model") assert cleanup_called is True manager.teardown() def test_get_model_stats(self): """Test getting model statistics""" config = MemoryConfig() manager = ModelManager(config) manager.get_or_load_model("model1", lambda: Mock(), estimated_memory_mb=100) manager.get_or_load_model("model2", lambda: Mock(), estimated_memory_mb=200) stats = manager.get_model_stats() assert stats["total_models"] == 2 assert "model1" in stats["models"] assert "model2" in stats["models"] assert stats["total_estimated_memory_mb"] == 300 manager.teardown() def test_idle_cleanup(self): """Test idle model cleanup""" config = MemoryConfig( model_idle_timeout_seconds=1, # Short timeout for testing memory_check_interval_seconds=60, # Don't auto-cleanup ) manager = ModelManager(config) manager.get_or_load_model("test_model", lambda: Mock()) manager.release_model("test_model") # Manually set last_used to simulate idle manager.models["test_model"].last_used = time.time() - 10 # Manually trigger cleanup manager._cleanup_idle_models() assert "test_model" not in manager.models manager.teardown() class TestGetModelManager: """Tests for get_model_manager helper function""" def setup_method(self): """Reset singleton before each test""" shutdown_model_manager() ModelManager._instance = None def teardown_method(self): """Cleanup after each test""" shutdown_model_manager() ModelManager._instance = None def test_get_model_manager_creates_singleton(self): """Test that get_model_manager creates a singleton""" manager1 = get_model_manager() manager2 = get_model_manager() assert manager1 is manager2 shutdown_model_manager() def test_shutdown_model_manager(self): """Test shutdown_model_manager cleans up""" manager = get_model_manager() manager.get_or_load_model("test", lambda: Mock()) shutdown_model_manager() # Should be able to create new manager new_manager = get_model_manager() assert "test" not in new_manager.models shutdown_model_manager() class TestConcurrency: """Tests for concurrent access""" def setup_method(self): shutdown_model_manager() ModelManager._instance = None def teardown_method(self): shutdown_model_manager() ModelManager._instance = None def test_concurrent_model_access(self): """Test concurrent model loading""" config = MemoryConfig() manager = ModelManager(config) load_count = 0 lock = threading.Lock() def loader(): nonlocal load_count with lock: load_count += 1 time.sleep(0.1) # Simulate slow load return Mock() results = [] def worker(): model = manager.get_or_load_model("shared_model", loader) results.append(model) threads = [threading.Thread(target=worker) for _ in range(5)] for t in threads: t.start() for t in threads: t.join() # All threads should get the same model assert len(set(id(r) for r in results)) == 1 # Loader should only be called once assert load_count == 1 # Ref count should match thread count assert manager.models["shared_model"].ref_count == 5 manager.teardown() class TestPredictionSemaphore: """Tests for PredictionSemaphore class""" def setup_method(self): """Reset singleton before each test""" from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore shutdown_prediction_semaphore() PredictionSemaphore._instance = None PredictionSemaphore._lock = threading.Lock() def teardown_method(self): """Cleanup after each test""" from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore shutdown_prediction_semaphore() PredictionSemaphore._instance = None def test_singleton_pattern(self): """Test that PredictionSemaphore is a singleton""" from app.services.memory_manager import PredictionSemaphore sem1 = PredictionSemaphore(max_concurrent=2) sem2 = PredictionSemaphore(max_concurrent=4) # Different config should be ignored assert sem1 is sem2 assert sem1._max_concurrent == 2 def test_acquire_release(self): """Test basic acquire and release""" from app.services.memory_manager import PredictionSemaphore sem = PredictionSemaphore(max_concurrent=2) # Acquire first slot assert sem.acquire(task_id="task1") is True assert sem._active_predictions == 1 # Acquire second slot assert sem.acquire(task_id="task2") is True assert sem._active_predictions == 2 # Release one sem.release(task_id="task1") assert sem._active_predictions == 1 # Release another sem.release(task_id="task2") assert sem._active_predictions == 0 def test_acquire_blocks_when_full(self): """Test that acquire blocks when all slots are taken""" from app.services.memory_manager import PredictionSemaphore sem = PredictionSemaphore(max_concurrent=1) # Acquire the only slot assert sem.acquire(task_id="task1") is True # Try to acquire another with short timeout - should fail result = sem.acquire(timeout=0.1, task_id="task2") assert result is False assert sem._total_timeouts == 1 # Release first slot sem.release(task_id="task1") # Now should succeed assert sem.acquire(task_id="task3") is True def test_get_stats(self): """Test statistics tracking""" from app.services.memory_manager import PredictionSemaphore sem = PredictionSemaphore(max_concurrent=2) sem.acquire(task_id="task1") sem.acquire(task_id="task2") sem.release(task_id="task1") stats = sem.get_stats() assert stats["max_concurrent"] == 2 assert stats["active_predictions"] == 1 assert stats["total_predictions"] == 2 assert stats["total_timeouts"] == 0 def test_concurrent_acquire(self): """Test concurrent access to semaphore""" from app.services.memory_manager import PredictionSemaphore sem = PredictionSemaphore(max_concurrent=2) results = [] acquired_count = 0 lock = threading.Lock() def worker(worker_id): nonlocal acquired_count if sem.acquire(timeout=1.0, task_id=f"task_{worker_id}"): with lock: acquired_count += 1 time.sleep(0.1) # Simulate work sem.release(task_id=f"task_{worker_id}") results.append(worker_id) # Start 4 workers but only 2 slots threads = [threading.Thread(target=worker, args=(i,)) for i in range(4)] for t in threads: t.start() for t in threads: t.join() # All should complete eventually assert len(results) == 4 class TestPredictionContext: """Tests for prediction_context helper function""" def setup_method(self): """Reset singleton before each test""" from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore shutdown_prediction_semaphore() PredictionSemaphore._instance = None PredictionSemaphore._lock = threading.Lock() def teardown_method(self): """Cleanup after each test""" from app.services.memory_manager import shutdown_prediction_semaphore, PredictionSemaphore shutdown_prediction_semaphore() PredictionSemaphore._instance = None def test_context_manager_success(self): """Test context manager for successful prediction""" from app.services.memory_manager import prediction_context, get_prediction_semaphore # Initialize semaphore sem = get_prediction_semaphore(max_concurrent=2) with prediction_context(timeout=5.0, task_id="test") as acquired: assert acquired is True assert sem._active_predictions == 1 # After exiting context, slot should be released assert sem._active_predictions == 0 def test_context_manager_with_exception(self): """Test context manager releases on exception""" from app.services.memory_manager import prediction_context, get_prediction_semaphore sem = get_prediction_semaphore(max_concurrent=2) with pytest.raises(ValueError): with prediction_context(task_id="test") as acquired: assert acquired is True raise ValueError("Test error") # Should still release the slot assert sem._active_predictions == 0 def test_context_manager_timeout(self): """Test context manager when timeout occurs""" from app.services.memory_manager import prediction_context, get_prediction_semaphore sem = get_prediction_semaphore(max_concurrent=1) # Acquire the only slot sem.acquire(task_id="blocker") # Context manager should timeout with prediction_context(timeout=0.1, task_id="waiter") as acquired: assert acquired is False # Release blocker sem.release(task_id="blocker") class TestBatchProcessor: """Tests for BatchProcessor class""" def test_add_item(self): """Test adding items to batch queue""" from app.services.memory_manager import BatchProcessor, BatchItem, BatchPriority processor = BatchProcessor(max_batch_size=5) item = BatchItem(item_id="test1", data="data1", priority=BatchPriority.NORMAL) processor.add_item(item) assert processor.get_queue_size() == 1 def test_add_items_sorted_by_priority(self): """Test that items are sorted by priority""" from app.services.memory_manager import BatchProcessor, BatchItem, BatchPriority processor = BatchProcessor(max_batch_size=5) processor.add_item(BatchItem(item_id="low", data="low", priority=BatchPriority.LOW)) processor.add_item(BatchItem(item_id="high", data="high", priority=BatchPriority.HIGH)) processor.add_item(BatchItem(item_id="normal", data="normal", priority=BatchPriority.NORMAL)) # High priority should be first assert processor._queue[0].item_id == "high" assert processor._queue[1].item_id == "normal" assert processor._queue[2].item_id == "low" def test_process_batch(self): """Test processing a batch of items""" from app.services.memory_manager import BatchProcessor, BatchItem, BatchPriority processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False) for i in range(5): processor.add_item(BatchItem(item_id=f"item{i}", data=i)) results = processor.process_batch(lambda x: x * 2) # Should process max_batch_size items assert len(results) == 3 assert all(r.success for r in results) assert processor.get_queue_size() == 2 # 2 remaining def test_process_all(self): """Test processing all items""" from app.services.memory_manager import BatchProcessor, BatchItem processor = BatchProcessor(max_batch_size=2, cleanup_between_batches=False) for i in range(5): processor.add_item(BatchItem(item_id=f"item{i}", data=i)) results = processor.process_all(lambda x: x * 2) assert len(results) == 5 assert processor.get_queue_size() == 0 def test_process_with_failure(self): """Test handling of processing failures""" from app.services.memory_manager import BatchProcessor, BatchItem processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False) processor.add_item(BatchItem(item_id="good", data=1)) processor.add_item(BatchItem(item_id="bad", data="error")) def processor_func(data): if data == "error": raise ValueError("Test error") return data * 2 results = processor.process_all(processor_func) assert len(results) == 2 assert results[0].success is True assert results[1].success is False assert "Test error" in results[1].error def test_memory_constraint_batching(self): """Test that batches respect memory constraints""" from app.services.memory_manager import BatchProcessor, BatchItem processor = BatchProcessor( max_batch_size=10, max_memory_per_batch_mb=100.0, cleanup_between_batches=False ) # Add items that exceed memory limit processor.add_item(BatchItem(item_id="item1", data=1, estimated_memory_mb=60.0)) processor.add_item(BatchItem(item_id="item2", data=2, estimated_memory_mb=60.0)) processor.add_item(BatchItem(item_id="item3", data=3, estimated_memory_mb=60.0)) results = processor.process_batch(lambda x: x) # Should only process items that fit in memory limit # First item (60MB) fits, second (60MB) doesn't exceed 100MB together, third does assert len(results) == 1 or len(results) == 2 assert processor.get_queue_size() >= 1 def test_get_stats(self): """Test statistics tracking""" from app.services.memory_manager import BatchProcessor, BatchItem processor = BatchProcessor(max_batch_size=2, cleanup_between_batches=False) processor.add_item(BatchItem(item_id="item1", data=1)) processor.add_item(BatchItem(item_id="item2", data=2)) processor.process_all(lambda x: x) stats = processor.get_stats() assert stats["total_processed"] == 2 assert stats["total_batches"] == 1 assert stats["total_failures"] == 0 class TestProgressiveLoader: """Tests for ProgressiveLoader class""" def test_initialize(self): """Test initializing loader with page count""" from app.services.memory_manager import ProgressiveLoader loader = ProgressiveLoader(lookahead_pages=2) loader.initialize(total_pages=10) stats = loader.get_stats() assert stats["total_pages"] == 10 assert stats["loaded_pages_count"] == 0 def test_load_page(self): """Test loading a single page""" from app.services.memory_manager import ProgressiveLoader loader = ProgressiveLoader(lookahead_pages=2, cleanup_after_pages=10) loader.initialize(total_pages=5) page_data = loader.load_page(0, lambda p: f"page_{p}_data") assert page_data == "page_0_data" assert 0 in loader.get_loaded_pages() def test_load_page_caching(self): """Test that loaded pages are cached""" from app.services.memory_manager import ProgressiveLoader load_count = 0 def loader_func(page_num): nonlocal load_count load_count += 1 return f"page_{page_num}" loader = ProgressiveLoader(lookahead_pages=2, cleanup_after_pages=10) loader.initialize(total_pages=5) # Load same page twice loader.load_page(0, loader_func) loader.load_page(0, loader_func) assert load_count == 1 # Should only load once def test_unload_distant_pages(self): """Test that distant pages are unloaded""" from app.services.memory_manager import ProgressiveLoader loader = ProgressiveLoader(lookahead_pages=1, cleanup_after_pages=100) loader.initialize(total_pages=10) # Load several pages for i in range(5): loader.load_page(i, lambda p: f"page_{p}") # After loading page 4, distant pages should be unloaded loaded = loader.get_loaded_pages() # Should keep only pages near current (4): pages 3, 4, and potentially 5 assert 0 not in loaded # Page 0 should be unloaded def test_iterate_pages(self): """Test iterating through all pages""" from app.services.memory_manager import ProgressiveLoader loader = ProgressiveLoader(lookahead_pages=0, cleanup_after_pages=100) loader.initialize(total_pages=5) results = loader.iterate_pages( loader_func=lambda p: f"page_{p}", processor_func=lambda p, data: f"processed_{data}" ) assert len(results) == 5 assert results[0] == "processed_page_0" assert results[4] == "processed_page_4" def test_progress_callback(self): """Test progress callback during iteration""" from app.services.memory_manager import ProgressiveLoader loader = ProgressiveLoader(lookahead_pages=0, cleanup_after_pages=100) loader.initialize(total_pages=3) progress_reports = [] def callback(current, total): progress_reports.append((current, total)) loader.iterate_pages( loader_func=lambda p: p, processor_func=lambda p, d: d, progress_callback=callback ) assert progress_reports == [(1, 3), (2, 3), (3, 3)] def test_clear(self): """Test clearing loaded pages""" from app.services.memory_manager import ProgressiveLoader loader = ProgressiveLoader(cleanup_after_pages=100) loader.initialize(total_pages=5) for i in range(3): loader.load_page(i, lambda p: p) loader.clear() assert loader.get_loaded_pages() == [] class TestPriorityOperationQueue: """Tests for PriorityOperationQueue class""" def test_enqueue_dequeue(self): """Test basic enqueue and dequeue""" from app.services.memory_manager import PriorityOperationQueue, BatchPriority queue = PriorityOperationQueue(max_size=10) queue.enqueue("item1", "data1", BatchPriority.NORMAL) result = queue.dequeue() assert result is not None item_id, data, priority = result assert item_id == "item1" assert data == "data1" assert priority == BatchPriority.NORMAL def test_priority_ordering(self): """Test that higher priority items are dequeued first""" from app.services.memory_manager import PriorityOperationQueue, BatchPriority queue = PriorityOperationQueue() queue.enqueue("low", "low_data", BatchPriority.LOW) queue.enqueue("high", "high_data", BatchPriority.HIGH) queue.enqueue("normal", "normal_data", BatchPriority.NORMAL) queue.enqueue("critical", "critical_data", BatchPriority.CRITICAL) # Dequeue in priority order item_id, _, _ = queue.dequeue() assert item_id == "critical" item_id, _, _ = queue.dequeue() assert item_id == "high" item_id, _, _ = queue.dequeue() assert item_id == "normal" item_id, _, _ = queue.dequeue() assert item_id == "low" def test_cancel(self): """Test cancelling an operation""" from app.services.memory_manager import PriorityOperationQueue, BatchPriority queue = PriorityOperationQueue() queue.enqueue("item1", "data1", BatchPriority.NORMAL) queue.enqueue("item2", "data2", BatchPriority.NORMAL) # Cancel item1 assert queue.cancel("item1") is True # Dequeue should skip cancelled item result = queue.dequeue() assert result[0] == "item2" def test_dequeue_empty_returns_none(self): """Test that dequeue on empty queue returns None""" from app.services.memory_manager import PriorityOperationQueue queue = PriorityOperationQueue() result = queue.dequeue(timeout=0.1) assert result is None def test_max_size_limit(self): """Test that queue respects max size""" from app.services.memory_manager import PriorityOperationQueue, BatchPriority queue = PriorityOperationQueue(max_size=2) assert queue.enqueue("item1", "data1") is True assert queue.enqueue("item2", "data2") is True # Third item should fail without timeout assert queue.enqueue("item3", "data3", timeout=0.1) is False def test_get_stats(self): """Test queue statistics""" from app.services.memory_manager import PriorityOperationQueue, BatchPriority queue = PriorityOperationQueue() queue.enqueue("item1", "data1", BatchPriority.HIGH) queue.enqueue("item2", "data2", BatchPriority.LOW) queue.dequeue() stats = queue.get_stats() assert stats["total_enqueued"] == 2 assert stats["total_dequeued"] == 1 assert stats["queue_size"] == 1 class TestRecoveryManager: """Tests for RecoveryManager class""" def test_can_attempt_recovery(self): """Test recovery attempt checking""" from app.services.memory_manager import RecoveryManager manager = RecoveryManager( cooldown_seconds=1.0, max_recovery_attempts=3 ) can_recover, reason = manager.can_attempt_recovery() assert can_recover is True assert "allowed" in reason.lower() def test_cooldown_period(self): """Test that cooldown period is enforced""" from app.services.memory_manager import RecoveryManager manager = RecoveryManager( cooldown_seconds=60.0, max_recovery_attempts=10 ) # First recovery manager.attempt_recovery() # Should be in cooldown assert manager.is_in_cooldown() is True can_recover, reason = manager.can_attempt_recovery() assert can_recover is False assert "cooldown" in reason.lower() def test_max_recovery_attempts(self): """Test that max recovery attempts are enforced""" from app.services.memory_manager import RecoveryManager manager = RecoveryManager( cooldown_seconds=0.01, # Very short cooldown for testing max_recovery_attempts=2, recovery_window_seconds=60.0 ) # Perform max attempts for _ in range(2): manager.attempt_recovery() time.sleep(0.02) # Wait for cooldown # Next attempt should be blocked can_recover, reason = manager.can_attempt_recovery() assert can_recover is False assert "max" in reason.lower() def test_recovery_callbacks(self): """Test recovery event callbacks""" from app.services.memory_manager import RecoveryManager manager = RecoveryManager(cooldown_seconds=0.01) start_called = False complete_called = False complete_success = None def on_start(): nonlocal start_called start_called = True def on_complete(success): nonlocal complete_called, complete_success complete_called = True complete_success = success manager.register_callbacks(on_start=on_start, on_complete=on_complete) manager.attempt_recovery() assert start_called is True assert complete_called is True assert complete_success is not None def test_get_state(self): """Test getting recovery state""" from app.services.memory_manager import RecoveryManager manager = RecoveryManager(cooldown_seconds=60.0) manager.attempt_recovery(error="Test error") state = manager.get_state() assert state["recovery_count"] == 1 assert state["in_cooldown"] is True assert state["last_error"] == "Test error" assert state["cooldown_remaining_seconds"] > 0 def test_emergency_release(self): """Test emergency memory release""" from app.services.memory_manager import RecoveryManager manager = RecoveryManager() # Emergency release without model manager result = manager.emergency_release(model_manager=None) # Should complete without error (may or may not free memory in test env) assert isinstance(result, bool) # ============================================================================= # Section 1.2: Test model reload after unload # ============================================================================= class TestModelReloadAfterUnload: """Tests for model reload after unload functionality""" def setup_method(self): """Reset singleton before each test""" shutdown_model_manager() ModelManager._instance = None ModelManager._lock = threading.Lock() def teardown_method(self): """Cleanup after each test""" shutdown_model_manager() ModelManager._instance = None def test_reload_after_unload(self): """Test that a model can be reloaded after being unloaded""" config = MemoryConfig() manager = ModelManager(config) load_count = 0 def loader(): nonlocal load_count load_count += 1 return Mock(name=f"model_instance_{load_count}") # First load model1 = manager.get_or_load_model("test_model", loader, estimated_memory_mb=100) assert load_count == 1 assert "test_model" in manager.models # Release and unload manager.release_model("test_model") success = manager.unload_model("test_model") assert success is True assert "test_model" not in manager.models # Reload model2 = manager.get_or_load_model("test_model", loader, estimated_memory_mb=100) assert load_count == 2 # Loader called again assert "test_model" in manager.models # Models should be different instances assert model1 is not model2 manager.teardown() def test_reload_preserves_config(self): """Test that reloaded model uses the same configuration""" config = MemoryConfig() manager = ModelManager(config) cleanup_count = 0 def cleanup(): nonlocal cleanup_count cleanup_count += 1 # Load with cleanup callback manager.get_or_load_model( "test_model", lambda: Mock(), estimated_memory_mb=200, cleanup_callback=cleanup ) # Release, unload (cleanup should be called) manager.release_model("test_model") manager.unload_model("test_model") assert cleanup_count == 1 # Reload with new cleanup callback def new_cleanup(): nonlocal cleanup_count cleanup_count += 10 manager.get_or_load_model( "test_model", lambda: Mock(), estimated_memory_mb=300, cleanup_callback=new_cleanup ) # Verify new estimated memory assert manager.models["test_model"].estimated_memory_mb == 300 # Release and unload again manager.release_model("test_model") manager.unload_model("test_model") assert cleanup_count == 11 # New cleanup was called manager.teardown() def test_concurrent_reload(self): """Test concurrent reload operations""" config = MemoryConfig() manager = ModelManager(config) load_count = 0 lock = threading.Lock() def loader(): nonlocal load_count with lock: load_count += 1 time.sleep(0.05) # Simulate slow load return Mock() # Load, release, unload manager.get_or_load_model("test_model", loader) manager.release_model("test_model") manager.unload_model("test_model") # Concurrent reload attempts results = [] def worker(): model = manager.get_or_load_model("test_model", loader) results.append(model) threads = [threading.Thread(target=worker) for _ in range(5)] for t in threads: t.start() for t in threads: t.join() # All threads should get the same model instance assert len(results) == 5 assert len(set(id(r) for r in results)) == 1 # Only one additional load should have occurred assert load_count == 2 manager.teardown() # ============================================================================= # Section 4.2: Test memory savings with selective processing # ============================================================================= class TestSelectiveProcessingMemorySavings: """Tests for memory savings with selective processing""" def test_batch_processor_memory_constraints(self): """Test that batch processor respects memory constraints""" from app.services.memory_manager import BatchProcessor, BatchItem # Create processor with strict memory limit processor = BatchProcessor( max_batch_size=10, max_memory_per_batch_mb=100.0, cleanup_between_batches=False ) # Add items with known memory estimates processor.add_item(BatchItem(item_id="small1", data=1, estimated_memory_mb=20.0)) processor.add_item(BatchItem(item_id="small2", data=2, estimated_memory_mb=20.0)) processor.add_item(BatchItem(item_id="small3", data=3, estimated_memory_mb=20.0)) processor.add_item(BatchItem(item_id="large1", data=4, estimated_memory_mb=80.0)) # First batch should include items that fit results = processor.process_batch(lambda x: x) # Items should be processed respecting memory limit total_memory_in_batch = sum( item.estimated_memory_mb for item in processor._queue ) + sum(20.0 for _ in results) # Processed items had 20MB each # Remaining items should include the large one assert processor.get_queue_size() >= 1 def test_progressive_loader_memory_efficiency(self): """Test that progressive loader manages memory efficiently""" from app.services.memory_manager import ProgressiveLoader loader = ProgressiveLoader(lookahead_pages=1, cleanup_after_pages=100) loader.initialize(total_pages=10) pages_loaded = [] def loader_func(page_num): pages_loaded.append(page_num) return f"page_data_{page_num}" # Load pages sequentially for i in range(10): loader.load_page(i, loader_func) # Only recent pages should be kept in memory loaded = loader.get_loaded_pages() # Should have unloaded distant pages assert 0 not in loaded # First page should be unloaded assert len(loaded) <= 3 # Current + lookahead loader.clear() def test_priority_queue_processing_order(self): """Test that priority queue processes high priority items first""" from app.services.memory_manager import PriorityOperationQueue, BatchPriority queue = PriorityOperationQueue() # Add items with different priorities queue.enqueue("low1", "data", BatchPriority.LOW) queue.enqueue("critical1", "data", BatchPriority.CRITICAL) queue.enqueue("normal1", "data", BatchPriority.NORMAL) queue.enqueue("high1", "data", BatchPriority.HIGH) # Process in priority order order = [] while True: result = queue.dequeue(timeout=0.01) if result is None: break order.append(result[0]) assert order[0] == "critical1" assert order[1] == "high1" assert order[2] == "normal1" assert order[3] == "low1" # ============================================================================= # Section 5.2: Test recovery under various scenarios # ============================================================================= class TestRecoveryScenarios: """Tests for recovery under various scenarios""" def test_recovery_after_oom_simulation(self): """Test recovery behavior after simulated OOM""" from app.services.memory_manager import RecoveryManager manager = RecoveryManager( cooldown_seconds=0.1, max_recovery_attempts=5 ) # Simulate OOM recovery success = manager.attempt_recovery(error="CUDA out of memory") assert success is not None # Recovery attempted # Check state state = manager.get_state() assert state["recovery_count"] == 1 assert state["last_error"] == "CUDA out of memory" def test_recovery_cooldown_enforcement(self): """Test that cooldown period is strictly enforced""" from app.services.memory_manager import RecoveryManager manager = RecoveryManager( cooldown_seconds=1.0, max_recovery_attempts=10 ) # First recovery manager.attempt_recovery() assert manager.is_in_cooldown() is True # Try immediate second recovery can_recover, reason = manager.can_attempt_recovery() assert can_recover is False assert "cooldown" in reason.lower() # Wait for cooldown time.sleep(1.1) can_recover, reason = manager.can_attempt_recovery() assert can_recover is True def test_recovery_max_attempts_window(self): """Test that max recovery attempts are enforced within window""" from app.services.memory_manager import RecoveryManager manager = RecoveryManager( cooldown_seconds=0.01, max_recovery_attempts=3, recovery_window_seconds=60.0 ) # Perform max attempts for i in range(3): manager.attempt_recovery(error=f"Error {i}") time.sleep(0.02) # Wait for cooldown # Next attempt should be blocked can_recover, reason = manager.can_attempt_recovery() assert can_recover is False assert "max" in reason.lower() def test_emergency_release_with_model_manager(self): """Test emergency release unloads models""" from app.services.memory_manager import RecoveryManager # Create a model manager with test models shutdown_model_manager() ModelManager._instance = None config = MemoryConfig() model_manager = ModelManager(config) # Load some test models model_manager.get_or_load_model("model1", lambda: Mock(), estimated_memory_mb=100) model_manager.get_or_load_model("model2", lambda: Mock(), estimated_memory_mb=200) # Release references model_manager.release_model("model1") model_manager.release_model("model2") assert len(model_manager.models) == 2 # Emergency release recovery_manager = RecoveryManager() result = recovery_manager.emergency_release(model_manager=model_manager) # Models should be unloaded assert len(model_manager.models) == 0 model_manager.teardown() shutdown_model_manager() ModelManager._instance = None # ============================================================================= # Section 6.1: Test shutdown sequence # ============================================================================= class TestShutdownSequence: """Tests for shutdown sequence""" def test_model_manager_teardown(self): """Test that teardown properly cleans up models""" shutdown_model_manager() ModelManager._instance = None config = MemoryConfig() manager = ModelManager(config) # Load models manager.get_or_load_model("model1", lambda: Mock()) manager.get_or_load_model("model2", lambda: Mock()) assert len(manager.models) == 2 # Teardown manager.teardown() assert len(manager.models) == 0 assert manager._monitor_running is False shutdown_model_manager() ModelManager._instance = None def test_cleanup_callbacks_called_on_teardown(self): """Test that cleanup callbacks are called during teardown""" shutdown_model_manager() ModelManager._instance = None config = MemoryConfig() manager = ModelManager(config) cleanup_calls = [] def cleanup1(): cleanup_calls.append("model1") def cleanup2(): cleanup_calls.append("model2") manager.get_or_load_model("model1", lambda: Mock(), cleanup_callback=cleanup1) manager.get_or_load_model("model2", lambda: Mock(), cleanup_callback=cleanup2) # Teardown with force unload manager.teardown() # Both callbacks should have been called assert "model1" in cleanup_calls assert "model2" in cleanup_calls shutdown_model_manager() ModelManager._instance = None def test_prediction_semaphore_shutdown(self): """Test prediction semaphore shutdown""" from app.services.memory_manager import ( get_prediction_semaphore, shutdown_prediction_semaphore, PredictionSemaphore ) shutdown_prediction_semaphore() PredictionSemaphore._instance = None PredictionSemaphore._lock = threading.Lock() sem = get_prediction_semaphore(max_concurrent=2) sem.acquire(task_id="test1") # Shutdown should reset the semaphore shutdown_prediction_semaphore() # New instance should be fresh new_sem = get_prediction_semaphore(max_concurrent=3) assert new_sem._active_predictions == 0 shutdown_prediction_semaphore() PredictionSemaphore._instance = None # ============================================================================= # Section 6.2: Test cleanup in error scenarios # ============================================================================= class TestCleanupInErrorScenarios: """Tests for cleanup in error scenarios""" def test_cleanup_after_loader_exception(self): """Test cleanup when model loader raises exception""" shutdown_model_manager() ModelManager._instance = None config = MemoryConfig() manager = ModelManager(config) def failing_loader(): raise RuntimeError("Loader failed") with pytest.raises(RuntimeError): manager.get_or_load_model("failing_model", failing_loader) # Model should not be in the manager assert "failing_model" not in manager.models manager.teardown() shutdown_model_manager() ModelManager._instance = None def test_cleanup_after_processing_error(self): """Test cleanup after processing error in batch processor""" from app.services.memory_manager import BatchProcessor, BatchItem processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False) processor.add_item(BatchItem(item_id="good1", data=1)) processor.add_item(BatchItem(item_id="bad", data="error")) processor.add_item(BatchItem(item_id="good2", data=2)) def processor_func(data): if data == "error": raise ValueError("Processing error") return data * 2 results = processor.process_all(processor_func) # Good items should succeed, bad item should fail assert len(results) == 3 assert results[0].success is True assert results[1].success is False assert results[2].success is True # Stats should reflect failure stats = processor.get_stats() assert stats["total_failures"] == 1 def test_pool_release_with_error(self): """Test that pool properly handles release with error""" from app.services.service_pool import ( OCRServicePool, PoolConfig, PooledService, ServiceState, shutdown_service_pool ) shutdown_service_pool() OCRServicePool._instance = None OCRServicePool._lock = threading.Lock() config = PoolConfig(max_consecutive_errors=2) pool = OCRServicePool(config) # Pre-populate with a mock service mock_service = Mock() pooled_service = PooledService(service=mock_service, device="GPU:0") pool.services["GPU:0"].append(pooled_service) # Acquire and release with errors pooled = pool.acquire(device="GPU:0") pool.release(pooled, error=Exception("Error 1")) assert pooled.error_count == 1 assert pooled.state == ServiceState.AVAILABLE pooled = pool.acquire(device="GPU:0") pool.release(pooled, error=Exception("Error 2")) assert pooled.error_count == 2 assert pooled.state == ServiceState.UNHEALTHY pool.shutdown() shutdown_service_pool() OCRServicePool._instance = None # ============================================================================= # Section 8.1: Memory leak detection tests # ============================================================================= class TestMemoryLeakDetection: """Tests for memory leak detection""" def test_no_leak_on_model_cycle(self): """Test that loading and unloading models doesn't leak""" shutdown_model_manager() ModelManager._instance = None config = MemoryConfig() manager = ModelManager(config) initial_model_count = len(manager.models) # Perform multiple load/unload cycles for i in range(10): manager.get_or_load_model(f"temp_model_{i}", lambda: Mock()) manager.release_model(f"temp_model_{i}") manager.unload_model(f"temp_model_{i}") # Should be back to initial state assert len(manager.models) == initial_model_count # Force gc and check gc.collect() manager.teardown() shutdown_model_manager() ModelManager._instance = None def test_no_leak_on_semaphore_cycle(self): """Test that semaphore acquire/release doesn't leak""" from app.services.memory_manager import ( PredictionSemaphore, shutdown_prediction_semaphore ) shutdown_prediction_semaphore() PredictionSemaphore._instance = None PredictionSemaphore._lock = threading.Lock() sem = PredictionSemaphore(max_concurrent=2) # Perform many acquire/release cycles for i in range(100): sem.acquire(task_id=f"task_{i}") sem.release(task_id=f"task_{i}") # Active predictions should be 0 assert sem._active_predictions == 0 assert sem._queue_depth == 0 shutdown_prediction_semaphore() PredictionSemaphore._instance = None def test_no_leak_in_batch_processor(self): """Test that batch processor doesn't leak items""" from app.services.memory_manager import BatchProcessor, BatchItem processor = BatchProcessor(max_batch_size=5, cleanup_between_batches=False) # Add and process many items for i in range(50): processor.add_item(BatchItem(item_id=f"item_{i}", data=i)) processor.process_all(lambda x: x * 2) # Queue should be empty assert processor.get_queue_size() == 0 # Stats should be accurate stats = processor.get_stats() assert stats["total_processed"] == 50 # ============================================================================= # Section 8.1: Stress tests with concurrent requests # ============================================================================= class TestStressConcurrentRequests: """Stress tests with concurrent requests""" def test_concurrent_model_access_stress(self): """Stress test concurrent model access""" shutdown_model_manager() ModelManager._instance = None config = MemoryConfig() manager = ModelManager(config) load_count = 0 lock = threading.Lock() def loader(): nonlocal load_count with lock: load_count += 1 return Mock() results = [] errors = [] def worker(worker_id): try: model = manager.get_or_load_model("shared_model", loader) time.sleep(0.01) # Simulate work manager.release_model("shared_model") results.append(worker_id) except Exception as e: errors.append(str(e)) # Launch many concurrent workers threads = [threading.Thread(target=worker, args=(i,)) for i in range(20)] for t in threads: t.start() for t in threads: t.join() # All workers should complete without errors assert len(errors) == 0 assert len(results) == 20 # Loader should only be called once assert load_count == 1 manager.teardown() shutdown_model_manager() ModelManager._instance = None def test_concurrent_semaphore_stress(self): """Stress test concurrent semaphore operations""" from app.services.memory_manager import ( PredictionSemaphore, shutdown_prediction_semaphore ) shutdown_prediction_semaphore() PredictionSemaphore._instance = None PredictionSemaphore._lock = threading.Lock() sem = PredictionSemaphore(max_concurrent=3) results = [] max_concurrent_observed = 0 current_count = 0 lock = threading.Lock() def worker(worker_id): nonlocal max_concurrent_observed, current_count if sem.acquire(timeout=10.0, task_id=f"task_{worker_id}"): with lock: current_count += 1 max_concurrent_observed = max(max_concurrent_observed, current_count) time.sleep(0.02) # Simulate work with lock: current_count -= 1 sem.release(task_id=f"task_{worker_id}") results.append(worker_id) # Launch many concurrent workers threads = [threading.Thread(target=worker, args=(i,)) for i in range(15)] for t in threads: t.start() for t in threads: t.join() # All should complete assert len(results) == 15 # Max concurrent should not exceed limit assert max_concurrent_observed <= 3 shutdown_prediction_semaphore() PredictionSemaphore._instance = None def test_concurrent_batch_processing(self): """Stress test concurrent batch processing""" from app.services.memory_manager import BatchProcessor, BatchItem processor = BatchProcessor(max_batch_size=3, cleanup_between_batches=False) # Add items from multiple threads def add_items(start_id): for i in range(10): processor.add_item(BatchItem(item_id=f"item_{start_id}_{i}", data=i)) threads = [threading.Thread(target=add_items, args=(i,)) for i in range(5)] for t in threads: t.start() for t in threads: t.join() # Should have 50 items assert processor.get_queue_size() == 50 # Process all results = processor.process_all(lambda x: x) assert len(results) == 50 # ============================================================================= # Section 8.1: Performance benchmarks # ============================================================================= class TestPerformanceBenchmarks: """Performance benchmark tests""" def test_model_load_performance(self): """Benchmark model loading performance""" shutdown_model_manager() ModelManager._instance = None config = MemoryConfig() manager = ModelManager(config) load_times = [] for i in range(5): start = time.time() manager.get_or_load_model(f"model_{i}", lambda: Mock()) load_times.append(time.time() - start) # Average load time should be reasonable (< 100ms for mock) avg_load_time = sum(load_times) / len(load_times) assert avg_load_time < 0.1 # 100ms manager.teardown() shutdown_model_manager() ModelManager._instance = None def test_semaphore_throughput(self): """Benchmark semaphore throughput""" from app.services.memory_manager import ( PredictionSemaphore, shutdown_prediction_semaphore ) shutdown_prediction_semaphore() PredictionSemaphore._instance = None PredictionSemaphore._lock = threading.Lock() sem = PredictionSemaphore(max_concurrent=10) start = time.time() iterations = 1000 for i in range(iterations): sem.acquire(timeout=1.0) sem.release() elapsed = time.time() - start # Should handle at least 10000 ops/sec ops_per_sec = iterations / elapsed assert ops_per_sec > 1000 shutdown_prediction_semaphore() PredictionSemaphore._instance = None def test_batch_processor_throughput(self): """Benchmark batch processor throughput""" from app.services.memory_manager import BatchProcessor, BatchItem processor = BatchProcessor(max_batch_size=100, cleanup_between_batches=False) # Add many items for i in range(1000): processor.add_item(BatchItem(item_id=f"item_{i}", data=i)) start = time.time() results = processor.process_all(lambda x: x * 2) elapsed = time.time() - start # Should process at least 10000 items/sec items_per_sec = len(results) / elapsed assert items_per_sec > 1000 stats = processor.get_stats() assert stats["total_processed"] == 1000 # ============================================================================= # Tests for Memory Dump and Prometheus Metrics (Section 5.2 & 7.2) # ============================================================================= class TestMemoryDumper: """Tests for MemoryDumper class""" def setup_method(self): """Reset singletons before each test""" from app.services.memory_manager import shutdown_memory_dumper shutdown_memory_dumper() shutdown_model_manager() ModelManager._instance = None def teardown_method(self): """Cleanup after each test""" from app.services.memory_manager import shutdown_memory_dumper shutdown_memory_dumper() shutdown_model_manager() ModelManager._instance = None def test_create_dump(self): """Test creating a memory dump""" from app.services.memory_manager import MemoryDumper, MemoryDump dumper = MemoryDumper() dump = dumper.create_dump() assert isinstance(dump, MemoryDump) assert dump.timestamp > 0 assert isinstance(dump.loaded_models, list) assert isinstance(dump.gc_stats, dict) def test_dump_history(self): """Test dump history tracking""" from app.services.memory_manager import MemoryDumper dumper = MemoryDumper() # Create multiple dumps for _ in range(5): dumper.create_dump() history = dumper.get_dump_history() assert len(history) == 5 latest = dumper.get_latest_dump() assert latest is history[-1] def test_dump_comparison(self): """Test comparing two dumps""" from app.services.memory_manager import MemoryDumper dumper = MemoryDumper() dump1 = dumper.create_dump() time.sleep(0.1) dump2 = dumper.create_dump() comparison = dumper.compare_dumps(dump1, dump2) assert "time_delta_seconds" in comparison assert comparison["time_delta_seconds"] > 0 assert "gpu_memory_change_mb" in comparison assert "cpu_memory_change_mb" in comparison def test_dump_to_dict(self): """Test converting dump to dictionary""" from app.services.memory_manager import MemoryDumper dumper = MemoryDumper() dump = dumper.create_dump() dump_dict = dumper.to_dict(dump) assert "timestamp" in dump_dict assert "gpu" in dump_dict assert "cpu" in dump_dict assert "models" in dump_dict assert "predictions" in dump_dict class TestPrometheusMetrics: """Tests for PrometheusMetrics class""" def setup_method(self): """Reset singletons before each test""" from app.services.memory_manager import shutdown_prometheus_metrics shutdown_prometheus_metrics() shutdown_model_manager() ModelManager._instance = None def teardown_method(self): """Cleanup after each test""" from app.services.memory_manager import shutdown_prometheus_metrics shutdown_prometheus_metrics() shutdown_model_manager() ModelManager._instance = None def test_export_metrics(self): """Test exporting metrics in Prometheus format""" from app.services.memory_manager import PrometheusMetrics prometheus = PrometheusMetrics() metrics = prometheus.export_metrics() # Should be a non-empty string assert isinstance(metrics, str) assert len(metrics) > 0 # Should contain expected metric prefixes assert "tool_ocr_memory_" in metrics def test_metric_format(self): """Test that metrics follow Prometheus format""" from app.services.memory_manager import PrometheusMetrics prometheus = PrometheusMetrics() metrics = prometheus.export_metrics() lines = metrics.split("\n") # Check for HELP and TYPE comments help_lines = [l for l in lines if l.startswith("# HELP")] type_lines = [l for l in lines if l.startswith("# TYPE")] assert len(help_lines) > 0 assert len(type_lines) > 0 def test_custom_metrics(self): """Test setting custom metrics""" from app.services.memory_manager import PrometheusMetrics prometheus = PrometheusMetrics() prometheus.set_custom_metric("custom_value", 42.0) prometheus.set_custom_metric("labeled_value", 100.0, {"env": "test"}) metrics = prometheus.export_metrics() assert "custom_value" in metrics or "42" in metrics def test_get_prometheus_metrics_singleton(self): """Test prometheus metrics singleton""" from app.services.memory_manager import get_prometheus_metrics, shutdown_prometheus_metrics metrics1 = get_prometheus_metrics() metrics2 = get_prometheus_metrics() assert metrics1 is metrics2 shutdown_prometheus_metrics() if __name__ == "__main__": pytest.main([__file__, "-v"])