import json import asyncio from typing import Dict, Set, Optional from fastapi import WebSocket from app.core.redis import get_redis_sync class ConnectionManager: """Manager for WebSocket connections.""" def __init__(self): # user_id -> set of WebSocket connections self.active_connections: Dict[str, Set[WebSocket]] = {} self._lock = asyncio.Lock() async def connect(self, websocket: WebSocket, user_id: str): """Accept and track a new WebSocket connection.""" await websocket.accept() async with self._lock: if user_id not in self.active_connections: self.active_connections[user_id] = set() self.active_connections[user_id].add(websocket) async def disconnect(self, websocket: WebSocket, user_id: str): """Remove a WebSocket connection.""" async with self._lock: if user_id in self.active_connections: self.active_connections[user_id].discard(websocket) if not self.active_connections[user_id]: del self.active_connections[user_id] async def send_personal_message(self, message: dict, user_id: str): """Send a message to all connections of a specific user.""" if user_id in self.active_connections: disconnected = set() for connection in self.active_connections[user_id]: try: await connection.send_json(message) except Exception: disconnected.add(connection) # Clean up disconnected connections for conn in disconnected: await self.disconnect(conn, user_id) async def broadcast(self, message: dict): """Broadcast a message to all connected users.""" for user_id in list(self.active_connections.keys()): await self.send_personal_message(message, user_id) def is_connected(self, user_id: str) -> bool: """Check if a user has any active connections.""" return user_id in self.active_connections and len(self.active_connections[user_id]) > 0 # Global connection manager instance manager = ConnectionManager() async def publish_notification(user_id: str, notification_data: dict): """ Publish a notification to a user via WebSocket. This can be called from anywhere in the application to send real-time notifications to connected users. """ message = { "type": "notification", "data": notification_data, } await manager.send_personal_message(message, user_id) async def publish_notification_count_update(user_id: str, unread_count: int): """ Publish an unread count update to a user. """ message = { "type": "unread_count", "data": {"unread_count": unread_count}, } await manager.send_personal_message(message, user_id)