Implemented proposals from comprehensive QA review: 1. extend-csrf-protection - Add POST to CSRF protected methods in frontend - Global CSRF middleware for all state-changing operations - Update tests with CSRF token fixtures 2. tighten-cors-websocket-security - Replace wildcard CORS with explicit method/header lists - Disable query parameter auth in production (code 4002) - Add per-user WebSocket connection limit (max 5, code 4005) 3. shorten-jwt-expiry - Reduce JWT expiry from 7 days to 60 minutes - Add refresh token support with 7-day expiry - Implement token rotation on refresh - Frontend auto-refresh when token near expiry (<5 min) 4. fix-frontend-quality - Add React.lazy() code splitting for all pages - Fix useCallback dependency arrays (Dashboard, Comments) - Add localStorage data validation in AuthContext - Complete i18n for AttachmentUpload component 5. enhance-backend-validation - Add SecurityAuditMiddleware for access denied logging - Add ErrorSanitizerMiddleware for production error messages - Protect /health/detailed with admin authentication - Add input length validation (comment 5000, desc 10000) All 521 backend tests passing. Frontend builds successfully. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
293 lines
8.9 KiB
Python
293 lines
8.9 KiB
Python
"""
|
|
CSRF (Cross-Site Request Forgery) Protection Middleware.
|
|
|
|
This module provides CSRF protection for all state-changing operations.
|
|
It validates CSRF tokens globally for authenticated POST, PUT, PATCH, DELETE requests.
|
|
"""
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse
|
|
from fastapi import HTTPException, status
|
|
from typing import Optional, Callable, List, Set
|
|
from functools import wraps
|
|
import logging
|
|
|
|
from app.core.security import validate_csrf_token, generate_csrf_token, decode_access_token
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Header name for CSRF token
|
|
CSRF_TOKEN_HEADER = "X-CSRF-Token"
|
|
|
|
# Methods that require CSRF protection (all state-changing operations)
|
|
CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
|
|
|
|
# Safe methods that don't require CSRF protection
|
|
CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
|
|
|
|
# Public endpoints that don't require CSRF validation
|
|
# These are endpoints that either:
|
|
# 1. Don't require authentication (login, health checks)
|
|
# 2. Are not state-changing in a security-sensitive way
|
|
CSRF_EXCLUDED_PATHS: Set[str] = {
|
|
# Authentication endpoints (unauthenticated)
|
|
"/api/auth/login",
|
|
"/api/v1/auth/login",
|
|
# Health check endpoints (unauthenticated)
|
|
"/health",
|
|
"/health/live",
|
|
"/health/ready",
|
|
"/health/detailed",
|
|
# WebSocket endpoints (use different auth mechanism)
|
|
"/api/ws",
|
|
"/ws",
|
|
}
|
|
|
|
# Path prefixes that are excluded from CSRF validation
|
|
CSRF_EXCLUDED_PREFIXES: List[str] = [
|
|
# WebSocket paths
|
|
"/api/ws/",
|
|
"/ws/",
|
|
]
|
|
|
|
|
|
class CSRFProtectionError(HTTPException):
|
|
"""Custom exception for CSRF validation failures."""
|
|
|
|
def __init__(self, detail: str = "CSRF validation failed"):
|
|
super().__init__(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=detail
|
|
)
|
|
|
|
|
|
class CSRFMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
Global CSRF protection middleware.
|
|
|
|
Validates CSRF tokens for all authenticated state-changing requests
|
|
(POST, PUT, PATCH, DELETE) except for explicitly excluded endpoints.
|
|
"""
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
"""Process the request and validate CSRF token if needed."""
|
|
method = request.method.upper()
|
|
path = request.url.path
|
|
|
|
# Skip CSRF validation for safe methods
|
|
if method in CSRF_SAFE_METHODS:
|
|
return await call_next(request)
|
|
|
|
# Skip CSRF validation for excluded paths
|
|
if self._is_excluded_path(path):
|
|
logger.debug("CSRF validation skipped for excluded path: %s", path)
|
|
return await call_next(request)
|
|
|
|
# Try to extract user ID from the Authorization header
|
|
user_id = self._extract_user_id_from_token(request)
|
|
|
|
# If no user ID (unauthenticated request), skip CSRF validation
|
|
# The authentication middleware will handle unauthorized access
|
|
if user_id is None:
|
|
logger.debug(
|
|
"CSRF validation skipped (no auth token): %s %s",
|
|
method, path
|
|
)
|
|
return await call_next(request)
|
|
|
|
# Get CSRF token from header
|
|
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
|
|
|
|
if not csrf_token:
|
|
logger.warning(
|
|
"CSRF validation failed: Missing token for user %s on %s %s",
|
|
user_id, method, path
|
|
)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
content={"detail": "CSRF token is required"}
|
|
)
|
|
|
|
# Validate the token
|
|
is_valid, error_message = validate_csrf_token(csrf_token, user_id)
|
|
|
|
if not is_valid:
|
|
logger.warning(
|
|
"CSRF validation failed for user %s on %s %s: %s",
|
|
user_id, method, path, error_message
|
|
)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
content={"detail": error_message}
|
|
)
|
|
|
|
logger.debug(
|
|
"CSRF validation passed for user %s on %s %s",
|
|
user_id, method, path
|
|
)
|
|
|
|
return await call_next(request)
|
|
|
|
def _is_excluded_path(self, path: str) -> bool:
|
|
"""Check if the path is excluded from CSRF validation."""
|
|
# Check exact path matches
|
|
if path in CSRF_EXCLUDED_PATHS:
|
|
return True
|
|
|
|
# Check path prefixes
|
|
for prefix in CSRF_EXCLUDED_PREFIXES:
|
|
if path.startswith(prefix):
|
|
return True
|
|
|
|
return False
|
|
|
|
def _extract_user_id_from_token(self, request: Request) -> Optional[str]:
|
|
"""
|
|
Extract user ID from the Authorization header.
|
|
|
|
Returns None if no valid token is found (unauthenticated request).
|
|
"""
|
|
auth_header = request.headers.get("Authorization")
|
|
if not auth_header:
|
|
return None
|
|
|
|
# Parse Bearer token
|
|
parts = auth_header.split()
|
|
if len(parts) != 2 or parts[0].lower() != "bearer":
|
|
return None
|
|
|
|
token = parts[1]
|
|
|
|
# Decode the token to get user ID
|
|
try:
|
|
payload = decode_access_token(token)
|
|
if payload is None:
|
|
return None
|
|
return payload.get("sub")
|
|
except Exception as e:
|
|
logger.debug("Failed to decode token for CSRF validation: %s", e)
|
|
return None
|
|
|
|
|
|
def require_csrf_token(func: Callable) -> Callable:
|
|
"""
|
|
Decorator to require CSRF token validation for an endpoint.
|
|
|
|
Usage:
|
|
@router.delete("/resource/{id}")
|
|
@require_csrf_token
|
|
async def delete_resource(request: Request, id: str, current_user: User = Depends(get_current_user)):
|
|
...
|
|
|
|
The decorator validates the X-CSRF-Token header against the current user.
|
|
"""
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
# Extract request and current_user from kwargs
|
|
request: Optional[Request] = kwargs.get("request")
|
|
current_user = kwargs.get("current_user")
|
|
|
|
if request is None:
|
|
# Try to find request in args (for methods where request is positional)
|
|
for arg in args:
|
|
if isinstance(arg, Request):
|
|
request = arg
|
|
break
|
|
|
|
if request is None:
|
|
logger.error("CSRF validation failed: Request object not found")
|
|
raise CSRFProtectionError("Internal error: Request not available")
|
|
|
|
if current_user is None:
|
|
logger.error("CSRF validation failed: User not authenticated")
|
|
raise CSRFProtectionError("Authentication required for CSRF-protected endpoint")
|
|
|
|
# Get CSRF token from header
|
|
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
|
|
|
|
if not csrf_token:
|
|
logger.warning(
|
|
"CSRF validation failed: Missing token for user %s on %s %s",
|
|
current_user.id, request.method, request.url.path
|
|
)
|
|
raise CSRFProtectionError("CSRF token is required")
|
|
|
|
# Validate the token
|
|
is_valid, error_message = validate_csrf_token(csrf_token, current_user.id)
|
|
|
|
if not is_valid:
|
|
logger.warning(
|
|
"CSRF validation failed for user %s on %s %s: %s",
|
|
current_user.id, request.method, request.url.path, error_message
|
|
)
|
|
raise CSRFProtectionError(error_message)
|
|
|
|
logger.debug(
|
|
"CSRF validation passed for user %s on %s %s",
|
|
current_user.id, request.method, request.url.path
|
|
)
|
|
|
|
return await func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def get_csrf_token_for_user(user_id: str) -> str:
|
|
"""
|
|
Generate a CSRF token for a user.
|
|
|
|
This function can be called from login endpoints to provide
|
|
the client with a CSRF token.
|
|
|
|
Args:
|
|
user_id: The user's ID
|
|
|
|
Returns:
|
|
CSRF token string
|
|
"""
|
|
return generate_csrf_token(user_id)
|
|
|
|
|
|
async def validate_csrf_for_request(
|
|
request: Request,
|
|
user_id: str,
|
|
skip_methods: Optional[List[str]] = None
|
|
) -> bool:
|
|
"""
|
|
Validate CSRF token for a request.
|
|
|
|
This is a utility function that can be used directly in endpoints
|
|
without the decorator.
|
|
|
|
Args:
|
|
request: The FastAPI request object
|
|
user_id: The current user's ID
|
|
skip_methods: HTTP methods to skip validation for (default: GET, HEAD, OPTIONS)
|
|
|
|
Returns:
|
|
True if validation passes
|
|
|
|
Raises:
|
|
CSRFProtectionError: If validation fails
|
|
"""
|
|
if skip_methods is None:
|
|
skip_methods = ["GET", "HEAD", "OPTIONS"]
|
|
|
|
# Skip validation for safe methods
|
|
if request.method.upper() in skip_methods:
|
|
return True
|
|
|
|
# Get CSRF token from header
|
|
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
|
|
|
|
if not csrf_token:
|
|
raise CSRFProtectionError("CSRF token is required")
|
|
|
|
is_valid, error_message = validate_csrf_token(csrf_token, user_id)
|
|
|
|
if not is_valid:
|
|
raise CSRFProtectionError(error_message)
|
|
|
|
return True
|