""" Concurrent Message Stress Test This script tests: 1. Connection pool doesn't exhaust under load 2. Sequence numbers are unique per room 3. No deadlocks occur during concurrent operations Usage: # Run with pytest (uses test database) pytest tests/test_concurrent_messages.py -v # Run standalone against real database python tests/test_concurrent_messages.py --concurrent 50 --messages 10 Requirements: - MySQL database (SQLite doesn't support FOR UPDATE properly) - pip install aiohttp websockets """ import asyncio import argparse import sys import os from collections import Counter from typing import List, Dict, Any import time # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Import all models to ensure proper relationship initialization from app.modules.auth.models import UserSession from app.modules.chat_room.models import IncidentRoom, RoomMember from app.modules.file_storage.models import RoomFile from app.modules.realtime.models import Message, MessageReaction, MessageEditHistory from app.modules.report_generation.models import GeneratedReport def create_test_room(db, room_id: str) -> IncidentRoom: """Create a test room for testing""" from datetime import datetime, timezone from app.modules.chat_room.schemas import IncidentType, SeverityLevel room = IncidentRoom( room_id=room_id, title=f"Test Room {room_id}", incident_type=IncidentType.OTHER, severity=SeverityLevel.LOW, lots=["TEST-LOT"], created_by="test-admin", created_at=datetime.now(timezone.utc) ) db.add(room) db.commit() db.refresh(room) return room def cleanup_test_room(room_id: str): """Clean up test room and its messages""" from app.core.database import get_db_context with get_db_context() as db: # Messages will be deleted by CASCADE db.query(IncidentRoom).filter(IncidentRoom.room_id == room_id).delete() db.commit() def test_sequence_number_uniqueness_sync(): """Test that sequence numbers are unique when creating messages sequentially""" from sqlalchemy.orm import Session from app.core.database import get_db_context from app.modules.realtime.services.message_service import MessageService from app.modules.realtime.models import MessageType, Message import uuid # Use short UUID to fit in room_id column (20 chars max based on model) room_id = f"ts-{str(uuid.uuid4())[:6]}" num_messages = 20 sequence_numbers = [] try: # Create test room first with get_db_context() as db: create_test_room(db, room_id) for i in range(num_messages): with get_db_context() as db: message = MessageService.create_message( db=db, room_id=room_id, sender_id=f"user-{i % 5}", content=f"Test message {i}", message_type=MessageType.TEXT ) sequence_numbers.append(message.sequence_number) # Verify uniqueness duplicates = [seq for seq, count in Counter(sequence_numbers).items() if count > 1] assert len(duplicates) == 0, f"Duplicate sequence numbers found: {duplicates}" # Verify monotonic increase assert sequence_numbers == sorted(sequence_numbers), "Sequence numbers not in order" assert sequence_numbers == list(range(1, num_messages + 1)), \ f"Sequence numbers not consecutive: {sequence_numbers}" print(f"✓ Sequential test passed: {num_messages} messages with unique sequences") finally: # Cleanup (room deletion cascades to messages) cleanup_test_room(room_id) async def create_message_async(room_id: str, user_id: str, msg_index: int) -> Dict[str, Any]: """Create a message and return its sequence number""" from app.core.database import get_db_context from app.modules.realtime.services.message_service import MessageService from app.modules.realtime.models import MessageType with get_db_context() as db: message = MessageService.create_message( db=db, room_id=room_id, sender_id=user_id, content=f"Concurrent message {msg_index}", message_type=MessageType.TEXT ) return { "message_id": message.message_id, "sequence_number": message.sequence_number, "user_id": user_id, "index": msg_index } async def test_concurrent_message_creation(num_concurrent: int = 50, messages_per_user: int = 5): """Test concurrent message creation from multiple users""" from app.modules.realtime.models import Message from app.core.database import get_db_context import uuid # Use short UUID to fit in room_id column (20 chars max based on model) room_id = f"cc-{str(uuid.uuid4())[:6]}" total_messages = num_concurrent * messages_per_user results: List[Dict[str, Any]] = [] errors: List[str] = [] # Create test room first with get_db_context() as db: create_test_room(db, room_id) print(f"\nStarting concurrent test: {num_concurrent} users × {messages_per_user} messages = {total_messages} total") start_time = time.time() # Create tasks for all concurrent message creations tasks = [] for user_idx in range(num_concurrent): user_id = f"user-{user_idx}" for msg_idx in range(messages_per_user): task = create_message_async(room_id, user_id, user_idx * messages_per_user + msg_idx) tasks.append(task) # Execute all tasks concurrently try: results = await asyncio.gather(*tasks, return_exceptions=True) except Exception as e: errors.append(str(e)) elapsed = time.time() - start_time # Analyze results successful_results = [r for r in results if isinstance(r, dict)] failed_results = [r for r in results if isinstance(r, Exception)] sequence_numbers = [r["sequence_number"] for r in successful_results] duplicates = [seq for seq, count in Counter(sequence_numbers).items() if count > 1] print(f"\nResults after {elapsed:.2f}s:") print(f" - Successful messages: {len(successful_results)}/{total_messages}") print(f" - Failed messages: {len(failed_results)}") print(f" - Unique sequence numbers: {len(set(sequence_numbers))}") print(f" - Duplicate sequences: {duplicates if duplicates else 'None'}") print(f" - Messages per second: {len(successful_results)/elapsed:.1f}") # Cleanup (room deletion cascades to messages) try: cleanup_test_room(room_id) print(f" - Cleaned up test room and messages") except Exception as e: print(f" - Cleanup error: {e}") # Assertions for pytest assert len(failed_results) == 0, f"Some messages failed: {failed_results[:5]}" assert len(duplicates) == 0, f"Duplicate sequence numbers: {duplicates}" assert len(successful_results) == total_messages, \ f"Not all messages created: {len(successful_results)}/{total_messages}" print(f"\n✓ Concurrent test passed!") return { "total": total_messages, "successful": len(successful_results), "failed": len(failed_results), "duplicates": duplicates, "elapsed": elapsed } def test_concurrent_messages_pytest(): """Pytest wrapper for concurrent message test""" asyncio.run(test_concurrent_message_creation(num_concurrent=10, messages_per_user=5)) def test_connection_pool_stress(): """Test that connection pool handles many short sessions""" from app.core.database import get_db_context, engine from sqlalchemy import text import threading import time num_threads = 100 queries_per_thread = 10 errors = [] success_count = [0] # Use list to allow modification in nested function lock = threading.Lock() def worker(thread_id: int): for i in range(queries_per_thread): try: with get_db_context() as db: # Simple query to test connection db.execute(text("SELECT 1")) with lock: success_count[0] += 1 except Exception as e: with lock: errors.append(f"Thread {thread_id}, query {i}: {e}") print(f"\nConnection pool stress test: {num_threads} threads × {queries_per_thread} queries") start_time = time.time() threads = [threading.Thread(target=worker, args=(i,)) for i in range(num_threads)] for t in threads: t.start() for t in threads: t.join() elapsed = time.time() - start_time total_queries = num_threads * queries_per_thread print(f" - Completed: {success_count[0]}/{total_queries} queries in {elapsed:.2f}s") print(f" - Queries per second: {success_count[0]/elapsed:.1f}") print(f" - Errors: {len(errors)}") # Show pool status pool_status = engine.pool.status() print(f" - Pool status: {pool_status}") assert len(errors) == 0, f"Pool exhaustion errors: {errors[:5]}" assert success_count[0] == total_queries, "Not all queries completed" print(f"\n✓ Connection pool stress test passed!") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Concurrent message stress test") parser.add_argument("--concurrent", type=int, default=50, help="Number of concurrent users (default: 50)") parser.add_argument("--messages", type=int, default=5, help="Messages per user (default: 5)") parser.add_argument("--pool-test", action="store_true", help="Run connection pool stress test") parser.add_argument("--sequential", action="store_true", help="Run sequential test only") args = parser.parse_args() print("=" * 60) print("WebSocket Database Session Optimization - Stress Tests") print("=" * 60) if args.sequential: test_sequence_number_uniqueness_sync() elif args.pool_test: test_connection_pool_stress() else: # Run all tests test_sequence_number_uniqueness_sync() test_connection_pool_stress() asyncio.run(test_concurrent_message_creation( num_concurrent=args.concurrent, messages_per_user=args.messages )) print("\n" + "=" * 60) print("All tests completed successfully!") print("=" * 60)