## Database Migration (SQLite → MySQL) - Add Alembic migration framework - Add 'tr_' prefix to all tables to avoid conflicts in shared database - Remove SQLite support, use MySQL exclusively - Add pymysql driver dependency - Change ad_token column to Text type for long JWT tokens ## Unified Environment Configuration - Centralize all hardcoded settings to environment variables - Backend: Extend Settings class in app/core/config.py - Frontend: Use Vite environment variables (import.meta.env) - Docker: Move credentials to environment variables - Update .env.example files with comprehensive documentation ## Test Organization - Move root-level test files to tests/ directory: - test_chat_room.py → tests/test_chat_room.py - test_websocket.py → tests/test_websocket.py - test_realtime_implementation.py → tests/test_realtime_implementation.py - Fix path references in test_realtime_implementation.py Breaking Changes: - CORS now requires explicit origins (no more wildcard) - All database tables renamed with 'tr_' prefix - SQLite no longer supported 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
245 lines
7.4 KiB
Python
245 lines
7.4 KiB
Python
"""WebSocket connection pool management"""
|
|
from fastapi import WebSocket
|
|
from typing import Dict, List, Set, Any
|
|
from datetime import datetime
|
|
import asyncio
|
|
import json
|
|
from collections import defaultdict
|
|
|
|
from app.core.config import get_settings
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
def json_serializer(obj: Any) -> str:
|
|
"""Custom JSON serializer for objects not serializable by default json code"""
|
|
if isinstance(obj, datetime):
|
|
return obj.isoformat()
|
|
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
|
|
|
|
|
|
class ConnectionInfo:
|
|
"""Information about a WebSocket connection"""
|
|
def __init__(self, websocket: WebSocket, user_id: str, room_id: str):
|
|
self.websocket = websocket
|
|
self.user_id = user_id
|
|
self.room_id = room_id
|
|
self.connected_at = datetime.utcnow()
|
|
self.last_sequence = 0 # Track last received sequence number for reconnection
|
|
|
|
|
|
class WebSocketManager:
|
|
"""Manages WebSocket connections and message broadcasting"""
|
|
|
|
def __init__(self):
|
|
# room_id -> Set of ConnectionInfo
|
|
self._room_connections: Dict[str, Set[ConnectionInfo]] = defaultdict(set)
|
|
|
|
# user_id -> ConnectionInfo (for direct messaging)
|
|
self._user_connections: Dict[str, ConnectionInfo] = {}
|
|
|
|
# room_id -> Set of user_ids (typing users)
|
|
self._typing_users: Dict[str, Set[str]] = defaultdict(set)
|
|
|
|
# user_id -> asyncio.Task (typing timeout tasks)
|
|
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
|
|
|
async def connect(self, websocket: WebSocket, room_id: str, user_id: str) -> ConnectionInfo:
|
|
"""
|
|
Add a WebSocket connection to the pool
|
|
|
|
Args:
|
|
websocket: The WebSocket connection
|
|
room_id: Room ID the user is connecting to
|
|
user_id: User ID
|
|
|
|
Returns:
|
|
ConnectionInfo object
|
|
"""
|
|
await websocket.accept()
|
|
|
|
conn_info = ConnectionInfo(websocket, user_id, room_id)
|
|
self._room_connections[room_id].add(conn_info)
|
|
self._user_connections[user_id] = conn_info
|
|
|
|
return conn_info
|
|
|
|
async def disconnect(self, conn_info: ConnectionInfo):
|
|
"""
|
|
Remove a WebSocket connection from the pool
|
|
|
|
Args:
|
|
conn_info: Connection info to remove
|
|
"""
|
|
room_id = conn_info.room_id
|
|
user_id = conn_info.user_id
|
|
|
|
# Remove from room connections
|
|
if room_id in self._room_connections:
|
|
self._room_connections[room_id].discard(conn_info)
|
|
if not self._room_connections[room_id]:
|
|
del self._room_connections[room_id]
|
|
|
|
# Remove from user connections
|
|
if user_id in self._user_connections:
|
|
del self._user_connections[user_id]
|
|
|
|
# Clear typing status
|
|
if user_id in self._typing_tasks:
|
|
self._typing_tasks[user_id].cancel()
|
|
del self._typing_tasks[user_id]
|
|
|
|
if room_id in self._typing_users:
|
|
self._typing_users[room_id].discard(user_id)
|
|
|
|
async def broadcast_to_room(self, room_id: str, message: dict, exclude_user: str = None):
|
|
"""
|
|
Broadcast a message to all connections in a room
|
|
|
|
Args:
|
|
room_id: Room ID to broadcast to
|
|
message: Message dictionary to broadcast
|
|
exclude_user: Optional user ID to exclude from broadcast
|
|
"""
|
|
if room_id not in self._room_connections:
|
|
return
|
|
|
|
message_json = json.dumps(message, default=json_serializer)
|
|
|
|
# Collect disconnected connections
|
|
disconnected = []
|
|
|
|
for conn_info in self._room_connections[room_id]:
|
|
if exclude_user and conn_info.user_id == exclude_user:
|
|
continue
|
|
|
|
try:
|
|
await conn_info.websocket.send_text(message_json)
|
|
except Exception as e:
|
|
# Connection failed, mark for removal
|
|
disconnected.append(conn_info)
|
|
|
|
# Clean up disconnected connections
|
|
for conn_info in disconnected:
|
|
await self.disconnect(conn_info)
|
|
|
|
async def send_personal(self, user_id: str, message: dict):
|
|
"""
|
|
Send a message to a specific user
|
|
|
|
Args:
|
|
user_id: User ID to send to
|
|
message: Message dictionary to send
|
|
"""
|
|
if user_id not in self._user_connections:
|
|
return
|
|
|
|
conn_info = self._user_connections[user_id]
|
|
message_json = json.dumps(message, default=json_serializer)
|
|
|
|
try:
|
|
await conn_info.websocket.send_text(message_json)
|
|
except Exception:
|
|
# Connection failed, disconnect
|
|
await self.disconnect(conn_info)
|
|
|
|
def get_room_connections(self, room_id: str) -> List[ConnectionInfo]:
|
|
"""
|
|
Get all active connections for a room
|
|
|
|
Args:
|
|
room_id: Room ID
|
|
|
|
Returns:
|
|
List of ConnectionInfo objects
|
|
"""
|
|
if room_id not in self._room_connections:
|
|
return []
|
|
return list(self._room_connections[room_id])
|
|
|
|
def get_online_users(self, room_id: str) -> List[str]:
|
|
"""
|
|
Get list of online user IDs in a room
|
|
|
|
Args:
|
|
room_id: Room ID
|
|
|
|
Returns:
|
|
List of user IDs
|
|
"""
|
|
return [conn.user_id for conn in self.get_room_connections(room_id)]
|
|
|
|
def is_user_online(self, user_id: str) -> bool:
|
|
"""
|
|
Check if a user is currently connected
|
|
|
|
Args:
|
|
user_id: User ID to check
|
|
|
|
Returns:
|
|
True if user is connected
|
|
"""
|
|
return user_id in self._user_connections
|
|
|
|
async def set_typing(self, room_id: str, user_id: str, is_typing: bool):
|
|
"""
|
|
Set typing status for a user in a room
|
|
|
|
Args:
|
|
room_id: Room ID
|
|
user_id: User ID
|
|
is_typing: Whether user is typing
|
|
"""
|
|
if is_typing:
|
|
self._typing_users[room_id].add(user_id)
|
|
|
|
# Cancel existing timeout task
|
|
if user_id in self._typing_tasks:
|
|
self._typing_tasks[user_id].cancel()
|
|
|
|
# Set new timeout (configurable via TYPING_TIMEOUT_SECONDS)
|
|
typing_timeout = settings.TYPING_TIMEOUT_SECONDS
|
|
|
|
async def clear_typing():
|
|
await asyncio.sleep(typing_timeout)
|
|
self._typing_users[room_id].discard(user_id)
|
|
if user_id in self._typing_tasks:
|
|
del self._typing_tasks[user_id]
|
|
|
|
self._typing_tasks[user_id] = asyncio.create_task(clear_typing())
|
|
else:
|
|
self._typing_users[room_id].discard(user_id)
|
|
if user_id in self._typing_tasks:
|
|
self._typing_tasks[user_id].cancel()
|
|
del self._typing_tasks[user_id]
|
|
|
|
def get_typing_users(self, room_id: str) -> List[str]:
|
|
"""
|
|
Get list of users currently typing in a room
|
|
|
|
Args:
|
|
room_id: Room ID
|
|
|
|
Returns:
|
|
List of user IDs
|
|
"""
|
|
if room_id not in self._typing_users:
|
|
return []
|
|
return list(self._typing_users[room_id])
|
|
|
|
async def send_heartbeat(self, conn_info: ConnectionInfo):
|
|
"""
|
|
Send a ping to check connection health
|
|
|
|
Args:
|
|
conn_info: Connection to ping
|
|
"""
|
|
try:
|
|
await conn_info.websocket.send_json({"type": "ping"})
|
|
except Exception:
|
|
await self.disconnect(conn_info)
|
|
|
|
|
|
# Global WebSocket manager instance
|
|
manager = WebSocketManager()
|