Files
PROJECT-CONTORL/backend/app/middleware/csrf.py
beabigegg 35c90fe76b feat: implement 5 QA-driven security and quality proposals
Implemented proposals from comprehensive QA review:

1. extend-csrf-protection
   - Add POST to CSRF protected methods in frontend
   - Global CSRF middleware for all state-changing operations
   - Update tests with CSRF token fixtures

2. tighten-cors-websocket-security
   - Replace wildcard CORS with explicit method/header lists
   - Disable query parameter auth in production (code 4002)
   - Add per-user WebSocket connection limit (max 5, code 4005)

3. shorten-jwt-expiry
   - Reduce JWT expiry from 7 days to 60 minutes
   - Add refresh token support with 7-day expiry
   - Implement token rotation on refresh
   - Frontend auto-refresh when token near expiry (<5 min)

4. fix-frontend-quality
   - Add React.lazy() code splitting for all pages
   - Fix useCallback dependency arrays (Dashboard, Comments)
   - Add localStorage data validation in AuthContext
   - Complete i18n for AttachmentUpload component

5. enhance-backend-validation
   - Add SecurityAuditMiddleware for access denied logging
   - Add ErrorSanitizerMiddleware for production error messages
   - Protect /health/detailed with admin authentication
   - Add input length validation (comment 5000, desc 10000)

All 521 backend tests passing. Frontend builds successfully.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 23:19:05 +08:00

293 lines
8.9 KiB
Python

