"""WebSocket connection pool management""" from fastapi import WebSocket from typing import Dict, List, Set, Any from datetime import datetime import asyncio import json from collections import defaultdict def json_serializer(obj: Any) -> str: """Custom JSON serializer for objects not serializable by default json code""" if isinstance(obj, datetime): return obj.isoformat() raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") class ConnectionInfo: """Information about a WebSocket connection""" def __init__(self, websocket: WebSocket, user_id: str, room_id: str): self.websocket = websocket self.user_id = user_id self.room_id = room_id self.connected_at = datetime.utcnow() self.last_sequence = 0 # Track last received sequence number for reconnection class WebSocketManager: """Manages WebSocket connections and message broadcasting""" def __init__(self): # room_id -> Set of ConnectionInfo self._room_connections: Dict[str, Set[ConnectionInfo]] = defaultdict(set) # user_id -> ConnectionInfo (for direct messaging) self._user_connections: Dict[str, ConnectionInfo] = {} # room_id -> Set of user_ids (typing users) self._typing_users: Dict[str, Set[str]] = defaultdict(set) # user_id -> asyncio.Task (typing timeout tasks) self._typing_tasks: Dict[str, asyncio.Task] = {} async def connect(self, websocket: WebSocket, room_id: str, user_id: str) -> ConnectionInfo: """ Add a WebSocket connection to the pool Args: websocket: The WebSocket connection room_id: Room ID the user is connecting to user_id: User ID Returns: ConnectionInfo object """ await websocket.accept() conn_info = ConnectionInfo(websocket, user_id, room_id) self._room_connections[room_id].add(conn_info) self._user_connections[user_id] = conn_info return conn_info async def disconnect(self, conn_info: ConnectionInfo): """ Remove a WebSocket connection from the pool Args: conn_info: Connection info to remove """ room_id = conn_info.room_id user_id = conn_info.user_id # Remove from room connections if room_id in self._room_connections: self._room_connections[room_id].discard(conn_info) if not self._room_connections[room_id]: del self._room_connections[room_id] # Remove from user connections if user_id in self._user_connections: del self._user_connections[user_id] # Clear typing status if user_id in self._typing_tasks: self._typing_tasks[user_id].cancel() del self._typing_tasks[user_id] if room_id in self._typing_users: self._typing_users[room_id].discard(user_id) async def broadcast_to_room(self, room_id: str, message: dict, exclude_user: str = None): """ Broadcast a message to all connections in a room Args: room_id: Room ID to broadcast to message: Message dictionary to broadcast exclude_user: Optional user ID to exclude from broadcast """ if room_id not in self._room_connections: return message_json = json.dumps(message, default=json_serializer) # Collect disconnected connections disconnected = [] for conn_info in self._room_connections[room_id]: if exclude_user and conn_info.user_id == exclude_user: continue try: await conn_info.websocket.send_text(message_json) except Exception as e: # Connection failed, mark for removal disconnected.append(conn_info) # Clean up disconnected connections for conn_info in disconnected: await self.disconnect(conn_info) async def send_personal(self, user_id: str, message: dict): """ Send a message to a specific user Args: user_id: User ID to send to message: Message dictionary to send """ if user_id not in self._user_connections: return conn_info = self._user_connections[user_id] message_json = json.dumps(message, default=json_serializer) try: await conn_info.websocket.send_text(message_json) except Exception: # Connection failed, disconnect await self.disconnect(conn_info) def get_room_connections(self, room_id: str) -> List[ConnectionInfo]: """ Get all active connections for a room Args: room_id: Room ID Returns: List of ConnectionInfo objects """ if room_id not in self._room_connections: return [] return list(self._room_connections[room_id]) def get_online_users(self, room_id: str) -> List[str]: """ Get list of online user IDs in a room Args: room_id: Room ID Returns: List of user IDs """ return [conn.user_id for conn in self.get_room_connections(room_id)] def is_user_online(self, user_id: str) -> bool: """ Check if a user is currently connected Args: user_id: User ID to check Returns: True if user is connected """ return user_id in self._user_connections async def set_typing(self, room_id: str, user_id: str, is_typing: bool): """ Set typing status for a user in a room Args: room_id: Room ID user_id: User ID is_typing: Whether user is typing """ if is_typing: self._typing_users[room_id].add(user_id) # Cancel existing timeout task if user_id in self._typing_tasks: self._typing_tasks[user_id].cancel() # Set new timeout (3 seconds) async def clear_typing(): await asyncio.sleep(3) self._typing_users[room_id].discard(user_id) if user_id in self._typing_tasks: del self._typing_tasks[user_id] self._typing_tasks[user_id] = asyncio.create_task(clear_typing()) else: self._typing_users[room_id].discard(user_id) if user_id in self._typing_tasks: self._typing_tasks[user_id].cancel() del self._typing_tasks[user_id] def get_typing_users(self, room_id: str) -> List[str]: """ Get list of users currently typing in a room Args: room_id: Room ID Returns: List of user IDs """ if room_id not in self._typing_users: return [] return list(self._typing_users[room_id]) async def send_heartbeat(self, conn_info: ConnectionInfo): """ Send a ping to check connection health Args: conn_info: Connection to ping """ try: await conn_info.websocket.send_json({"type": "ping"}) except Exception: await self.disconnect(conn_info) # Global WebSocket manager instance manager = WebSocketManager()