Files
PROJECT-CONTORL/backend/app/api/websocket/router.py
beabigegg 35c90fe76b feat: implement 5 QA-driven security and quality proposals
Implemented proposals from comprehensive QA review:

1. extend-csrf-protection
   - Add POST to CSRF protected methods in frontend
   - Global CSRF middleware for all state-changing operations
   - Update tests with CSRF token fixtures

2. tighten-cors-websocket-security
   - Replace wildcard CORS with explicit method/header lists
   - Disable query parameter auth in production (code 4002)
   - Add per-user WebSocket connection limit (max 5, code 4005)

3. shorten-jwt-expiry
   - Reduce JWT expiry from 7 days to 60 minutes
   - Add refresh token support with 7-day expiry
   - Implement token rotation on refresh
   - Frontend auto-refresh when token near expiry (<5 min)

4. fix-frontend-quality
   - Add React.lazy() code splitting for all pages
   - Fix useCallback dependency arrays (Dashboard, Comments)
   - Add localStorage data validation in AuthContext
   - Complete i18n for AttachmentUpload component

5. enhance-backend-validation
   - Add SecurityAuditMiddleware for access denied logging
   - Add ErrorSanitizerMiddleware for production error messages
   - Protect /health/detailed with admin authentication
   - Add input length validation (comment 5000, desc 10000)

All 521 backend tests passing. Frontend builds successfully.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 23:19:05 +08:00

552 lines
20 KiB
Python