"""
CSRF (Cross-Site Request Forgery) Protection Middleware.
This module provides CSRF protection for all state-changing operations.
It validates CSRF tokens globally for authenticated POST, PUT, PATCH, DELETE requests.
"""
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from fastapi import HTTPException, status
from typing import Optional, Callable, List, Set
from functools import wraps
import logging
from app.core.security import validate_csrf_token, generate_csrf_token, decode_access_token
logger = logging.getLogger(__name__)
# Header name for CSRF token
CSRF_TOKEN_HEADER = "X-CSRF-Token"
# Methods that require CSRF protection (all state-changing operations)
CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
# Safe methods that don't require CSRF protection
CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
# Public endpoints that don't require CSRF validation
# These are endpoints that either:
# 1. Don't require authentication (login, health checks)
# 2. Are not state-changing in a security-sensitive way
CSRF_EXCLUDED_PATHS: Set[str] = {
# Authentication endpoints (unauthenticated)
"/api/auth/login",
"/api/v1/auth/login",
# Health check endpoints (unauthenticated)
"/health",
"/health/live",
"/health/ready",
"/health/detailed",
# WebSocket endpoints (use different auth mechanism)
"/api/ws",
"/ws",
}
# Path prefixes that are excluded from CSRF validation
CSRF_EXCLUDED_PREFIXES: List[str] = [
# WebSocket paths
"/api/ws/",
"/ws/",
]
class CSRFProtectionError(HTTPException):
"""Custom exception for CSRF validation failures."""
def __init__(self, detail: str = "CSRF validation failed"):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail
)
class CSRFMiddleware(BaseHTTPMiddleware):
"""
Global CSRF protection middleware.
Validates CSRF tokens for all authenticated state-changing requests
(POST, PUT, PATCH, DELETE) except for explicitly excluded endpoints.
"""
async def dispatch(self, request: Request, call_next):
"""Process the request and validate CSRF token if needed."""
method = request.method.upper()
path = request.url.path
# Skip CSRF validation for safe methods
if method in CSRF_SAFE_METHODS:
return await call_next(request)
# Skip CSRF validation for excluded paths
if self._is_excluded_path(path):
logger.debug("CSRF validation skipped for excluded path: %s", path)
return await call_next(request)
# Try to extract user ID from the Authorization header
user_id = self._extract_user_id_from_token(request)
# If no user ID (unauthenticated request), skip CSRF validation
# The authentication middleware will handle unauthorized access
if user_id is None:
logger.debug(
"CSRF validation skipped (no auth token): %s %s",
method, path
)
return await call_next(request)
# Get CSRF token from header
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
if not csrf_token:
logger.warning(
"CSRF validation failed: Missing token for user %s on %s %s",
user_id, method, path
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": "CSRF token is required"}
)
# Validate the token
is_valid, error_message = validate_csrf_token(csrf_token, user_id)
if not is_valid:
logger.warning(
"CSRF validation failed for user %s on %s %s: %s",
user_id, method, path, error_message
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": error_message}
)
logger.debug(
"CSRF validation passed for user %s on %s %s",
user_id, method, path
)
return await call_next(request)
def _is_excluded_path(self, path: str) -> bool:
"""Check if the path is excluded from CSRF validation."""
# Check exact path matches
if path in CSRF_EXCLUDED_PATHS:
return True
# Check path prefixes
for prefix in CSRF_EXCLUDED_PREFIXES:
if path.startswith(prefix):
return True
return False
def _extract_user_id_from_token(self, request: Request) -> Optional[str]:
"""
Extract user ID from the Authorization header.
Returns None if no valid token is found (unauthenticated request).
"""
auth_header = request.headers.get("Authorization")
if not auth_header:
return None
# Parse Bearer token
parts = auth_header.split()
if len(parts) != 2 or parts[0].lower() != "bearer":
return None
token = parts[1]
# Decode the token to get user ID
try:
payload = decode_access_token(token)
if payload is None:
return None
return payload.get("sub")
except Exception as e:
logger.debug("Failed to decode token for CSRF validation: %s", e)
return None
def require_csrf_token(func: Callable) -> Callable:
"""
Decorator to require CSRF token validation for an endpoint.
Usage:
@router.delete("/resource/{id}")
@require_csrf_token
async def delete_resource(request: Request, id: str, current_user: User = Depends(get_current_user)):
...
The decorator validates the X-CSRF-Token header against the current user.
"""
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract request and current_user from kwargs
request: Optional[Request] = kwargs.get("request")
current_user = kwargs.get("current_user")
if request is None:
# Try to find request in args (for methods where request is positional)
for arg in args:
if isinstance(arg, Request):
request = arg
break
if request is None:
logger.error("CSRF validation failed: Request object not found")
raise CSRFProtectionError("Internal error: Request not available")
if current_user is None:
logger.error("CSRF validation failed: User not authenticated")
raise CSRFProtectionError("Authentication required for CSRF-protected endpoint")
# Get CSRF token from header
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
if not csrf_token:
logger.warning(
"CSRF validation failed: Missing token for user %s on %s %s",
current_user.id, request.method, request.url.path
)
raise CSRFProtectionError("CSRF token is required")
# Validate the token
is_valid, error_message = validate_csrf_token(csrf_token, current_user.id)
if not is_valid:
logger.warning(
"CSRF validation failed for user %s on %s %s: %s",
current_user.id, request.method, request.url.path, error_message
)
raise CSRFProtectionError(error_message)
logger.debug(
"CSRF validation passed for user %s on %s %s",
current_user.id, request.method, request.url.path
)
return await func(*args, **kwargs)
return wrapper
def get_csrf_token_for_user(user_id: str) -> str:
"""
Generate a CSRF token for a user.
This function can be called from login endpoints to provide
the client with a CSRF token.
Args:
user_id: The user's ID
Returns:
CSRF token string
"""
return generate_csrf_token(user_id)
async def validate_csrf_for_request(
request: Request,
user_id: str,
skip_methods: Optional[List[str]] = None
) -> bool:
"""
Validate CSRF token for a request.
This is a utility function that can be used directly in endpoints
without the decorator.
Args:
request: The FastAPI request object
user_id: The current user's ID
skip_methods: HTTP methods to skip validation for (default: GET, HEAD, OPTIONS)
Returns:
True if validation passes
Raises:
CSRFProtectionError: If validation fails
"""
if skip_methods is None:
skip_methods = ["GET", "HEAD", "OPTIONS"]
# Skip validation for safe methods
if request.method.upper() in skip_methods:
return True
# Get CSRF token from header
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
if not csrf_token:
raise CSRFProtectionError("CSRF token is required")
is_valid, error_message = validate_csrf_token(csrf_token, user_id)
if not is_valid:
raise CSRFProtectionError(error_message)
return True