- Add ActionBar component with expandable toolbar for mobile - Add @mention functionality with autocomplete dropdown - Add browser notification system (push, sound, vibration) - Add NotificationSettings modal for user preferences - Add mention badges on room list cards - Add ReportPreview with Markdown rendering and copy/download - Add message copy functionality with hover actions - Add backend mentions field to messages with Alembic migration - Add lots field to rooms, remove templates - Optimize WebSocket database session handling - Various UX polish (animations, accessibility) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
297 lines
10 KiB
Python
297 lines
10 KiB
Python
"""
|
||
Concurrent Message Stress Test
|
||
|
||
This script tests:
|
||
1. Connection pool doesn't exhaust under load
|
||
2. Sequence numbers are unique per room
|
||
3. No deadlocks occur during concurrent operations
|
||
|
||
Usage:
|
||
# Run with pytest (uses test database)
|
||
pytest tests/test_concurrent_messages.py -v
|
||
|
||
# Run standalone against real database
|
||
python tests/test_concurrent_messages.py --concurrent 50 --messages 10
|
||
|
||
Requirements:
|
||
- MySQL database (SQLite doesn't support FOR UPDATE properly)
|
||
- pip install aiohttp websockets
|
||
"""
|
||
import asyncio
|
||
import argparse
|
||
import sys
|
||
import os
|
||
from collections import Counter
|
||
from typing import List, Dict, Any
|
||
import time
|
||
|
||
# Add project root to path
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
# Import all models to ensure proper relationship initialization
|
||
from app.modules.auth.models import UserSession
|
||
from app.modules.chat_room.models import IncidentRoom, RoomMember
|
||
from app.modules.file_storage.models import RoomFile
|
||
from app.modules.realtime.models import Message, MessageReaction, MessageEditHistory
|
||
from app.modules.report_generation.models import GeneratedReport
|
||
|
||
|
||
def create_test_room(db, room_id: str) -> IncidentRoom:
|
||
"""Create a test room for testing"""
|
||
from datetime import datetime, timezone
|
||
from app.modules.chat_room.schemas import IncidentType, SeverityLevel
|
||
|
||
room = IncidentRoom(
|
||
room_id=room_id,
|
||
title=f"Test Room {room_id}",
|
||
incident_type=IncidentType.OTHER,
|
||
severity=SeverityLevel.LOW,
|
||
lots=["TEST-LOT"],
|
||
created_by="test-admin",
|
||
created_at=datetime.now(timezone.utc)
|
||
)
|
||
db.add(room)
|
||
db.commit()
|
||
db.refresh(room)
|
||
return room
|
||
|
||
|
||
def cleanup_test_room(room_id: str):
|
||
"""Clean up test room and its messages"""
|
||
from app.core.database import get_db_context
|
||
|
||
with get_db_context() as db:
|
||
# Messages will be deleted by CASCADE
|
||
db.query(IncidentRoom).filter(IncidentRoom.room_id == room_id).delete()
|
||
db.commit()
|
||
|
||
|
||
def test_sequence_number_uniqueness_sync():
|
||
"""Test that sequence numbers are unique when creating messages sequentially"""
|
||
from sqlalchemy.orm import Session
|
||
from app.core.database import get_db_context
|
||
from app.modules.realtime.services.message_service import MessageService
|
||
from app.modules.realtime.models import MessageType, Message
|
||
import uuid
|
||
|
||
# Use short UUID to fit in room_id column (20 chars max based on model)
|
||
room_id = f"ts-{str(uuid.uuid4())[:6]}"
|
||
num_messages = 20
|
||
sequence_numbers = []
|
||
|
||
try:
|
||
# Create test room first
|
||
with get_db_context() as db:
|
||
create_test_room(db, room_id)
|
||
|
||
for i in range(num_messages):
|
||
with get_db_context() as db:
|
||
message = MessageService.create_message(
|
||
db=db,
|
||
room_id=room_id,
|
||
sender_id=f"user-{i % 5}",
|
||
content=f"Test message {i}",
|
||
message_type=MessageType.TEXT
|
||
)
|
||
sequence_numbers.append(message.sequence_number)
|
||
|
||
# Verify uniqueness
|
||
duplicates = [seq for seq, count in Counter(sequence_numbers).items() if count > 1]
|
||
assert len(duplicates) == 0, f"Duplicate sequence numbers found: {duplicates}"
|
||
|
||
# Verify monotonic increase
|
||
assert sequence_numbers == sorted(sequence_numbers), "Sequence numbers not in order"
|
||
assert sequence_numbers == list(range(1, num_messages + 1)), \
|
||
f"Sequence numbers not consecutive: {sequence_numbers}"
|
||
|
||
print(f"✓ Sequential test passed: {num_messages} messages with unique sequences")
|
||
|
||
finally:
|
||
# Cleanup (room deletion cascades to messages)
|
||
cleanup_test_room(room_id)
|
||
|
||
|
||
async def create_message_async(room_id: str, user_id: str, msg_index: int) -> Dict[str, Any]:
|
||
"""Create a message and return its sequence number"""
|
||
from app.core.database import get_db_context
|
||
from app.modules.realtime.services.message_service import MessageService
|
||
from app.modules.realtime.models import MessageType
|
||
|
||
with get_db_context() as db:
|
||
message = MessageService.create_message(
|
||
db=db,
|
||
room_id=room_id,
|
||
sender_id=user_id,
|
||
content=f"Concurrent message {msg_index}",
|
||
message_type=MessageType.TEXT
|
||
)
|
||
return {
|
||
"message_id": message.message_id,
|
||
"sequence_number": message.sequence_number,
|
||
"user_id": user_id,
|
||
"index": msg_index
|
||
}
|
||
|
||
|
||
async def test_concurrent_message_creation(num_concurrent: int = 50, messages_per_user: int = 5):
|
||
"""Test concurrent message creation from multiple users"""
|
||
from app.modules.realtime.models import Message
|
||
from app.core.database import get_db_context
|
||
import uuid
|
||
|
||
# Use short UUID to fit in room_id column (20 chars max based on model)
|
||
room_id = f"cc-{str(uuid.uuid4())[:6]}"
|
||
total_messages = num_concurrent * messages_per_user
|
||
results: List[Dict[str, Any]] = []
|
||
errors: List[str] = []
|
||
|
||
# Create test room first
|
||
with get_db_context() as db:
|
||
create_test_room(db, room_id)
|
||
|
||
print(f"\nStarting concurrent test: {num_concurrent} users × {messages_per_user} messages = {total_messages} total")
|
||
start_time = time.time()
|
||
|
||
# Create tasks for all concurrent message creations
|
||
tasks = []
|
||
for user_idx in range(num_concurrent):
|
||
user_id = f"user-{user_idx}"
|
||
for msg_idx in range(messages_per_user):
|
||
task = create_message_async(room_id, user_id, user_idx * messages_per_user + msg_idx)
|
||
tasks.append(task)
|
||
|
||
# Execute all tasks concurrently
|
||
try:
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
except Exception as e:
|
||
errors.append(str(e))
|
||
|
||
elapsed = time.time() - start_time
|
||
|
||
# Analyze results
|
||
successful_results = [r for r in results if isinstance(r, dict)]
|
||
failed_results = [r for r in results if isinstance(r, Exception)]
|
||
|
||
sequence_numbers = [r["sequence_number"] for r in successful_results]
|
||
duplicates = [seq for seq, count in Counter(sequence_numbers).items() if count > 1]
|
||
|
||
print(f"\nResults after {elapsed:.2f}s:")
|
||
print(f" - Successful messages: {len(successful_results)}/{total_messages}")
|
||
print(f" - Failed messages: {len(failed_results)}")
|
||
print(f" - Unique sequence numbers: {len(set(sequence_numbers))}")
|
||
print(f" - Duplicate sequences: {duplicates if duplicates else 'None'}")
|
||
print(f" - Messages per second: {len(successful_results)/elapsed:.1f}")
|
||
|
||
# Cleanup (room deletion cascades to messages)
|
||
try:
|
||
cleanup_test_room(room_id)
|
||
print(f" - Cleaned up test room and messages")
|
||
except Exception as e:
|
||
print(f" - Cleanup error: {e}")
|
||
|
||
# Assertions for pytest
|
||
assert len(failed_results) == 0, f"Some messages failed: {failed_results[:5]}"
|
||
assert len(duplicates) == 0, f"Duplicate sequence numbers: {duplicates}"
|
||
assert len(successful_results) == total_messages, \
|
||
f"Not all messages created: {len(successful_results)}/{total_messages}"
|
||
|
||
print(f"\n✓ Concurrent test passed!")
|
||
return {
|
||
"total": total_messages,
|
||
"successful": len(successful_results),
|
||
"failed": len(failed_results),
|
||
"duplicates": duplicates,
|
||
"elapsed": elapsed
|
||
}
|
||
|
||
|
||
def test_concurrent_messages_pytest():
|
||
"""Pytest wrapper for concurrent message test"""
|
||
asyncio.run(test_concurrent_message_creation(num_concurrent=10, messages_per_user=5))
|
||
|
||
|
||
def test_connection_pool_stress():
|
||
"""Test that connection pool handles many short sessions"""
|
||
from app.core.database import get_db_context, engine
|
||
from sqlalchemy import text
|
||
import threading
|
||
import time
|
||
|
||
num_threads = 100
|
||
queries_per_thread = 10
|
||
errors = []
|
||
success_count = [0] # Use list to allow modification in nested function
|
||
lock = threading.Lock()
|
||
|
||
def worker(thread_id: int):
|
||
for i in range(queries_per_thread):
|
||
try:
|
||
with get_db_context() as db:
|
||
# Simple query to test connection
|
||
db.execute(text("SELECT 1"))
|
||
with lock:
|
||
success_count[0] += 1
|
||
except Exception as e:
|
||
with lock:
|
||
errors.append(f"Thread {thread_id}, query {i}: {e}")
|
||
|
||
print(f"\nConnection pool stress test: {num_threads} threads × {queries_per_thread} queries")
|
||
start_time = time.time()
|
||
|
||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(num_threads)]
|
||
for t in threads:
|
||
t.start()
|
||
for t in threads:
|
||
t.join()
|
||
|
||
elapsed = time.time() - start_time
|
||
total_queries = num_threads * queries_per_thread
|
||
|
||
print(f" - Completed: {success_count[0]}/{total_queries} queries in {elapsed:.2f}s")
|
||
print(f" - Queries per second: {success_count[0]/elapsed:.1f}")
|
||
print(f" - Errors: {len(errors)}")
|
||
|
||
# Show pool status
|
||
pool_status = engine.pool.status()
|
||
print(f" - Pool status: {pool_status}")
|
||
|
||
assert len(errors) == 0, f"Pool exhaustion errors: {errors[:5]}"
|
||
assert success_count[0] == total_queries, "Not all queries completed"
|
||
|
||
print(f"\n✓ Connection pool stress test passed!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="Concurrent message stress test")
|
||
parser.add_argument("--concurrent", type=int, default=50,
|
||
help="Number of concurrent users (default: 50)")
|
||
parser.add_argument("--messages", type=int, default=5,
|
||
help="Messages per user (default: 5)")
|
||
parser.add_argument("--pool-test", action="store_true",
|
||
help="Run connection pool stress test")
|
||
parser.add_argument("--sequential", action="store_true",
|
||
help="Run sequential test only")
|
||
|
||
args = parser.parse_args()
|
||
|
||
print("=" * 60)
|
||
print("WebSocket Database Session Optimization - Stress Tests")
|
||
print("=" * 60)
|
||
|
||
if args.sequential:
|
||
test_sequence_number_uniqueness_sync()
|
||
elif args.pool_test:
|
||
test_connection_pool_stress()
|
||
else:
|
||
# Run all tests
|
||
test_sequence_number_uniqueness_sync()
|
||
test_connection_pool_stress()
|
||
asyncio.run(test_concurrent_message_creation(
|
||
num_concurrent=args.concurrent,
|
||
messages_per_user=args.messages
|
||
))
|
||
|
||
print("\n" + "=" * 60)
|
||
print("All tests completed successfully!")
|
||
print("=" * 60)
|