Files
Task_Reporter/tests/test_concurrent_messages.py
egg 599802b818 feat: Add Chat UX improvements with notifications and @mention support
- Add ActionBar component with expandable toolbar for mobile
- Add @mention functionality with autocomplete dropdown
- Add browser notification system (push, sound, vibration)
- Add NotificationSettings modal for user preferences
- Add mention badges on room list cards
- Add ReportPreview with Markdown rendering and copy/download
- Add message copy functionality with hover actions
- Add backend mentions field to messages with Alembic migration
- Add lots field to rooms, remove templates
- Optimize WebSocket database session handling
- Various UX polish (animations, accessibility)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-08 08:20:37 +08:00

297 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Concurrent Message Stress Test
This script tests:
1. Connection pool doesn't exhaust under load
2. Sequence numbers are unique per room
3. No deadlocks occur during concurrent operations
Usage:
# Run with pytest (uses test database)
pytest tests/test_concurrent_messages.py -v
# Run standalone against real database
python tests/test_concurrent_messages.py --concurrent 50 --messages 10
Requirements:
- MySQL database (SQLite doesn't support FOR UPDATE properly)
- pip install aiohttp websockets
"""
import asyncio
import argparse
import sys
import os
from collections import Counter
from typing import List, Dict, Any
import time
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import all models to ensure proper relationship initialization
from app.modules.auth.models import UserSession
from app.modules.chat_room.models import IncidentRoom, RoomMember
from app.modules.file_storage.models import RoomFile
from app.modules.realtime.models import Message, MessageReaction, MessageEditHistory
from app.modules.report_generation.models import GeneratedReport
def create_test_room(db, room_id: str) -> IncidentRoom:
"""Create a test room for testing"""
from datetime import datetime, timezone
from app.modules.chat_room.schemas import IncidentType, SeverityLevel
room = IncidentRoom(
room_id=room_id,
title=f"Test Room {room_id}",
incident_type=IncidentType.OTHER,
severity=SeverityLevel.LOW,
lots=["TEST-LOT"],
created_by="test-admin",
created_at=datetime.now(timezone.utc)
)
db.add(room)
db.commit()
db.refresh(room)
return room
def cleanup_test_room(room_id: str):
"""Clean up test room and its messages"""
from app.core.database import get_db_context
with get_db_context() as db:
# Messages will be deleted by CASCADE
db.query(IncidentRoom).filter(IncidentRoom.room_id == room_id).delete()
db.commit()
def test_sequence_number_uniqueness_sync():
"""Test that sequence numbers are unique when creating messages sequentially"""
from sqlalchemy.orm import Session
from app.core.database import get_db_context
from app.modules.realtime.services.message_service import MessageService
from app.modules.realtime.models import MessageType, Message
import uuid
# Use short UUID to fit in room_id column (20 chars max based on model)
room_id = f"ts-{str(uuid.uuid4())[:6]}"
num_messages = 20
sequence_numbers = []
try:
# Create test room first
with get_db_context() as db:
create_test_room(db, room_id)
for i in range(num_messages):
with get_db_context() as db:
message = MessageService.create_message(
db=db,
room_id=room_id,
sender_id=f"user-{i % 5}",
content=f"Test message {i}",
message_type=MessageType.TEXT
)
sequence_numbers.append(message.sequence_number)
# Verify uniqueness
duplicates = [seq for seq, count in Counter(sequence_numbers).items() if count > 1]
assert len(duplicates) == 0, f"Duplicate sequence numbers found: {duplicates}"
# Verify monotonic increase
assert sequence_numbers == sorted(sequence_numbers), "Sequence numbers not in order"
assert sequence_numbers == list(range(1, num_messages + 1)), \
f"Sequence numbers not consecutive: {sequence_numbers}"
print(f"✓ Sequential test passed: {num_messages} messages with unique sequences")
finally:
# Cleanup (room deletion cascades to messages)
cleanup_test_room(room_id)
async def create_message_async(room_id: str, user_id: str, msg_index: int) -> Dict[str, Any]:
"""Create a message and return its sequence number"""
from app.core.database import get_db_context
from app.modules.realtime.services.message_service import MessageService
from app.modules.realtime.models import MessageType
with get_db_context() as db:
message = MessageService.create_message(
db=db,
room_id=room_id,
sender_id=user_id,
content=f"Concurrent message {msg_index}",
message_type=MessageType.TEXT
)
return {
"message_id": message.message_id,
"sequence_number": message.sequence_number,
"user_id": user_id,
"index": msg_index
}
async def test_concurrent_message_creation(num_concurrent: int = 50, messages_per_user: int = 5):
"""Test concurrent message creation from multiple users"""
from app.modules.realtime.models import Message
from app.core.database import get_db_context
import uuid
# Use short UUID to fit in room_id column (20 chars max based on model)
room_id = f"cc-{str(uuid.uuid4())[:6]}"
total_messages = num_concurrent * messages_per_user
results: List[Dict[str, Any]] = []
errors: List[str] = []
# Create test room first
with get_db_context() as db:
create_test_room(db, room_id)
print(f"\nStarting concurrent test: {num_concurrent} users × {messages_per_user} messages = {total_messages} total")
start_time = time.time()
# Create tasks for all concurrent message creations
tasks = []
for user_idx in range(num_concurrent):
user_id = f"user-{user_idx}"
for msg_idx in range(messages_per_user):
task = create_message_async(room_id, user_id, user_idx * messages_per_user + msg_idx)
tasks.append(task)
# Execute all tasks concurrently
try:
results = await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
errors.append(str(e))
elapsed = time.time() - start_time
# Analyze results
successful_results = [r for r in results if isinstance(r, dict)]
failed_results = [r for r in results if isinstance(r, Exception)]
sequence_numbers = [r["sequence_number"] for r in successful_results]
duplicates = [seq for seq, count in Counter(sequence_numbers).items() if count > 1]
print(f"\nResults after {elapsed:.2f}s:")
print(f" - Successful messages: {len(successful_results)}/{total_messages}")
print(f" - Failed messages: {len(failed_results)}")
print(f" - Unique sequence numbers: {len(set(sequence_numbers))}")
print(f" - Duplicate sequences: {duplicates if duplicates else 'None'}")
print(f" - Messages per second: {len(successful_results)/elapsed:.1f}")
# Cleanup (room deletion cascades to messages)
try:
cleanup_test_room(room_id)
print(f" - Cleaned up test room and messages")
except Exception as e:
print(f" - Cleanup error: {e}")
# Assertions for pytest
assert len(failed_results) == 0, f"Some messages failed: {failed_results[:5]}"
assert len(duplicates) == 0, f"Duplicate sequence numbers: {duplicates}"
assert len(successful_results) == total_messages, \
f"Not all messages created: {len(successful_results)}/{total_messages}"
print(f"\n✓ Concurrent test passed!")
return {
"total": total_messages,
"successful": len(successful_results),
"failed": len(failed_results),
"duplicates": duplicates,
"elapsed": elapsed
}
def test_concurrent_messages_pytest():
"""Pytest wrapper for concurrent message test"""
asyncio.run(test_concurrent_message_creation(num_concurrent=10, messages_per_user=5))
def test_connection_pool_stress():
"""Test that connection pool handles many short sessions"""
from app.core.database import get_db_context, engine
from sqlalchemy import text
import threading
import time
num_threads = 100
queries_per_thread = 10
errors = []
success_count = [0] # Use list to allow modification in nested function
lock = threading.Lock()
def worker(thread_id: int):
for i in range(queries_per_thread):
try:
with get_db_context() as db:
# Simple query to test connection
db.execute(text("SELECT 1"))
with lock:
success_count[0] += 1
except Exception as e:
with lock:
errors.append(f"Thread {thread_id}, query {i}: {e}")
print(f"\nConnection pool stress test: {num_threads} threads × {queries_per_thread} queries")
start_time = time.time()
threads = [threading.Thread(target=worker, args=(i,)) for i in range(num_threads)]
for t in threads:
t.start()
for t in threads:
t.join()
elapsed = time.time() - start_time
total_queries = num_threads * queries_per_thread
print(f" - Completed: {success_count[0]}/{total_queries} queries in {elapsed:.2f}s")
print(f" - Queries per second: {success_count[0]/elapsed:.1f}")
print(f" - Errors: {len(errors)}")
# Show pool status
pool_status = engine.pool.status()
print(f" - Pool status: {pool_status}")
assert len(errors) == 0, f"Pool exhaustion errors: {errors[:5]}"
assert success_count[0] == total_queries, "Not all queries completed"
print(f"\n✓ Connection pool stress test passed!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Concurrent message stress test")
parser.add_argument("--concurrent", type=int, default=50,
help="Number of concurrent users (default: 50)")
parser.add_argument("--messages", type=int, default=5,
help="Messages per user (default: 5)")
parser.add_argument("--pool-test", action="store_true",
help="Run connection pool stress test")
parser.add_argument("--sequential", action="store_true",
help="Run sequential test only")
args = parser.parse_args()
print("=" * 60)
print("WebSocket Database Session Optimization - Stress Tests")
print("=" * 60)
if args.sequential:
test_sequence_number_uniqueness_sync()
elif args.pool_test:
test_connection_pool_stress()
else:
# Run all tests
test_sequence_number_uniqueness_sync()
test_connection_pool_stress()
asyncio.run(test_concurrent_message_creation(
num_concurrent=args.concurrent,
messages_per_user=args.messages
))
print("\n" + "=" * 60)
print("All tests completed successfully!")
print("=" * 60)