"""WebSocket connection pool management""" from fastapi import WebSocket from typing import Dict, List, Set, Any from datetime import datetime import asyncio import json from collections import defaultdict from app.core.config import get_settings settings = get_settings() def json_serializer(obj: Any) -> str: """Custom JSON serializer for objects not serializable by default json code""" if isinstance(obj, datetime): # Append 'Z' to indicate UTC so JavaScript parses it correctly return obj.isoformat() + 'Z' raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") class ConnectionInfo: """Information about a WebSocket connection""" def __init__(self, websocket: WebSocket, user_id: str, room_id: str): self.websocket = websocket self.user_id = user_id self.room_id = room_id self.connected_at = datetime.utcnow() self.last_sequence = 0 # Track last received sequence number for reconnection class WebSocketManager: """Manages WebSocket connections and message broadcasting""" def __init__(self): # room_id -> Set of ConnectionInfo self._room_connections: Dict[str, Set[ConnectionInfo]] = defaultdict(set) # user_id -> ConnectionInfo (for direct messaging) self._user_connections: Dict[str, ConnectionInfo] = {} # room_id -> Set of user_ids (typing users) self._typing_users: Dict[str, Set[str]] = defaultdict(set) # user_id -> asyncio.Task (typing timeout tasks) self._typing_tasks: Dict[str, asyncio.Task] = {} async def connect(self, websocket: WebSocket, room_id: str, user_id: str) -> ConnectionInfo: """ Add a WebSocket connection to the pool Args: websocket: The WebSocket connection room_id: Room ID the user is connecting to user_id: User ID Returns: ConnectionInfo object """ await websocket.accept() conn_info = ConnectionInfo(websocket, user_id, room_id) self._room_connections[room_id].add(conn_info) self._user_connections[user_id] = conn_info return conn_info async def disconnect(self, conn_info: ConnectionInfo): """ Remove a WebSocket connection from the pool Args: conn_info: Connection info to remove """ room_id = conn_info.room_id user_id = conn_info.user_id # Remove from room connections if room_id in self._room_connections: self._room_connections[room_id].discard(conn_info) if not self._room_connections[room_id]: del self._room_connections[room_id] # Remove from user connections if user_id in self._user_connections: del self._user_connections[user_id] # Clear typing status if user_id in self._typing_tasks: self._typing_tasks[user_id].cancel() del self._typing_tasks[user_id] if room_id in self._typing_users: self._typing_users[room_id].discard(user_id) async def broadcast_to_room(self, room_id: str, message: dict, exclude_user: str = None): """ Broadcast a message to all connections in a room Args: room_id: Room ID to broadcast to message: Message dictionary to broadcast exclude_user: Optional user ID to exclude from broadcast """ if room_id not in self._room_connections: return message_json = json.dumps(message, default=json_serializer) # Collect disconnected connections disconnected = [] for conn_info in self._room_connections[room_id]: if exclude_user and conn_info.user_id == exclude_user: continue try: await conn_info.websocket.send_text(message_json) except Exception as e: # Connection failed, mark for removal disconnected.append(conn_info) # Clean up disconnected connections for conn_info in disconnected: await self.disconnect(conn_info) async def send_personal(self, user_id: str, message: dict): """ Send a message to a specific user Args: user_id: User ID to send to message: Message dictionary to send """ if user_id not in self._user_connections: return conn_info = self._user_connections[user_id] message_json = json.dumps(message, default=json_serializer) try: await conn_info.websocket.send_text(message_json) except Exception: # Connection failed, disconnect await self.disconnect(conn_info) def get_room_connections(self, room_id: str) -> List[ConnectionInfo]: """ Get all active connections for a room Args: room_id: Room ID Returns: List of ConnectionInfo objects """ if room_id not in self._room_connections: return [] return list(self._room_connections[room_id]) def get_online_users(self, room_id: str) -> List[str]: """ Get list of online user IDs in a room Args: room_id: Room ID Returns: List of user IDs """ return [conn.user_id for conn in self.get_room_connections(room_id)] def is_user_online(self, user_id: str) -> bool: """ Check if a user is currently connected Args: user_id: User ID to check Returns: True if user is connected """ return user_id in self._user_connections async def set_typing(self, room_id: str, user_id: str, is_typing: bool): """ Set typing status for a user in a room Args: room_id: Room ID user_id: User ID is_typing: Whether user is typing """ if is_typing: self._typing_users[room_id].add(user_id) # Cancel existing timeout task if user_id in self._typing_tasks: self._typing_tasks[user_id].cancel() # Set new timeout (configurable via TYPING_TIMEOUT_SECONDS) typing_timeout = settings.TYPING_TIMEOUT_SECONDS async def clear_typing(): await asyncio.sleep(typing_timeout) self._typing_users[room_id].discard(user_id) if user_id in self._typing_tasks: del self._typing_tasks[user_id] self._typing_tasks[user_id] = asyncio.create_task(clear_typing()) else: self._typing_users[room_id].discard(user_id) if user_id in self._typing_tasks: self._typing_tasks[user_id].cancel() del self._typing_tasks[user_id] def get_typing_users(self, room_id: str) -> List[str]: """ Get list of users currently typing in a room Args: room_id: Room ID Returns: List of user IDs """ if room_id not in self._typing_users: return [] return list(self._typing_users[room_id]) async def send_heartbeat(self, conn_info: ConnectionInfo): """ Send a ping to check connection health Args: conn_info: Connection to ping """ try: await conn_info.websocket.send_json({"type": "ping"}) except Exception: await self.disconnect(conn_info) # Global WebSocket manager instance manager = WebSocketManager()