import asyncio import os import logging import time from typing import Optional from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query from sqlalchemy.orm import Session from app.core import database from app.core.security import decode_access_token from app.core.redis import get_redis_sync from app.models import User, Notification, Project from app.services.websocket_manager import manager from app.core.redis_pubsub import NotificationSubscriber, ProjectTaskSubscriber from app.middleware.auth import check_project_access logger = logging.getLogger(__name__) router = APIRouter(tags=["websocket"]) # Heartbeat configuration PING_INTERVAL = 60.0 # Send ping after this many seconds of no messages PONG_TIMEOUT = 30.0 # Disconnect if no pong received within this time after ping # Authentication timeout (10 seconds) AUTH_TIMEOUT = 10.0 if os.getenv("TESTING") == "true": AUTH_TIMEOUT = 1.0 async def get_user_from_token(token: str) -> str | None: """ Validate token and return user_id. Returns: user_id if valid, None otherwise. Note: This function properly closes the database session after validation. Do not return ORM objects as they become detached after session close. """ payload = decode_access_token(token) if payload is None: return None user_id = payload.get("sub") if user_id is None: return None # Verify session in Redis redis_client = get_redis_sync() stored_token = redis_client.get(f"session:{user_id}") if stored_token is None or stored_token != token: return None # Verify user exists and is active db = database.SessionLocal() try: user = db.query(User).filter(User.id == user_id).first() if user is None or not user.is_active: return None return user_id finally: db.close() async def authenticate_websocket( websocket: WebSocket, query_token: Optional[str] = None ) -> tuple[str | None, Optional[str]]: """ Authenticate WebSocket connection. Supports two authentication methods: 1. First message authentication (preferred, more secure) - Client sends: {"type": "auth", "token": ""} 2. Query parameter authentication (deprecated, for backward compatibility) - Client connects with: ?token= Returns: Tuple of (user_id, error_reason). user_id is None if authentication fails. """ # If token provided via query parameter (backward compatibility) if query_token: logger.warning( "WebSocket authentication via query parameter is deprecated. " "Please use first-message authentication for better security." ) user_id = await get_user_from_token(query_token) if user_id is None: return None, "invalid_token" return user_id, None # Wait for authentication message with timeout try: data = await asyncio.wait_for( websocket.receive_json(), timeout=AUTH_TIMEOUT ) msg_type = data.get("type") if msg_type != "auth": logger.warning("Expected 'auth' message type, got: %s", msg_type) return None, "invalid_message" token = data.get("token") if not token: logger.warning("No token provided in auth message") return None, "missing_token" user_id = await get_user_from_token(token) if user_id is None: return None, "invalid_token" return user_id, None except asyncio.TimeoutError: logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT) return None, "timeout" except Exception as e: logger.error("Error during WebSocket authentication: %s", e) return None, "error" async def get_unread_notifications(user_id: str) -> list[dict]: """Query all unread notifications for a user.""" db = database.SessionLocal() try: notifications = ( db.query(Notification) .filter(Notification.user_id == user_id, Notification.is_read == False) .order_by(Notification.created_at.desc()) .all() ) return [ { "id": n.id, "type": n.type, "reference_type": n.reference_type, "reference_id": n.reference_id, "title": n.title, "message": n.message, "is_read": n.is_read, "created_at": n.created_at.isoformat() if n.created_at else None, } for n in notifications ] finally: db.close() async def get_unread_count(user_id: str) -> int: """Get the count of unread notifications for a user.""" db = database.SessionLocal() try: return ( db.query(Notification) .filter(Notification.user_id == user_id, Notification.is_read == False) .count() ) finally: db.close() @router.websocket("/ws/notifications") async def websocket_notifications( websocket: WebSocket, token: Optional[str] = Query(None, description="JWT token (deprecated, use first-message auth)"), ): """ WebSocket endpoint for real-time notifications. Authentication methods (in order of preference): 1. First message authentication (recommended): - Connect without token: ws://host/ws/notifications - Send: {"type": "auth", "token": ""} - Must authenticate within 10 seconds or connection will be closed 2. Query parameter (deprecated, for backward compatibility): - Connect with: ws://host/ws/notifications?token= Messages sent by server: - {"type": "connected", "data": {"user_id": "...", "message": "..."}} - Connection success - {"type": "unread_sync", "data": {"notifications": [...], "unread_count": N}} - All unread on connect - {"type": "notification", "data": {...}} - New notification - {"type": "unread_count", "data": {"unread_count": N}} - Unread count update - {"type": "ping"} - Server keepalive ping - {"type": "pong"} - Response to client ping Messages accepted from client: - {"type": "auth", "token": "..."} - Authentication (must be first message if no query token) - {"type": "ping"} - Client keepalive ping """ # Accept WebSocket connection first await websocket.accept() # Authenticate user_id, error_reason = await authenticate_websocket(websocket, token) if user_id is None: if error_reason == "invalid_token": await websocket.send_json({"type": "error", "message": "Invalid or expired token"}) await websocket.close(code=4001, reason="Invalid or expired token") return await manager.connect(websocket, user_id) subscriber = NotificationSubscriber(user_id) async def handle_redis_message(notification_data: dict): """Forward Redis pub/sub messages to WebSocket.""" try: await websocket.send_json({ "type": "notification", "data": notification_data, }) # Also send updated unread count unread_count = await get_unread_count(user_id) await websocket.send_json({ "type": "unread_count", "data": {"unread_count": unread_count}, }) except Exception as e: logger.error(f"Error forwarding notification to WebSocket: {e}") redis_task = None try: # Send initial connection success message await websocket.send_json({ "type": "connected", "data": {"user_id": user_id, "message": "Connected to notification service"}, }) # Send all unread notifications on connect (unread_sync) unread_notifications = await get_unread_notifications(user_id) await websocket.send_json({ "type": "unread_sync", "data": { "notifications": unread_notifications, "unread_count": len(unread_notifications), }, }) # Start Redis pub/sub subscription in background await subscriber.start() redis_task = asyncio.create_task(subscriber.listen(handle_redis_message)) # Heartbeat tracking waiting_for_pong = False ping_sent_at = 0.0 last_activity = time.time() while True: # Calculate appropriate timeout based on state if waiting_for_pong: # When waiting for pong, use remaining pong timeout remaining = PONG_TIMEOUT - (time.time() - ping_sent_at) if remaining <= 0: logger.warning(f"Pong timeout for user {user_id}, disconnecting") break timeout = remaining else: # When not waiting, use remaining ping interval remaining = PING_INTERVAL - (time.time() - last_activity) if remaining <= 0: # Time to send ping immediately try: await websocket.send_json({"type": "ping"}) waiting_for_pong = True ping_sent_at = time.time() last_activity = ping_sent_at timeout = PONG_TIMEOUT except Exception: break else: timeout = remaining try: # Wait for messages from client data = await asyncio.wait_for( websocket.receive_json(), timeout=timeout ) last_activity = time.time() msg_type = data.get("type") # Handle ping message from client if msg_type == "ping": await websocket.send_json({"type": "pong"}) # Handle pong message from client (response to our ping) elif msg_type == "pong": waiting_for_pong = False logger.debug(f"Pong received from user {user_id}") except asyncio.TimeoutError: if waiting_for_pong: # Strict timeout check if time.time() - ping_sent_at >= PONG_TIMEOUT: logger.warning(f"Pong timeout for user {user_id}, disconnecting") break # If not waiting_for_pong, loop will handle sending ping at top except WebSocketDisconnect: pass except Exception as e: logger.error(f"WebSocket error: {e}") finally: # Clean up Redis subscription if redis_task: redis_task.cancel() try: await redis_task except asyncio.CancelledError: pass await subscriber.stop() await manager.disconnect(websocket, user_id) async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, str | None, str | None]: """ Check if user has access to the project. Args: user_id: The user's ID project_id: The project's ID Returns: Tuple of (has_access: bool, project_title: str | None, error: str | None) - has_access: True if user can access the project - project_title: The project title (only if access granted) - error: Error code if access denied ("user_not_found", "project_not_found", "access_denied") Note: This function extracts needed data before closing the session to avoid detached instance errors when accessing ORM object attributes. """ db = database.SessionLocal() try: # Get the user user = db.query(User).filter(User.id == user_id).first() if user is None or not user.is_active: return False, None, "user_not_found" # Get the project project = db.query(Project).filter(Project.id == project_id).first() if project is None: return False, None, "project_not_found" # Check access using existing middleware function has_access = check_project_access(user, project) if not has_access: return False, None, "access_denied" # Extract title while session is still open project_title = project.title return True, project_title, None finally: db.close() @router.websocket("/ws/projects/{project_id}") async def websocket_project_sync( websocket: WebSocket, project_id: str, token: Optional[str] = Query(None, description="JWT token (deprecated, use first-message auth)"), ): """ WebSocket endpoint for project task real-time sync. Authentication methods (in order of preference): 1. First message authentication (recommended): - Connect without token: ws://host/ws/projects/{project_id} - Send: {"type": "auth", "token": ""} - Must authenticate within 10 seconds or connection will be closed 2. Query parameter (deprecated, for backward compatibility): - Connect with: ws://host/ws/projects/{project_id}?token= Messages sent by server: - {"type": "connected", "data": {"project_id": "...", "user_id": "..."}} - {"type": "task_created", "data": {...}, "triggered_by": "..."} - {"type": "task_updated", "data": {...}, "triggered_by": "..."} - {"type": "task_status_changed", "data": {...}, "triggered_by": "..."} - {"type": "task_deleted", "data": {...}, "triggered_by": "..."} - {"type": "task_assigned", "data": {...}, "triggered_by": "..."} - {"type": "ping"} / {"type": "pong"} Messages accepted from client: - {"type": "auth", "token": "..."} - Authentication (must be first message if no query token) - {"type": "ping"} - Client keepalive ping """ # Accept WebSocket connection first await websocket.accept() # Authenticate user user_id, error_reason = await authenticate_websocket(websocket, token) if user_id is None: if error_reason == "invalid_token": await websocket.send_json({"type": "error", "message": "Invalid or expired token"}) await websocket.close(code=4001, reason="Invalid or expired token") return # Verify user has access to the project has_access, project_title, access_error = await verify_project_access(user_id, project_id) if not has_access: if access_error == "project_not_found": await websocket.close(code=4004, reason="Project not found") else: await websocket.close(code=4003, reason="Access denied to this project") return # Join project room await manager.join_project(websocket, user_id, project_id) # Create Redis subscriber for project task events subscriber = ProjectTaskSubscriber(project_id) async def handle_redis_message(event_data: dict): """Forward Redis pub/sub task events to WebSocket.""" try: # Forward the event directly (it already contains type, data, triggered_by) await websocket.send_json(event_data) except Exception as e: logger.error(f"Error forwarding task event to WebSocket: {e}") redis_task = None try: # Send initial connection success message await websocket.send_json({ "type": "connected", "data": { "project_id": project_id, "user_id": user_id, "project_title": project_title, }, }) logger.info(f"User {user_id} connected to project {project_id} WebSocket") # Start Redis pub/sub subscription in background await subscriber.start() redis_task = asyncio.create_task(subscriber.listen(handle_redis_message)) # Heartbeat tracking (reuse same configuration as notifications) waiting_for_pong = False ping_sent_at = 0.0 last_activity = time.time() while True: # Calculate appropriate timeout based on state if waiting_for_pong: # When waiting for pong, use remaining pong timeout remaining = PONG_TIMEOUT - (time.time() - ping_sent_at) if remaining <= 0: logger.warning(f"Pong timeout for user {user_id} in project {project_id}, disconnecting") break timeout = remaining else: # When not waiting, use remaining ping interval remaining = PING_INTERVAL - (time.time() - last_activity) if remaining <= 0: # Time to send ping immediately try: await websocket.send_json({"type": "ping"}) waiting_for_pong = True ping_sent_at = time.time() last_activity = ping_sent_at timeout = PONG_TIMEOUT except Exception: break else: timeout = remaining try: # Wait for messages from client data = await asyncio.wait_for( websocket.receive_json(), timeout=timeout ) last_activity = time.time() msg_type = data.get("type") # Handle ping message from client if msg_type == "ping": await websocket.send_json({"type": "pong"}) # Handle pong message from client (response to our ping) elif msg_type == "pong": waiting_for_pong = False logger.debug(f"Pong received from user {user_id} in project {project_id}") except asyncio.TimeoutError: if waiting_for_pong: # Strict timeout check if time.time() - ping_sent_at >= PONG_TIMEOUT: logger.warning(f"Pong timeout for user {user_id} in project {project_id}, disconnecting") break # If not waiting_for_pong, loop will handle sending ping at top except WebSocketDisconnect: logger.info(f"User {user_id} disconnected from project {project_id} WebSocket") except Exception as e: logger.error(f"WebSocket error for project {project_id}: {e}") finally: # Clean up Redis subscription if redis_task: redis_task.cancel() try: await redis_task except asyncio.CancelledError: pass await subscriber.stop() await manager.leave_project(websocket, user_id, project_id) logger.info(f"User {user_id} left project {project_id} room")