import asyncio import logging import time from typing import Optional from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query from sqlalchemy.orm import Session from app.core.database import SessionLocal from app.core.security import decode_access_token from app.core.redis import get_redis_sync from app.models import User, Notification, Project from app.services.websocket_manager import manager from app.core.redis_pubsub import NotificationSubscriber, ProjectTaskSubscriber from app.middleware.auth import check_project_access logger = logging.getLogger(__name__) router = APIRouter(tags=["websocket"]) # Heartbeat configuration PING_INTERVAL = 60.0 # Send ping after this many seconds of no messages PONG_TIMEOUT = 30.0 # Disconnect if no pong received within this time after ping # Authentication timeout (10 seconds) AUTH_TIMEOUT = 10.0 async def get_user_from_token(token: str) -> tuple[str | None, User | None]: """Validate token and return user_id and user object.""" payload = decode_access_token(token) if payload is None: return None, None user_id = payload.get("sub") if user_id is None: return None, None # Verify session in Redis redis_client = get_redis_sync() stored_token = redis_client.get(f"session:{user_id}") if stored_token is None or stored_token != token: return None, None # Get user from database db = SessionLocal() try: user = db.query(User).filter(User.id == user_id).first() if user is None or not user.is_active: return None, None return user_id, user finally: db.close() async def authenticate_websocket( websocket: WebSocket, query_token: Optional[str] = None ) -> tuple[str | None, User | None]: """ Authenticate WebSocket connection. Supports two authentication methods: 1. First message authentication (preferred, more secure) - Client sends: {"type": "auth", "token": ""} 2. Query parameter authentication (deprecated, for backward compatibility) - Client connects with: ?token= Returns (user_id, user) if authenticated, (None, None) otherwise. """ # If token provided via query parameter (backward compatibility) if query_token: logger.warning( "WebSocket authentication via query parameter is deprecated. " "Please use first-message authentication for better security." ) return await get_user_from_token(query_token) # Wait for authentication message with timeout try: data = await asyncio.wait_for( websocket.receive_json(), timeout=AUTH_TIMEOUT ) msg_type = data.get("type") if msg_type != "auth": logger.warning("Expected 'auth' message type, got: %s", msg_type) return None, None token = data.get("token") if not token: logger.warning("No token provided in auth message") return None, None return await get_user_from_token(token) except asyncio.TimeoutError: logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT) return None, None except Exception as e: logger.error("Error during WebSocket authentication: %s", e) return None, None async def get_unread_notifications(user_id: str) -> list[dict]: """Query all unread notifications for a user.""" db = SessionLocal() try: notifications = ( db.query(Notification) .filter(Notification.user_id == user_id, Notification.is_read == False) .order_by(Notification.created_at.desc()) .all() ) return [ { "id": n.id, "type": n.type, "reference_type": n.reference_type, "reference_id": n.reference_id, "title": n.title, "message": n.message, "is_read": n.is_read, "created_at": n.created_at.isoformat() if n.created_at else None, } for n in notifications ] finally: db.close() async def get_unread_count(user_id: str) -> int: """Get the count of unread notifications for a user.""" db = SessionLocal() try: return ( db.query(Notification) .filter(Notification.user_id == user_id, Notification.is_read == False) .count() ) finally: db.close() @router.websocket("/ws/notifications") async def websocket_notifications( websocket: WebSocket, token: Optional[str] = Query(None, description="JWT token (deprecated, use first-message auth)"), ): """ WebSocket endpoint for real-time notifications. Authentication methods (in order of preference): 1. First message authentication (recommended): - Connect without token: ws://host/ws/notifications - Send: {"type": "auth", "token": ""} - Must authenticate within 10 seconds or connection will be closed 2. Query parameter (deprecated, for backward compatibility): - Connect with: ws://host/ws/notifications?token= Messages sent by server: - {"type": "auth_required"} - Sent when waiting for auth message - {"type": "connected", "data": {"user_id": "...", "message": "..."}} - Connection success - {"type": "unread_sync", "data": {"notifications": [...], "unread_count": N}} - All unread on connect - {"type": "notification", "data": {...}} - New notification - {"type": "unread_count", "data": {"unread_count": N}} - Unread count update - {"type": "ping"} - Server keepalive ping - {"type": "pong"} - Response to client ping Messages accepted from client: - {"type": "auth", "token": "..."} - Authentication (must be first message if no query token) - {"type": "ping"} - Client keepalive ping """ # Accept WebSocket connection first await websocket.accept() # If no query token, notify client that auth is required if not token: await websocket.send_json({"type": "auth_required"}) # Authenticate user_id, user = await authenticate_websocket(websocket, token) if user_id is None: await websocket.close(code=4001, reason="Invalid or expired token") return await manager.connect(websocket, user_id) subscriber = NotificationSubscriber(user_id) async def handle_redis_message(notification_data: dict): """Forward Redis pub/sub messages to WebSocket.""" try: await websocket.send_json({ "type": "notification", "data": notification_data, }) # Also send updated unread count unread_count = await get_unread_count(user_id) await websocket.send_json({ "type": "unread_count", "data": {"unread_count": unread_count}, }) except Exception as e: logger.error(f"Error forwarding notification to WebSocket: {e}") redis_task = None try: # Send initial connection success message await websocket.send_json({ "type": "connected", "data": {"user_id": user_id, "message": "Connected to notification service"}, }) # Send all unread notifications on connect (unread_sync) unread_notifications = await get_unread_notifications(user_id) await websocket.send_json({ "type": "unread_sync", "data": { "notifications": unread_notifications, "unread_count": len(unread_notifications), }, }) # Start Redis pub/sub subscription in background await subscriber.start() redis_task = asyncio.create_task(subscriber.listen(handle_redis_message)) # Heartbeat tracking waiting_for_pong = False ping_sent_at = 0.0 last_activity = time.time() while True: # Calculate appropriate timeout based on state if waiting_for_pong: # When waiting for pong, use remaining pong timeout remaining = PONG_TIMEOUT - (time.time() - ping_sent_at) if remaining <= 0: logger.warning(f"Pong timeout for user {user_id}, disconnecting") break timeout = remaining else: # When not waiting, use remaining ping interval remaining = PING_INTERVAL - (time.time() - last_activity) if remaining <= 0: # Time to send ping immediately try: await websocket.send_json({"type": "ping"}) waiting_for_pong = True ping_sent_at = time.time() last_activity = ping_sent_at timeout = PONG_TIMEOUT except Exception: break else: timeout = remaining try: # Wait for messages from client data = await asyncio.wait_for( websocket.receive_json(), timeout=timeout ) last_activity = time.time() msg_type = data.get("type") # Handle ping message from client if msg_type == "ping": await websocket.send_json({"type": "pong"}) # Handle pong message from client (response to our ping) elif msg_type == "pong": waiting_for_pong = False logger.debug(f"Pong received from user {user_id}") except asyncio.TimeoutError: if waiting_for_pong: # Strict timeout check if time.time() - ping_sent_at >= PONG_TIMEOUT: logger.warning(f"Pong timeout for user {user_id}, disconnecting") break # If not waiting_for_pong, loop will handle sending ping at top except WebSocketDisconnect: pass except Exception as e: logger.error(f"WebSocket error: {e}") finally: # Clean up Redis subscription if redis_task: redis_task.cancel() try: await redis_task except asyncio.CancelledError: pass await subscriber.stop() await manager.disconnect(websocket, user_id) async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Project | None]: """ Check if user has access to the project. Args: user_id: The user's ID project_id: The project's ID Returns: Tuple of (has_access: bool, project: Project | None) """ db = SessionLocal() try: # Get the user user = db.query(User).filter(User.id == user_id).first() if user is None or not user.is_active: return False, None # Get the project project = db.query(Project).filter(Project.id == project_id).first() if project is None: return False, None # Check access using existing middleware function has_access = check_project_access(user, project) return has_access, project finally: db.close() @router.websocket("/ws/projects/{project_id}") async def websocket_project_sync( websocket: WebSocket, project_id: str, token: Optional[str] = Query(None, description="JWT token (deprecated, use first-message auth)"), ): """ WebSocket endpoint for project task real-time sync. Authentication methods (in order of preference): 1. First message authentication (recommended): - Connect without token: ws://host/ws/projects/{project_id} - Send: {"type": "auth", "token": ""} - Must authenticate within 10 seconds or connection will be closed 2. Query parameter (deprecated, for backward compatibility): - Connect with: ws://host/ws/projects/{project_id}?token= Messages sent by server: - {"type": "auth_required"} - Sent when waiting for auth message - {"type": "connected", "data": {"project_id": "...", "user_id": "..."}} - {"type": "task_created", "data": {...}, "triggered_by": "..."} - {"type": "task_updated", "data": {...}, "triggered_by": "..."} - {"type": "task_status_changed", "data": {...}, "triggered_by": "..."} - {"type": "task_deleted", "data": {...}, "triggered_by": "..."} - {"type": "task_assigned", "data": {...}, "triggered_by": "..."} - {"type": "ping"} / {"type": "pong"} Messages accepted from client: - {"type": "auth", "token": "..."} - Authentication (must be first message if no query token) - {"type": "ping"} - Client keepalive ping """ # Accept WebSocket connection first await websocket.accept() # If no query token, notify client that auth is required if not token: await websocket.send_json({"type": "auth_required"}) # Authenticate user user_id, user = await authenticate_websocket(websocket, token) if user_id is None: await websocket.close(code=4001, reason="Invalid or expired token") return # Verify user has access to the project has_access, project = await verify_project_access(user_id, project_id) if not has_access: await websocket.close(code=4003, reason="Access denied to this project") return if project is None: await websocket.close(code=4004, reason="Project not found") return # Join project room await manager.join_project(websocket, user_id, project_id) # Create Redis subscriber for project task events subscriber = ProjectTaskSubscriber(project_id) async def handle_redis_message(event_data: dict): """Forward Redis pub/sub task events to WebSocket.""" try: # Forward the event directly (it already contains type, data, triggered_by) await websocket.send_json(event_data) except Exception as e: logger.error(f"Error forwarding task event to WebSocket: {e}") redis_task = None try: # Send initial connection success message await websocket.send_json({ "type": "connected", "data": { "project_id": project_id, "user_id": user_id, "project_title": project.title if project else None, }, }) logger.info(f"User {user_id} connected to project {project_id} WebSocket") # Start Redis pub/sub subscription in background await subscriber.start() redis_task = asyncio.create_task(subscriber.listen(handle_redis_message)) # Heartbeat tracking (reuse same configuration as notifications) waiting_for_pong = False ping_sent_at = 0.0 last_activity = time.time() while True: # Calculate appropriate timeout based on state if waiting_for_pong: # When waiting for pong, use remaining pong timeout remaining = PONG_TIMEOUT - (time.time() - ping_sent_at) if remaining <= 0: logger.warning(f"Pong timeout for user {user_id} in project {project_id}, disconnecting") break timeout = remaining else: # When not waiting, use remaining ping interval remaining = PING_INTERVAL - (time.time() - last_activity) if remaining <= 0: # Time to send ping immediately try: await websocket.send_json({"type": "ping"}) waiting_for_pong = True ping_sent_at = time.time() last_activity = ping_sent_at timeout = PONG_TIMEOUT except Exception: break else: timeout = remaining try: # Wait for messages from client data = await asyncio.wait_for( websocket.receive_json(), timeout=timeout ) last_activity = time.time() msg_type = data.get("type") # Handle ping message from client if msg_type == "ping": await websocket.send_json({"type": "pong"}) # Handle pong message from client (response to our ping) elif msg_type == "pong": waiting_for_pong = False logger.debug(f"Pong received from user {user_id} in project {project_id}") except asyncio.TimeoutError: if waiting_for_pong: # Strict timeout check if time.time() - ping_sent_at >= PONG_TIMEOUT: logger.warning(f"Pong timeout for user {user_id} in project {project_id}, disconnecting") break # If not waiting_for_pong, loop will handle sending ping at top except WebSocketDisconnect: logger.info(f"User {user_id} disconnected from project {project_id} WebSocket") except Exception as e: logger.error(f"WebSocket error for project {project_id}: {e}") finally: # Clean up Redis subscription if redis_task: redis_task.cancel() try: await redis_task except asyncio.CancelledError: pass await subscriber.stop() await manager.leave_project(websocket, user_id, project_id) logger.info(f"User {user_id} left project {project_id} room")