import asyncio import logging import time from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query from sqlalchemy.orm import Session from app.core.database import SessionLocal from app.core.security import decode_access_token from app.core.redis import get_redis_sync from app.models import User, Notification from app.services.websocket_manager import manager from app.core.redis_pubsub import NotificationSubscriber logger = logging.getLogger(__name__) router = APIRouter(tags=["websocket"]) # Heartbeat configuration PING_INTERVAL = 60.0 # Send ping after this many seconds of no messages PONG_TIMEOUT = 30.0 # Disconnect if no pong received within this time after ping async def get_user_from_token(token: str) -> tuple[str | None, User | None]: """Validate token and return user_id and user object.""" payload = decode_access_token(token) if payload is None: return None, None user_id = payload.get("sub") if user_id is None: return None, None # Verify session in Redis redis_client = get_redis_sync() stored_token = redis_client.get(f"session:{user_id}") if stored_token is None or stored_token != token: return None, None # Get user from database db = SessionLocal() try: user = db.query(User).filter(User.id == user_id).first() if user is None or not user.is_active: return None, None return user_id, user finally: db.close() async def get_unread_notifications(user_id: str) -> list[dict]: """Query all unread notifications for a user.""" db = SessionLocal() try: notifications = ( db.query(Notification) .filter(Notification.user_id == user_id, Notification.is_read == False) .order_by(Notification.created_at.desc()) .all() ) return [ { "id": n.id, "type": n.type, "reference_type": n.reference_type, "reference_id": n.reference_id, "title": n.title, "message": n.message, "is_read": n.is_read, "created_at": n.created_at.isoformat() if n.created_at else None, } for n in notifications ] finally: db.close() async def get_unread_count(user_id: str) -> int: """Get the count of unread notifications for a user.""" db = SessionLocal() try: return ( db.query(Notification) .filter(Notification.user_id == user_id, Notification.is_read == False) .count() ) finally: db.close() @router.websocket("/ws/notifications") async def websocket_notifications( websocket: WebSocket, token: str = Query(..., description="JWT token for authentication"), ): """ WebSocket endpoint for real-time notifications. Connect with: ws://host/ws/notifications?token= Messages sent by server: - {"type": "connected", "data": {"user_id": "...", "message": "..."}} - Connection success - {"type": "unread_sync", "data": {"notifications": [...], "unread_count": N}} - All unread on connect - {"type": "notification", "data": {...}} - New notification - {"type": "unread_count", "data": {"unread_count": N}} - Unread count update - {"type": "ping"} - Server keepalive ping - {"type": "pong"} - Response to client ping Messages accepted from client: - {"type": "ping"} - Client keepalive ping """ user_id, user = await get_user_from_token(token) if user_id is None: await websocket.close(code=4001, reason="Invalid or expired token") return await manager.connect(websocket, user_id) subscriber = NotificationSubscriber(user_id) async def handle_redis_message(notification_data: dict): """Forward Redis pub/sub messages to WebSocket.""" try: await websocket.send_json({ "type": "notification", "data": notification_data, }) # Also send updated unread count unread_count = await get_unread_count(user_id) await websocket.send_json({ "type": "unread_count", "data": {"unread_count": unread_count}, }) except Exception as e: logger.error(f"Error forwarding notification to WebSocket: {e}") redis_task = None try: # Send initial connection success message await websocket.send_json({ "type": "connected", "data": {"user_id": user_id, "message": "Connected to notification service"}, }) # Send all unread notifications on connect (unread_sync) unread_notifications = await get_unread_notifications(user_id) await websocket.send_json({ "type": "unread_sync", "data": { "notifications": unread_notifications, "unread_count": len(unread_notifications), }, }) # Start Redis pub/sub subscription in background await subscriber.start() redis_task = asyncio.create_task(subscriber.listen(handle_redis_message)) # Heartbeat tracking waiting_for_pong = False ping_sent_at = 0.0 last_activity = time.time() while True: # Calculate appropriate timeout based on state if waiting_for_pong: # When waiting for pong, use remaining pong timeout remaining = PONG_TIMEOUT - (time.time() - ping_sent_at) if remaining <= 0: logger.warning(f"Pong timeout for user {user_id}, disconnecting") break timeout = remaining else: # When not waiting, use remaining ping interval remaining = PING_INTERVAL - (time.time() - last_activity) if remaining <= 0: # Time to send ping immediately try: await websocket.send_json({"type": "ping"}) waiting_for_pong = True ping_sent_at = time.time() last_activity = ping_sent_at timeout = PONG_TIMEOUT except Exception: break else: timeout = remaining try: # Wait for messages from client data = await asyncio.wait_for( websocket.receive_json(), timeout=timeout ) last_activity = time.time() msg_type = data.get("type") # Handle ping message from client if msg_type == "ping": await websocket.send_json({"type": "pong"}) # Handle pong message from client (response to our ping) elif msg_type == "pong": waiting_for_pong = False logger.debug(f"Pong received from user {user_id}") except asyncio.TimeoutError: if waiting_for_pong: # Strict timeout check if time.time() - ping_sent_at >= PONG_TIMEOUT: logger.warning(f"Pong timeout for user {user_id}, disconnecting") break # If not waiting_for_pong, loop will handle sending ping at top except WebSocketDisconnect: pass except Exception as e: logger.error(f"WebSocket error: {e}") finally: # Clean up Redis subscription if redis_task: redis_task.cancel() try: await redis_task except asyncio.CancelledError: pass await subscriber.stop() await manager.disconnect(websocket, user_id)