import asyncio
import os
import logging
import time
from typing import Optional
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from sqlalchemy.orm import Session
from app.core import database
from app.core.security import decode_access_token
from app.core.redis import get_redis_sync
from app.core.config import settings
from app.models import User, Notification, Project
from app.services.websocket_manager import manager
from app.core.redis_pubsub import NotificationSubscriber, ProjectTaskSubscriber
from app.middleware.auth import check_project_access
logger = logging.getLogger(__name__)
router = APIRouter(tags=["websocket"])
# Heartbeat configuration
PING_INTERVAL = 60.0 # Send ping after this many seconds of no messages
PONG_TIMEOUT = 30.0 # Disconnect if no pong received within this time after ping
# Authentication timeout (10 seconds)
AUTH_TIMEOUT = 10.0
if os.getenv("TESTING") == "true":
AUTH_TIMEOUT = 1.0
async def get_user_from_token(token: str) -> str | None:
"""
Validate token and return user_id.
Returns:
user_id if valid, None otherwise.
Note: This function properly closes the database session after validation.
Do not return ORM objects as they become detached after session close.
"""
payload = decode_access_token(token)
if payload is None:
return None
user_id = payload.get("sub")
if user_id is None:
return None
# Verify session in Redis
redis_client = get_redis_sync()
stored_token = redis_client.get(f"session:{user_id}")
if stored_token is None or stored_token != token:
return None
# Verify user exists and is active
db = database.SessionLocal()
try:
user = db.query(User).filter(User.id == user_id).first()
if user is None or not user.is_active:
return None
return user_id
finally:
db.close()
async def authenticate_websocket(
websocket: WebSocket,
query_token: Optional[str] = None
) -> tuple[str | None, Optional[str]]:
"""
Authenticate WebSocket connection.
Supports two authentication methods:
1. First message authentication (preferred, more secure)
- Client sends: {"type": "auth", "token": "<jwt_token>"}
2. Query parameter authentication (disabled in production, for backward compatibility only)
- Client connects with: ?token=<jwt_token>
Returns:
Tuple of (user_id, error_reason). user_id is None if authentication fails.
Error reasons: "invalid_token", "invalid_message", "missing_token",
"timeout", "error", "query_auth_disabled"
"""
# If token provided via query parameter (backward compatibility)
if query_token:
# Reject query parameter auth in production for security
if settings.ENVIRONMENT == "production":
logger.warning(
"WebSocket query parameter authentication attempted in production environment. "
"This is disabled for security reasons."
)
return None, "query_auth_disabled"
logger.warning(
"WebSocket authentication via query parameter is deprecated. "
"Please use first-message authentication for better security."
)
user_id = await get_user_from_token(query_token)
if user_id is None:
return None, "invalid_token"
return user_id, None
# Wait for authentication message with timeout
try:
data = await asyncio.wait_for(
websocket.receive_json(),
timeout=AUTH_TIMEOUT
)
msg_type = data.get("type")
if msg_type != "auth":
logger.warning("Expected 'auth' message type, got: %s", msg_type)
return None, "invalid_message"
token = data.get("token")
if not token:
logger.warning("No token provided in auth message")
return None, "missing_token"
user_id = await get_user_from_token(token)
if user_id is None:
return None, "invalid_token"
return user_id, None
except asyncio.TimeoutError:
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
return None, "timeout"
except Exception as e:
logger.error("Error during WebSocket authentication: %s", e)
return None, "error"
async def get_unread_notifications(user_id: str) -> list[dict]:
"""Query all unread notifications for a user."""
db = database.SessionLocal()
try:
notifications = (
db.query(Notification)
.filter(Notification.user_id == user_id, Notification.is_read == False)
.order_by(Notification.created_at.desc())
.all()
)
return [
{
"id": n.id,
"type": n.type,
"reference_type": n.reference_type,
"reference_id": n.reference_id,
"title": n.title,
"message": n.message,
"is_read": n.is_read,
"created_at": n.created_at.isoformat() if n.created_at else None,
}
for n in notifications
]
finally:
db.close()
async def get_unread_count(user_id: str) -> int:
"""Get the count of unread notifications for a user."""
db = database.SessionLocal()
try:
return (
db.query(Notification)
.filter(Notification.user_id == user_id, Notification.is_read == False)
.count()
)
finally:
db.close()
@router.websocket("/ws/notifications")
async def websocket_notifications(
websocket: WebSocket,
token: Optional[str] = Query(None, description="JWT token (deprecated, use first-message auth)"),
):
"""
WebSocket endpoint for real-time notifications.
Authentication methods (in order of preference):
1. First message authentication (recommended):
- Connect without token: ws://host/ws/notifications
- Send: {"type": "auth", "token": "<jwt_token>"}
- Must authenticate within 10 seconds or connection will be closed
2. Query parameter (deprecated, for backward compatibility):
- Connect with: ws://host/ws/notifications?token=<jwt_token>
Messages sent by server:
- {"type": "connected", "data": {"user_id": "...", "message": "..."}} - Connection success
- {"type": "unread_sync", "data": {"notifications": [...], "unread_count": N}} - All unread on connect
- {"type": "notification", "data": {...}} - New notification
- {"type": "unread_count", "data": {"unread_count": N}} - Unread count update
- {"type": "ping"} - Server keepalive ping
- {"type": "pong"} - Response to client ping
Messages accepted from client:
- {"type": "auth", "token": "..."} - Authentication (must be first message if no query token)
- {"type": "ping"} - Client keepalive ping
"""
# Accept WebSocket connection first
await websocket.accept()
# Authenticate
user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "query_auth_disabled":
await websocket.send_json({"type": "error", "message": "Query parameter auth disabled in production"})
await websocket.close(code=4002, reason="Query parameter auth disabled in production")
elif error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
else:
await websocket.close(code=4001, reason="Invalid or expired token")
return
# Check connection limit before accepting
can_connect, reject_reason = await manager.check_connection_limit(user_id)
if not can_connect:
await websocket.send_json({"type": "error", "message": reject_reason})
await websocket.close(code=4005, reason=reject_reason)
return
await manager.connect(websocket, user_id)
subscriber = NotificationSubscriber(user_id)
async def handle_redis_message(notification_data: dict):
"""Forward Redis pub/sub messages to WebSocket."""
try:
await websocket.send_json({
"type": "notification",
"data": notification_data,
})
# Also send updated unread count
unread_count = await get_unread_count(user_id)
await websocket.send_json({
"type": "unread_count",
"data": {"unread_count": unread_count},
})
except Exception as e:
logger.error(f"Error forwarding notification to WebSocket: {e}")
redis_task = None
try:
# Send initial connection success message
await websocket.send_json({
"type": "connected",
"data": {"user_id": user_id, "message": "Connected to notification service"},
})
# Send all unread notifications on connect (unread_sync)
unread_notifications = await get_unread_notifications(user_id)
await websocket.send_json({
"type": "unread_sync",
"data": {
"notifications": unread_notifications,
"unread_count": len(unread_notifications),
},
})
# Start Redis pub/sub subscription in background
await subscriber.start()
redis_task = asyncio.create_task(subscriber.listen(handle_redis_message))
# Heartbeat tracking
waiting_for_pong = False
ping_sent_at = 0.0
last_activity = time.time()
while True:
# Calculate appropriate timeout based on state
if waiting_for_pong:
# When waiting for pong, use remaining pong timeout
remaining = PONG_TIMEOUT - (time.time() - ping_sent_at)
if remaining <= 0:
logger.warning(f"Pong timeout for user {user_id}, disconnecting")
break
timeout = remaining
else:
# When not waiting, use remaining ping interval
remaining = PING_INTERVAL - (time.time() - last_activity)
if remaining <= 0:
# Time to send ping immediately
try:
await websocket.send_json({"type": "ping"})
waiting_for_pong = True
ping_sent_at = time.time()
last_activity = ping_sent_at
timeout = PONG_TIMEOUT
except Exception:
break
else:
timeout = remaining
try:
# Wait for messages from client
data = await asyncio.wait_for(
websocket.receive_json(),
timeout=timeout
)
last_activity = time.time()
msg_type = data.get("type")
# Handle ping message from client
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
# Handle pong message from client (response to our ping)
elif msg_type == "pong":
waiting_for_pong = False
logger.debug(f"Pong received from user {user_id}")
except asyncio.TimeoutError:
if waiting_for_pong:
# Strict timeout check
if time.time() - ping_sent_at >= PONG_TIMEOUT:
logger.warning(f"Pong timeout for user {user_id}, disconnecting")
break
# If not waiting_for_pong, loop will handle sending ping at top
except WebSocketDisconnect:
pass
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
# Clean up Redis subscription
if redis_task:
redis_task.cancel()
try:
await redis_task
except asyncio.CancelledError:
pass
await subscriber.stop()
await manager.disconnect(websocket, user_id)
async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, str | None, str | None]:
"""
Check if user has access to the project.
Args:
user_id: The user's ID
project_id: The project's ID
Returns:
Tuple of (has_access: bool, project_title: str | None, error: str | None)
- has_access: True if user can access the project
- project_title: The project title (only if access granted)
- error: Error code if access denied ("user_not_found", "project_not_found", "access_denied")
Note: This function extracts needed data before closing the session to avoid
detached instance errors when accessing ORM object attributes.
"""
db = database.SessionLocal()
try:
# Get the user
user = db.query(User).filter(User.id == user_id).first()
if user is None or not user.is_active:
return False, None, "user_not_found"
# Get the project
project = db.query(Project).filter(Project.id == project_id).first()
if project is None:
return False, None, "project_not_found"
# Check access using existing middleware function
has_access = check_project_access(user, project)
if not has_access:
return False, None, "access_denied"
# Extract title while session is still open
project_title = project.title
return True, project_title, None
finally:
db.close()
@router.websocket("/ws/projects/{project_id}")
async def websocket_project_sync(
websocket: WebSocket,
project_id: str,
token: Optional[str] = Query(None, description="JWT token (deprecated, use first-message auth)"),
):
"""
WebSocket endpoint for project task real-time sync.
Authentication methods (in order of preference):
1. First message authentication (recommended):
- Connect without token: ws://host/ws/projects/{project_id}
- Send: {"type": "auth", "token": "<jwt_token>"}
- Must authenticate within 10 seconds or connection will be closed
2. Query parameter (deprecated, for backward compatibility):
- Connect with: ws://host/ws/projects/{project_id}?token=<jwt_token>
Messages sent by server:
- {"type": "connected", "data": {"project_id": "...", "user_id": "..."}}
- {"type": "task_created", "data": {...}, "triggered_by": "..."}
- {"type": "task_updated", "data": {...}, "triggered_by": "..."}
- {"type": "task_status_changed", "data": {...}, "triggered_by": "..."}
- {"type": "task_deleted", "data": {...}, "triggered_by": "..."}
- {"type": "task_assigned", "data": {...}, "triggered_by": "..."}
- {"type": "ping"} / {"type": "pong"}
Messages accepted from client:
- {"type": "auth", "token": "..."} - Authentication (must be first message if no query token)
- {"type": "ping"} - Client keepalive ping
"""
# Accept WebSocket connection first
await websocket.accept()
# Authenticate user
user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "query_auth_disabled":
await websocket.send_json({"type": "error", "message": "Query parameter auth disabled in production"})
await websocket.close(code=4002, reason="Query parameter auth disabled in production")
elif error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
else:
await websocket.close(code=4001, reason="Invalid or expired token")
return
# Check connection limit before accepting
can_connect, reject_reason = await manager.check_connection_limit(user_id)
if not can_connect:
await websocket.send_json({"type": "error", "message": reject_reason})
await websocket.close(code=4005, reason=reject_reason)
return
# Verify user has access to the project
has_access, project_title, access_error = await verify_project_access(user_id, project_id)
if not has_access:
if access_error == "project_not_found":
await websocket.close(code=4004, reason="Project not found")
else:
await websocket.close(code=4003, reason="Access denied to this project")
return
# Join project room
await manager.join_project(websocket, user_id, project_id)
# Create Redis subscriber for project task events
subscriber = ProjectTaskSubscriber(project_id)
async def handle_redis_message(event_data: dict):
"""Forward Redis pub/sub task events to WebSocket."""
try:
# Forward the event directly (it already contains type, data, triggered_by)
await websocket.send_json(event_data)
except Exception as e:
logger.error(f"Error forwarding task event to WebSocket: {e}")
redis_task = None
try:
# Send initial connection success message
await websocket.send_json({
"type": "connected",
"data": {
"project_id": project_id,
"user_id": user_id,
"project_title": project_title,
},
})
logger.info(f"User {user_id} connected to project {project_id} WebSocket")
# Start Redis pub/sub subscription in background
await subscriber.start()
redis_task = asyncio.create_task(subscriber.listen(handle_redis_message))
# Heartbeat tracking (reuse same configuration as notifications)
waiting_for_pong = False
ping_sent_at = 0.0
last_activity = time.time()
while True:
# Calculate appropriate timeout based on state
if waiting_for_pong:
# When waiting for pong, use remaining pong timeout
remaining = PONG_TIMEOUT - (time.time() - ping_sent_at)
if remaining <= 0:
logger.warning(f"Pong timeout for user {user_id} in project {project_id}, disconnecting")
break
timeout = remaining
else:
# When not waiting, use remaining ping interval
remaining = PING_INTERVAL - (time.time() - last_activity)
if remaining <= 0:
# Time to send ping immediately
try:
await websocket.send_json({"type": "ping"})
waiting_for_pong = True
ping_sent_at = time.time()
last_activity = ping_sent_at
timeout = PONG_TIMEOUT
except Exception:
break
else:
timeout = remaining
try:
# Wait for messages from client
data = await asyncio.wait_for(
websocket.receive_json(),
timeout=timeout
)
last_activity = time.time()
msg_type = data.get("type")
# Handle ping message from client
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
# Handle pong message from client (response to our ping)
elif msg_type == "pong":
waiting_for_pong = False
logger.debug(f"Pong received from user {user_id} in project {project_id}")
except asyncio.TimeoutError:
if waiting_for_pong:
# Strict timeout check
if time.time() - ping_sent_at >= PONG_TIMEOUT:
logger.warning(f"Pong timeout for user {user_id} in project {project_id}, disconnecting")
break
# If not waiting_for_pong, loop will handle sending ping at top
except WebSocketDisconnect:
logger.info(f"User {user_id} disconnected from project {project_id} WebSocket")
except Exception as e:
logger.error(f"WebSocket error for project {project_id}: {e}")
finally:
# Clean up Redis subscription
if redis_task:
redis_task.cancel()
try:
await redis_task
except asyncio.CancelledError:
pass
await subscriber.stop()
await manager.leave_project(websocket, user_id, project_id)
logger.info(f"User {user_id} left project {project_id} room")