Files
PROJECT-CONTORL/backend/app/api/websocket/router.py
beabigegg 679b89ae4c feat: implement security, error resilience, and query optimization proposals
Security Validation (enhance-security-validation):
- JWT secret validation with entropy checking and pattern detection
- CSRF protection middleware with token generation/validation
- Frontend CSRF token auto-injection for DELETE/PUT/PATCH requests
- MIME type validation with magic bytes detection for file uploads

Error Resilience (add-error-resilience):
- React ErrorBoundary component with fallback UI and retry functionality
- ErrorBoundaryWithI18n wrapper for internationalization support
- Page-level and section-level error boundaries in App.tsx

Query Performance (optimize-query-performance):
- Query monitoring utility with threshold warnings
- N+1 query fixes using joinedload/selectinload
- Optimized project members, tasks, and subtasks endpoints

Bug Fixes:
- WebSocket session management (P0): Return primitives instead of ORM objects
- LIKE query injection (P1): Escape special characters in search queries

Tests: 543 backend tests, 56 frontend tests passing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-11 18:41:19 +08:00

517 lines
19 KiB
Python

import asyncio
import os
import logging
import time
from typing import Optional
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from sqlalchemy.orm import Session
from app.core import database
from app.core.security import decode_access_token
from app.core.redis import get_redis_sync
from app.models import User, Notification, Project
from app.services.websocket_manager import manager
from app.core.redis_pubsub import NotificationSubscriber, ProjectTaskSubscriber
from app.middleware.auth import check_project_access
logger = logging.getLogger(__name__)
router = APIRouter(tags=["websocket"])
# Heartbeat configuration
PING_INTERVAL = 60.0 # Send ping after this many seconds of no messages
PONG_TIMEOUT = 30.0 # Disconnect if no pong received within this time after ping
# Authentication timeout (10 seconds)
AUTH_TIMEOUT = 10.0
if os.getenv("TESTING") == "true":
AUTH_TIMEOUT = 1.0
async def get_user_from_token(token: str) -> str | None:
"""
Validate token and return user_id.
Returns:
user_id if valid, None otherwise.
Note: This function properly closes the database session after validation.
Do not return ORM objects as they become detached after session close.
"""
payload = decode_access_token(token)
if payload is None:
return None
user_id = payload.get("sub")
if user_id is None:
return None
# Verify session in Redis
redis_client = get_redis_sync()
stored_token = redis_client.get(f"session:{user_id}")
if stored_token is None or stored_token != token:
return None
# Verify user exists and is active
db = database.SessionLocal()
try:
user = db.query(User).filter(User.id == user_id).first()
if user is None or not user.is_active:
return None
return user_id
finally:
db.close()
async def authenticate_websocket(
websocket: WebSocket,
query_token: Optional[str] = None
) -> tuple[str | None, Optional[str]]:
"""
Authenticate WebSocket connection.
Supports two authentication methods:
1. First message authentication (preferred, more secure)
- Client sends: {"type": "auth", "token": "<jwt_token>"}
2. Query parameter authentication (deprecated, for backward compatibility)
- Client connects with: ?token=<jwt_token>
Returns:
Tuple of (user_id, error_reason). user_id is None if authentication fails.
"""
# If token provided via query parameter (backward compatibility)
if query_token:
logger.warning(
"WebSocket authentication via query parameter is deprecated. "
"Please use first-message authentication for better security."
)
user_id = await get_user_from_token(query_token)
if user_id is None:
return None, "invalid_token"
return user_id, None
# Wait for authentication message with timeout
try:
data = await asyncio.wait_for(
websocket.receive_json(),
timeout=AUTH_TIMEOUT
)
msg_type = data.get("type")
if msg_type != "auth":
logger.warning("Expected 'auth' message type, got: %s", msg_type)
return None, "invalid_message"
token = data.get("token")
if not token:
logger.warning("No token provided in auth message")
return None, "missing_token"
user_id = await get_user_from_token(token)
if user_id is None:
return None, "invalid_token"
return user_id, None
except asyncio.TimeoutError:
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
return None, "timeout"
except Exception as e:
logger.error("Error during WebSocket authentication: %s", e)
return None, "error"
async def get_unread_notifications(user_id: str) -> list[dict]:
"""Query all unread notifications for a user."""
db = database.SessionLocal()
try:
notifications = (
db.query(Notification)
.filter(Notification.user_id == user_id, Notification.is_read == False)
.order_by(Notification.created_at.desc())
.all()
)
return [
{
"id": n.id,
"type": n.type,
"reference_type": n.reference_type,
"reference_id": n.reference_id,
"title": n.title,
"message": n.message,
"is_read": n.is_read,
"created_at": n.created_at.isoformat() if n.created_at else None,
}
for n in notifications
]
finally:
db.close()
async def get_unread_count(user_id: str) -> int:
"""Get the count of unread notifications for a user."""
db = database.SessionLocal()
try:
return (
db.query(Notification)
.filter(Notification.user_id == user_id, Notification.is_read == False)
.count()
)
finally:
db.close()
@router.websocket("/ws/notifications")
async def websocket_notifications(
websocket: WebSocket,
token: Optional[str] = Query(None, description="JWT token (deprecated, use first-message auth)"),
):
"""
WebSocket endpoint for real-time notifications.
Authentication methods (in order of preference):
1. First message authentication (recommended):
- Connect without token: ws://host/ws/notifications
- Send: {"type": "auth", "token": "<jwt_token>"}
- Must authenticate within 10 seconds or connection will be closed
2. Query parameter (deprecated, for backward compatibility):
- Connect with: ws://host/ws/notifications?token=<jwt_token>
Messages sent by server:
- {"type": "connected", "data": {"user_id": "...", "message": "..."}} - Connection success
- {"type": "unread_sync", "data": {"notifications": [...], "unread_count": N}} - All unread on connect
- {"type": "notification", "data": {...}} - New notification
- {"type": "unread_count", "data": {"unread_count": N}} - Unread count update
- {"type": "ping"} - Server keepalive ping
- {"type": "pong"} - Response to client ping
Messages accepted from client:
- {"type": "auth", "token": "..."} - Authentication (must be first message if no query token)
- {"type": "ping"} - Client keepalive ping
"""
# Accept WebSocket connection first
await websocket.accept()
# Authenticate
user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
return
await manager.connect(websocket, user_id)
subscriber = NotificationSubscriber(user_id)
async def handle_redis_message(notification_data: dict):
"""Forward Redis pub/sub messages to WebSocket."""
try:
await websocket.send_json({
"type": "notification",
"data": notification_data,
})
# Also send updated unread count
unread_count = await get_unread_count(user_id)
await websocket.send_json({
"type": "unread_count",
"data": {"unread_count": unread_count},
})
except Exception as e:
logger.error(f"Error forwarding notification to WebSocket: {e}")
redis_task = None
try:
# Send initial connection success message
await websocket.send_json({
"type": "connected",
"data": {"user_id": user_id, "message": "Connected to notification service"},
})
# Send all unread notifications on connect (unread_sync)
unread_notifications = await get_unread_notifications(user_id)
await websocket.send_json({
"type": "unread_sync",
"data": {
"notifications": unread_notifications,
"unread_count": len(unread_notifications),
},
})
# Start Redis pub/sub subscription in background
await subscriber.start()
redis_task = asyncio.create_task(subscriber.listen(handle_redis_message))
# Heartbeat tracking
waiting_for_pong = False
ping_sent_at = 0.0
last_activity = time.time()
while True:
# Calculate appropriate timeout based on state
if waiting_for_pong:
# When waiting for pong, use remaining pong timeout
remaining = PONG_TIMEOUT - (time.time() - ping_sent_at)
if remaining <= 0:
logger.warning(f"Pong timeout for user {user_id}, disconnecting")
break
timeout = remaining
else:
# When not waiting, use remaining ping interval
remaining = PING_INTERVAL - (time.time() - last_activity)
if remaining <= 0:
# Time to send ping immediately
try:
await websocket.send_json({"type": "ping"})
waiting_for_pong = True
ping_sent_at = time.time()
last_activity = ping_sent_at
timeout = PONG_TIMEOUT
except Exception:
break
else:
timeout = remaining
try:
# Wait for messages from client
data = await asyncio.wait_for(
websocket.receive_json(),
timeout=timeout
)
last_activity = time.time()
msg_type = data.get("type")
# Handle ping message from client
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
# Handle pong message from client (response to our ping)
elif msg_type == "pong":
waiting_for_pong = False
logger.debug(f"Pong received from user {user_id}")
except asyncio.TimeoutError:
if waiting_for_pong:
# Strict timeout check
if time.time() - ping_sent_at >= PONG_TIMEOUT:
logger.warning(f"Pong timeout for user {user_id}, disconnecting")
break
# If not waiting_for_pong, loop will handle sending ping at top
except WebSocketDisconnect:
pass
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
# Clean up Redis subscription
if redis_task:
redis_task.cancel()
try:
await redis_task
except asyncio.CancelledError:
pass
await subscriber.stop()
await manager.disconnect(websocket, user_id)
async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, str | None, str | None]:
"""
Check if user has access to the project.
Args:
user_id: The user's ID
project_id: The project's ID
Returns:
Tuple of (has_access: bool, project_title: str | None, error: str | None)
- has_access: True if user can access the project
- project_title: The project title (only if access granted)
- error: Error code if access denied ("user_not_found", "project_not_found", "access_denied")
Note: This function extracts needed data before closing the session to avoid
detached instance errors when accessing ORM object attributes.
"""
db = database.SessionLocal()
try:
# Get the user
user = db.query(User).filter(User.id == user_id).first()
if user is None or not user.is_active:
return False, None, "user_not_found"
# Get the project
project = db.query(Project).filter(Project.id == project_id).first()
if project is None:
return False, None, "project_not_found"
# Check access using existing middleware function
has_access = check_project_access(user, project)
if not has_access:
return False, None, "access_denied"
# Extract title while session is still open
project_title = project.title
return True, project_title, None
finally:
db.close()
@router.websocket("/ws/projects/{project_id}")
async def websocket_project_sync(
websocket: WebSocket,
project_id: str,
token: Optional[str] = Query(None, description="JWT token (deprecated, use first-message auth)"),
):
"""
WebSocket endpoint for project task real-time sync.
Authentication methods (in order of preference):
1. First message authentication (recommended):
- Connect without token: ws://host/ws/projects/{project_id}
- Send: {"type": "auth", "token": "<jwt_token>"}
- Must authenticate within 10 seconds or connection will be closed
2. Query parameter (deprecated, for backward compatibility):
- Connect with: ws://host/ws/projects/{project_id}?token=<jwt_token>
Messages sent by server:
- {"type": "connected", "data": {"project_id": "...", "user_id": "..."}}
- {"type": "task_created", "data": {...}, "triggered_by": "..."}
- {"type": "task_updated", "data": {...}, "triggered_by": "..."}
- {"type": "task_status_changed", "data": {...}, "triggered_by": "..."}
- {"type": "task_deleted", "data": {...}, "triggered_by": "..."}
- {"type": "task_assigned", "data": {...}, "triggered_by": "..."}
- {"type": "ping"} / {"type": "pong"}
Messages accepted from client:
- {"type": "auth", "token": "..."} - Authentication (must be first message if no query token)
- {"type": "ping"} - Client keepalive ping
"""
# Accept WebSocket connection first
await websocket.accept()
# Authenticate user
user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
return
# Verify user has access to the project
has_access, project_title, access_error = await verify_project_access(user_id, project_id)
if not has_access:
if access_error == "project_not_found":
await websocket.close(code=4004, reason="Project not found")
else:
await websocket.close(code=4003, reason="Access denied to this project")
return
# Join project room
await manager.join_project(websocket, user_id, project_id)
# Create Redis subscriber for project task events
subscriber = ProjectTaskSubscriber(project_id)
async def handle_redis_message(event_data: dict):
"""Forward Redis pub/sub task events to WebSocket."""
try:
# Forward the event directly (it already contains type, data, triggered_by)
await websocket.send_json(event_data)
except Exception as e:
logger.error(f"Error forwarding task event to WebSocket: {e}")
redis_task = None
try:
# Send initial connection success message
await websocket.send_json({
"type": "connected",
"data": {
"project_id": project_id,
"user_id": user_id,
"project_title": project_title,
},
})
logger.info(f"User {user_id} connected to project {project_id} WebSocket")
# Start Redis pub/sub subscription in background
await subscriber.start()
redis_task = asyncio.create_task(subscriber.listen(handle_redis_message))
# Heartbeat tracking (reuse same configuration as notifications)
waiting_for_pong = False
ping_sent_at = 0.0
last_activity = time.time()
while True:
# Calculate appropriate timeout based on state
if waiting_for_pong:
# When waiting for pong, use remaining pong timeout
remaining = PONG_TIMEOUT - (time.time() - ping_sent_at)
if remaining <= 0:
logger.warning(f"Pong timeout for user {user_id} in project {project_id}, disconnecting")
break
timeout = remaining
else:
# When not waiting, use remaining ping interval
remaining = PING_INTERVAL - (time.time() - last_activity)
if remaining <= 0:
# Time to send ping immediately
try:
await websocket.send_json({"type": "ping"})
waiting_for_pong = True
ping_sent_at = time.time()
last_activity = ping_sent_at
timeout = PONG_TIMEOUT
except Exception:
break
else:
timeout = remaining
try:
# Wait for messages from client
data = await asyncio.wait_for(
websocket.receive_json(),
timeout=timeout
)
last_activity = time.time()
msg_type = data.get("type")
# Handle ping message from client
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
# Handle pong message from client (response to our ping)
elif msg_type == "pong":
waiting_for_pong = False
logger.debug(f"Pong received from user {user_id} in project {project_id}")
except asyncio.TimeoutError:
if waiting_for_pong:
# Strict timeout check
if time.time() - ping_sent_at >= PONG_TIMEOUT:
logger.warning(f"Pong timeout for user {user_id} in project {project_id}, disconnecting")
break
# If not waiting_for_pong, loop will handle sending ping at top
except WebSocketDisconnect:
logger.info(f"User {user_id} disconnected from project {project_id} WebSocket")
except Exception as e:
logger.error(f"WebSocket error for project {project_id}: {e}")
finally:
# Clean up Redis subscription
if redis_task:
redis_task.cancel()
try:
await redis_task
except asyncio.CancelledError:
pass
await subscriber.stop()
await manager.leave_project(websocket, user_id, project_id)
logger.info(f"User {user_id} left project {project_id} room")