Files
PROJECT-CONTORL/backend/app/api/websocket/router.py
2026-01-11 08:37:21 +08:00

500 lines
18 KiB
Python

import asyncio
import os
import logging
import time
from typing import Optional
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from sqlalchemy.orm import Session
from app.core import database
from app.core.security import decode_access_token
from app.core.redis import get_redis_sync
from app.models import User, Notification, Project
from app.services.websocket_manager import manager
from app.core.redis_pubsub import NotificationSubscriber, ProjectTaskSubscriber
from app.middleware.auth import check_project_access
logger = logging.getLogger(__name__)
router = APIRouter(tags=["websocket"])
# Heartbeat configuration
PING_INTERVAL = 60.0 # Send ping after this many seconds of no messages
PONG_TIMEOUT = 30.0 # Disconnect if no pong received within this time after ping
# Authentication timeout (10 seconds)
AUTH_TIMEOUT = 10.0
if os.getenv("TESTING") == "true":
AUTH_TIMEOUT = 1.0
async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
"""Validate token and return user_id and user object."""
payload = decode_access_token(token)
if payload is None:
return None, None
user_id = payload.get("sub")
if user_id is None:
return None, None
# Verify session in Redis
redis_client = get_redis_sync()
stored_token = redis_client.get(f"session:{user_id}")
if stored_token is None or stored_token != token:
return None, None
# Get user from database
db = database.SessionLocal()
try:
user = db.query(User).filter(User.id == user_id).first()
if user is None or not user.is_active:
return None, None
return user_id, user
finally:
db.close()
async def authenticate_websocket(
websocket: WebSocket,
query_token: Optional[str] = None
) -> tuple[str | None, User | None, Optional[str]]:
"""
Authenticate WebSocket connection.
Supports two authentication methods:
1. First message authentication (preferred, more secure)
- Client sends: {"type": "auth", "token": "<jwt_token>"}
2. Query parameter authentication (deprecated, for backward compatibility)
- Client connects with: ?token=<jwt_token>
Returns (user_id, user) if authenticated, (None, None) otherwise.
"""
# If token provided via query parameter (backward compatibility)
if query_token:
logger.warning(
"WebSocket authentication via query parameter is deprecated. "
"Please use first-message authentication for better security."
)
user_id, user = await get_user_from_token(query_token)
if user_id is None:
return None, None, "invalid_token"
return user_id, user, None
# Wait for authentication message with timeout
try:
data = await asyncio.wait_for(
websocket.receive_json(),
timeout=AUTH_TIMEOUT
)
msg_type = data.get("type")
if msg_type != "auth":
logger.warning("Expected 'auth' message type, got: %s", msg_type)
return None, None, "invalid_message"
token = data.get("token")
if not token:
logger.warning("No token provided in auth message")
return None, None, "missing_token"
user_id, user = await get_user_from_token(token)
if user_id is None:
return None, None, "invalid_token"
return user_id, user, None
except asyncio.TimeoutError:
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
return None, None, "timeout"
except Exception as e:
logger.error("Error during WebSocket authentication: %s", e)
return None, None, "error"
async def get_unread_notifications(user_id: str) -> list[dict]:
"""Query all unread notifications for a user."""
db = database.SessionLocal()
try:
notifications = (
db.query(Notification)
.filter(Notification.user_id == user_id, Notification.is_read == False)
.order_by(Notification.created_at.desc())
.all()
)
return [
{
"id": n.id,
"type": n.type,
"reference_type": n.reference_type,
"reference_id": n.reference_id,
"title": n.title,
"message": n.message,
"is_read": n.is_read,
"created_at": n.created_at.isoformat() if n.created_at else None,
}
for n in notifications
]
finally:
db.close()
async def get_unread_count(user_id: str) -> int:
"""Get the count of unread notifications for a user."""
db = database.SessionLocal()
try:
return (
db.query(Notification)
.filter(Notification.user_id == user_id, Notification.is_read == False)
.count()
)
finally:
db.close()
@router.websocket("/ws/notifications")
async def websocket_notifications(
websocket: WebSocket,
token: Optional[str] = Query(None, description="JWT token (deprecated, use first-message auth)"),
):
"""
WebSocket endpoint for real-time notifications.
Authentication methods (in order of preference):
1. First message authentication (recommended):
- Connect without token: ws://host/ws/notifications
- Send: {"type": "auth", "token": "<jwt_token>"}
- Must authenticate within 10 seconds or connection will be closed
2. Query parameter (deprecated, for backward compatibility):
- Connect with: ws://host/ws/notifications?token=<jwt_token>
Messages sent by server:
- {"type": "auth_required"} - Sent when waiting for auth message
- {"type": "connected", "data": {"user_id": "...", "message": "..."}} - Connection success
- {"type": "unread_sync", "data": {"notifications": [...], "unread_count": N}} - All unread on connect
- {"type": "notification", "data": {...}} - New notification
- {"type": "unread_count", "data": {"unread_count": N}} - Unread count update
- {"type": "ping"} - Server keepalive ping
- {"type": "pong"} - Response to client ping
Messages accepted from client:
- {"type": "auth", "token": "..."} - Authentication (must be first message if no query token)
- {"type": "ping"} - Client keepalive ping
"""
# Accept WebSocket connection first
await websocket.accept()
# Authenticate
user_id, user, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
return
await manager.connect(websocket, user_id)
subscriber = NotificationSubscriber(user_id)
async def handle_redis_message(notification_data: dict):
"""Forward Redis pub/sub messages to WebSocket."""
try:
await websocket.send_json({
"type": "notification",
"data": notification_data,
})
# Also send updated unread count
unread_count = await get_unread_count(user_id)
await websocket.send_json({
"type": "unread_count",
"data": {"unread_count": unread_count},
})
except Exception as e:
logger.error(f"Error forwarding notification to WebSocket: {e}")
redis_task = None
try:
# Send initial connection success message
await websocket.send_json({
"type": "connected",
"data": {"user_id": user_id, "message": "Connected to notification service"},
})
# Send all unread notifications on connect (unread_sync)
unread_notifications = await get_unread_notifications(user_id)
await websocket.send_json({
"type": "unread_sync",
"data": {
"notifications": unread_notifications,
"unread_count": len(unread_notifications),
},
})
# Start Redis pub/sub subscription in background
await subscriber.start()
redis_task = asyncio.create_task(subscriber.listen(handle_redis_message))
# Heartbeat tracking
waiting_for_pong = False
ping_sent_at = 0.0
last_activity = time.time()
while True:
# Calculate appropriate timeout based on state
if waiting_for_pong:
# When waiting for pong, use remaining pong timeout
remaining = PONG_TIMEOUT - (time.time() - ping_sent_at)
if remaining <= 0:
logger.warning(f"Pong timeout for user {user_id}, disconnecting")
break
timeout = remaining
else:
# When not waiting, use remaining ping interval
remaining = PING_INTERVAL - (time.time() - last_activity)
if remaining <= 0:
# Time to send ping immediately
try:
await websocket.send_json({"type": "ping"})
waiting_for_pong = True
ping_sent_at = time.time()
last_activity = ping_sent_at
timeout = PONG_TIMEOUT
except Exception:
break
else:
timeout = remaining
try:
# Wait for messages from client
data = await asyncio.wait_for(
websocket.receive_json(),
timeout=timeout
)
last_activity = time.time()
msg_type = data.get("type")
# Handle ping message from client
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
# Handle pong message from client (response to our ping)
elif msg_type == "pong":
waiting_for_pong = False
logger.debug(f"Pong received from user {user_id}")
except asyncio.TimeoutError:
if waiting_for_pong:
# Strict timeout check
if time.time() - ping_sent_at >= PONG_TIMEOUT:
logger.warning(f"Pong timeout for user {user_id}, disconnecting")
break
# If not waiting_for_pong, loop will handle sending ping at top
except WebSocketDisconnect:
pass
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
# Clean up Redis subscription
if redis_task:
redis_task.cancel()
try:
await redis_task
except asyncio.CancelledError:
pass
await subscriber.stop()
await manager.disconnect(websocket, user_id)
async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Project | None]:
"""
Check if user has access to the project.
Args:
user_id: The user's ID
project_id: The project's ID
Returns:
Tuple of (has_access: bool, project: Project | None)
"""
db = database.SessionLocal()
try:
# Get the user
user = db.query(User).filter(User.id == user_id).first()
if user is None or not user.is_active:
return False, None
# Get the project
project = db.query(Project).filter(Project.id == project_id).first()
if project is None:
return False, None
# Check access using existing middleware function
has_access = check_project_access(user, project)
return has_access, project
finally:
db.close()
@router.websocket("/ws/projects/{project_id}")
async def websocket_project_sync(
websocket: WebSocket,
project_id: str,
token: Optional[str] = Query(None, description="JWT token (deprecated, use first-message auth)"),
):
"""
WebSocket endpoint for project task real-time sync.
Authentication methods (in order of preference):
1. First message authentication (recommended):
- Connect without token: ws://host/ws/projects/{project_id}
- Send: {"type": "auth", "token": "<jwt_token>"}
- Must authenticate within 10 seconds or connection will be closed
2. Query parameter (deprecated, for backward compatibility):
- Connect with: ws://host/ws/projects/{project_id}?token=<jwt_token>
Messages sent by server:
- {"type": "auth_required"} - Sent when waiting for auth message
- {"type": "connected", "data": {"project_id": "...", "user_id": "..."}}
- {"type": "task_created", "data": {...}, "triggered_by": "..."}
- {"type": "task_updated", "data": {...}, "triggered_by": "..."}
- {"type": "task_status_changed", "data": {...}, "triggered_by": "..."}
- {"type": "task_deleted", "data": {...}, "triggered_by": "..."}
- {"type": "task_assigned", "data": {...}, "triggered_by": "..."}
- {"type": "ping"} / {"type": "pong"}
Messages accepted from client:
- {"type": "auth", "token": "..."} - Authentication (must be first message if no query token)
- {"type": "ping"} - Client keepalive ping
"""
# Accept WebSocket connection first
await websocket.accept()
# Authenticate user
user_id, user, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
return
# Verify user has access to the project
has_access, project = await verify_project_access(user_id, project_id)
if not has_access:
await websocket.close(code=4003, reason="Access denied to this project")
return
if project is None:
await websocket.close(code=4004, reason="Project not found")
return
# Join project room
await manager.join_project(websocket, user_id, project_id)
# Create Redis subscriber for project task events
subscriber = ProjectTaskSubscriber(project_id)
async def handle_redis_message(event_data: dict):
"""Forward Redis pub/sub task events to WebSocket."""
try:
# Forward the event directly (it already contains type, data, triggered_by)
await websocket.send_json(event_data)
except Exception as e:
logger.error(f"Error forwarding task event to WebSocket: {e}")
redis_task = None
try:
# Send initial connection success message
await websocket.send_json({
"type": "connected",
"data": {
"project_id": project_id,
"user_id": user_id,
"project_title": project.title if project else None,
},
})
logger.info(f"User {user_id} connected to project {project_id} WebSocket")
# Start Redis pub/sub subscription in background
await subscriber.start()
redis_task = asyncio.create_task(subscriber.listen(handle_redis_message))
# Heartbeat tracking (reuse same configuration as notifications)
waiting_for_pong = False
ping_sent_at = 0.0
last_activity = time.time()
while True:
# Calculate appropriate timeout based on state
if waiting_for_pong:
# When waiting for pong, use remaining pong timeout
remaining = PONG_TIMEOUT - (time.time() - ping_sent_at)
if remaining <= 0:
logger.warning(f"Pong timeout for user {user_id} in project {project_id}, disconnecting")
break
timeout = remaining
else:
# When not waiting, use remaining ping interval
remaining = PING_INTERVAL - (time.time() - last_activity)
if remaining <= 0:
# Time to send ping immediately
try:
await websocket.send_json({"type": "ping"})
waiting_for_pong = True
ping_sent_at = time.time()
last_activity = ping_sent_at
timeout = PONG_TIMEOUT
except Exception:
break
else:
timeout = remaining
try:
# Wait for messages from client
data = await asyncio.wait_for(
websocket.receive_json(),
timeout=timeout
)
last_activity = time.time()
msg_type = data.get("type")
# Handle ping message from client
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
# Handle pong message from client (response to our ping)
elif msg_type == "pong":
waiting_for_pong = False
logger.debug(f"Pong received from user {user_id} in project {project_id}")
except asyncio.TimeoutError:
if waiting_for_pong:
# Strict timeout check
if time.time() - ping_sent_at >= PONG_TIMEOUT:
logger.warning(f"Pong timeout for user {user_id} in project {project_id}, disconnecting")
break
# If not waiting_for_pong, loop will handle sending ping at top
except WebSocketDisconnect:
logger.info(f"User {user_id} disconnected from project {project_id} WebSocket")
except Exception as e:
logger.error(f"WebSocket error for project {project_id}: {e}")
finally:
# Clean up Redis subscription
if redis_task:
redis_task.cancel()
try:
await redis_task
except asyncio.CancelledError:
pass
await subscriber.stop()
await manager.leave_project(websocket, user_id, project_id)
logger.info(f"User {user_id} left project {project_id} room")