import asyncio from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query from sqlalchemy.orm import Session from app.core.database import SessionLocal from app.core.security import decode_access_token from app.core.redis import get_redis_sync from app.models import User from app.services.websocket_manager import manager router = APIRouter(tags=["websocket"]) async def get_user_from_token(token: str) -> tuple[str | None, User | None]: """Validate token and return user_id and user object.""" payload = decode_access_token(token) if payload is None: return None, None user_id = payload.get("sub") if user_id is None: return None, None # Verify session in Redis redis_client = get_redis_sync() stored_token = redis_client.get(f"session:{user_id}") if stored_token is None or stored_token != token: return None, None # Get user from database db = SessionLocal() try: user = db.query(User).filter(User.id == user_id).first() if user is None or not user.is_active: return None, None return user_id, user finally: db.close() @router.websocket("/ws/notifications") async def websocket_notifications( websocket: WebSocket, token: str = Query(..., description="JWT token for authentication"), ): """ WebSocket endpoint for real-time notifications. Connect with: ws://host/ws/notifications?token= Messages sent by server: - {"type": "notification", "data": {...}} - New notification - {"type": "unread_count", "data": {"unread_count": N}} - Unread count update - {"type": "pong"} - Response to ping Messages accepted from client: - {"type": "ping"} - Keepalive ping """ user_id, user = await get_user_from_token(token) if user_id is None: await websocket.close(code=4001, reason="Invalid or expired token") return await manager.connect(websocket, user_id) try: # Send initial connection success message await websocket.send_json({ "type": "connected", "data": {"user_id": user_id, "message": "Connected to notification service"}, }) while True: try: # Wait for messages from client (ping/pong for keepalive) data = await asyncio.wait_for( websocket.receive_json(), timeout=60.0 # 60 second timeout ) # Handle ping message if data.get("type") == "ping": await websocket.send_json({"type": "pong"}) except asyncio.TimeoutError: # Send keepalive ping if no message received try: await websocket.send_json({"type": "ping"}) except Exception: break except WebSocketDisconnect: pass except Exception: pass finally: await manager.disconnect(websocket, user_id)