import json import asyncio import logging from typing import Dict, Set, Optional, Tuple from fastapi import WebSocket from app.core.redis import get_redis_sync from app.core.config import settings logger = logging.getLogger(__name__) class ConnectionManager: """Manager for WebSocket connections.""" def __init__(self): # user_id -> set of WebSocket connections (for notifications) self.active_connections: Dict[str, Set[WebSocket]] = {} # project_id -> set of (user_id, WebSocket) tuples (for project sync) self.project_connections: Dict[str, Set[Tuple[str, WebSocket]]] = {} self._lock = asyncio.Lock() self._project_lock = asyncio.Lock() async def check_connection_limit(self, user_id: str) -> Tuple[bool, Optional[str]]: """ Check if user can create a new WebSocket connection. Args: user_id: The user's ID Returns: Tuple of (can_connect: bool, reject_reason: str | None) - can_connect: True if user is within connection limit - reject_reason: Error message if connection should be rejected """ max_connections = settings.MAX_WEBSOCKET_CONNECTIONS_PER_USER async with self._lock: current_count = len(self.active_connections.get(user_id, set())) if current_count >= max_connections: logger.warning( f"User {user_id} exceeded WebSocket connection limit " f"({current_count}/{max_connections})" ) return False, "Too many connections" return True, None def get_user_connection_count(self, user_id: str) -> int: """Get the current number of WebSocket connections for a user.""" return len(self.active_connections.get(user_id, set())) async def connect(self, websocket: WebSocket, user_id: str): """ Track a new WebSocket connection. Note: WebSocket must already be accepted before calling this method. Connection limit should be checked via check_connection_limit() before calling. """ async with self._lock: if user_id not in self.active_connections: self.active_connections[user_id] = set() self.active_connections[user_id].add(websocket) logger.debug( f"User {user_id} connected. Total connections: " f"{len(self.active_connections[user_id])}" ) async def disconnect(self, websocket: WebSocket, user_id: str): """Remove a WebSocket connection.""" async with self._lock: if user_id in self.active_connections: self.active_connections[user_id].discard(websocket) if not self.active_connections[user_id]: del self.active_connections[user_id] async def send_personal_message(self, message: dict, user_id: str): """Send a message to all connections of a specific user.""" if user_id in self.active_connections: disconnected = set() for connection in self.active_connections[user_id]: try: await connection.send_json(message) except Exception: disconnected.add(connection) # Clean up disconnected connections for conn in disconnected: await self.disconnect(conn, user_id) async def broadcast(self, message: dict): """Broadcast a message to all connected users.""" for user_id in list(self.active_connections.keys()): await self.send_personal_message(message, user_id) def is_connected(self, user_id: str) -> bool: """Check if a user has any active connections.""" return user_id in self.active_connections and len(self.active_connections[user_id]) > 0 # Project room management methods async def join_project(self, websocket: WebSocket, user_id: str, project_id: str): """ Add user to a project room for real-time task sync. Args: websocket: The WebSocket connection user_id: The user's ID project_id: The project to join """ async with self._project_lock: if project_id not in self.project_connections: self.project_connections[project_id] = set() self.project_connections[project_id].add((user_id, websocket)) logger.debug(f"User {user_id} joined project room {project_id}") async def leave_project(self, websocket: WebSocket, user_id: str, project_id: str): """ Remove user from a project room. Args: websocket: The WebSocket connection user_id: The user's ID project_id: The project to leave """ async with self._project_lock: if project_id in self.project_connections: self.project_connections[project_id].discard((user_id, websocket)) if not self.project_connections[project_id]: del self.project_connections[project_id] logger.debug(f"User {user_id} left project room {project_id}") async def broadcast_to_project( self, project_id: str, message: dict, exclude_user_id: Optional[str] = None ): """ Broadcast message to all users in a project room. Args: project_id: The project room to broadcast to message: The message to send exclude_user_id: Optional user ID to exclude from broadcast (e.g., the sender) """ # Create snapshot while holding lock to prevent race condition async with self._project_lock: if project_id not in self.project_connections: return connections_snapshot = list(self.project_connections[project_id]) disconnected = set() for user_id, websocket in connections_snapshot: # Skip excluded user (sender) if exclude_user_id and user_id == exclude_user_id: continue try: await websocket.send_json(message) except Exception as e: logger.warning(f"Failed to send message to user {user_id} in project {project_id}: {e}") disconnected.add((user_id, websocket)) # Clean up disconnected connections if disconnected: async with self._project_lock: for conn in disconnected: if project_id in self.project_connections: self.project_connections[project_id].discard(conn) if not self.project_connections[project_id]: del self.project_connections[project_id] def get_project_user_count(self, project_id: str) -> int: """Get the number of unique users in a project room.""" if project_id not in self.project_connections: return 0 unique_users = set(user_id for user_id, _ in self.project_connections[project_id]) return len(unique_users) def is_user_in_project(self, user_id: str, project_id: str) -> bool: """Check if a user has any active connections to a project room.""" if project_id not in self.project_connections: return False return any(uid == user_id for uid, _ in self.project_connections[project_id]) # Global connection manager instance manager = ConnectionManager() async def publish_notification(user_id: str, notification_data: dict): """ Publish a notification to a user via WebSocket. This can be called from anywhere in the application to send real-time notifications to connected users. """ message = { "type": "notification", "data": notification_data, } await manager.send_personal_message(message, user_id) async def publish_notification_count_update(user_id: str, unread_count: int): """ Publish an unread count update to a user. """ message = { "type": "unread_count", "data": {"unread_count": unread_count}, } await manager.send_personal_message(message, user_id)