""" CSRF (Cross-Site Request Forgery) Protection Middleware. This module provides CSRF protection for all state-changing operations. It validates CSRF tokens globally for authenticated POST, PUT, PATCH, DELETE requests. """ from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse from fastapi import HTTPException, status from typing import Optional, Callable, List, Set from functools import wraps import logging from app.core.security import validate_csrf_token, generate_csrf_token, decode_access_token logger = logging.getLogger(__name__) # Header name for CSRF token CSRF_TOKEN_HEADER = "X-CSRF-Token" # Methods that require CSRF protection (all state-changing operations) CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"} # Safe methods that don't require CSRF protection CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS"} # Public endpoints that don't require CSRF validation # These are endpoints that either: # 1. Don't require authentication (login, health checks) # 2. Are not state-changing in a security-sensitive way CSRF_EXCLUDED_PATHS: Set[str] = { # Authentication endpoints (unauthenticated) "/api/auth/login", "/api/v1/auth/login", # Health check endpoints (unauthenticated) "/health", "/health/live", "/health/ready", "/health/detailed", # WebSocket endpoints (use different auth mechanism) "/api/ws", "/ws", } # Path prefixes that are excluded from CSRF validation CSRF_EXCLUDED_PREFIXES: List[str] = [ # WebSocket paths "/api/ws/", "/ws/", ] class CSRFProtectionError(HTTPException): """Custom exception for CSRF validation failures.""" def __init__(self, detail: str = "CSRF validation failed"): super().__init__( status_code=status.HTTP_403_FORBIDDEN, detail=detail ) class CSRFMiddleware(BaseHTTPMiddleware): """ Global CSRF protection middleware. Validates CSRF tokens for all authenticated state-changing requests (POST, PUT, PATCH, DELETE) except for explicitly excluded endpoints. """ async def dispatch(self, request: Request, call_next): """Process the request and validate CSRF token if needed.""" method = request.method.upper() path = request.url.path # Skip CSRF validation for safe methods if method in CSRF_SAFE_METHODS: return await call_next(request) # Skip CSRF validation for excluded paths if self._is_excluded_path(path): logger.debug("CSRF validation skipped for excluded path: %s", path) return await call_next(request) # Try to extract user ID from the Authorization header user_id = self._extract_user_id_from_token(request) # If no user ID (unauthenticated request), skip CSRF validation # The authentication middleware will handle unauthorized access if user_id is None: logger.debug( "CSRF validation skipped (no auth token): %s %s", method, path ) return await call_next(request) # Get CSRF token from header csrf_token = request.headers.get(CSRF_TOKEN_HEADER) if not csrf_token: logger.warning( "CSRF validation failed: Missing token for user %s on %s %s", user_id, method, path ) return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content={"detail": "CSRF token is required"} ) # Validate the token is_valid, error_message = validate_csrf_token(csrf_token, user_id) if not is_valid: logger.warning( "CSRF validation failed for user %s on %s %s: %s", user_id, method, path, error_message ) return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content={"detail": error_message} ) logger.debug( "CSRF validation passed for user %s on %s %s", user_id, method, path ) return await call_next(request) def _is_excluded_path(self, path: str) -> bool: """Check if the path is excluded from CSRF validation.""" # Check exact path matches if path in CSRF_EXCLUDED_PATHS: return True # Check path prefixes for prefix in CSRF_EXCLUDED_PREFIXES: if path.startswith(prefix): return True return False def _extract_user_id_from_token(self, request: Request) -> Optional[str]: """ Extract user ID from the Authorization header. Returns None if no valid token is found (unauthenticated request). """ auth_header = request.headers.get("Authorization") if not auth_header: return None # Parse Bearer token parts = auth_header.split() if len(parts) != 2 or parts[0].lower() != "bearer": return None token = parts[1] # Decode the token to get user ID try: payload = decode_access_token(token) if payload is None: return None return payload.get("sub") except Exception as e: logger.debug("Failed to decode token for CSRF validation: %s", e) return None def require_csrf_token(func: Callable) -> Callable: """ Decorator to require CSRF token validation for an endpoint. Usage: @router.delete("/resource/{id}") @require_csrf_token async def delete_resource(request: Request, id: str, current_user: User = Depends(get_current_user)): ... The decorator validates the X-CSRF-Token header against the current user. """ @wraps(func) async def wrapper(*args, **kwargs): # Extract request and current_user from kwargs request: Optional[Request] = kwargs.get("request") current_user = kwargs.get("current_user") if request is None: # Try to find request in args (for methods where request is positional) for arg in args: if isinstance(arg, Request): request = arg break if request is None: logger.error("CSRF validation failed: Request object not found") raise CSRFProtectionError("Internal error: Request not available") if current_user is None: logger.error("CSRF validation failed: User not authenticated") raise CSRFProtectionError("Authentication required for CSRF-protected endpoint") # Get CSRF token from header csrf_token = request.headers.get(CSRF_TOKEN_HEADER) if not csrf_token: logger.warning( "CSRF validation failed: Missing token for user %s on %s %s", current_user.id, request.method, request.url.path ) raise CSRFProtectionError("CSRF token is required") # Validate the token is_valid, error_message = validate_csrf_token(csrf_token, current_user.id) if not is_valid: logger.warning( "CSRF validation failed for user %s on %s %s: %s", current_user.id, request.method, request.url.path, error_message ) raise CSRFProtectionError(error_message) logger.debug( "CSRF validation passed for user %s on %s %s", current_user.id, request.method, request.url.path ) return await func(*args, **kwargs) return wrapper def get_csrf_token_for_user(user_id: str) -> str: """ Generate a CSRF token for a user. This function can be called from login endpoints to provide the client with a CSRF token. Args: user_id: The user's ID Returns: CSRF token string """ return generate_csrf_token(user_id) async def validate_csrf_for_request( request: Request, user_id: str, skip_methods: Optional[List[str]] = None ) -> bool: """ Validate CSRF token for a request. This is a utility function that can be used directly in endpoints without the decorator. Args: request: The FastAPI request object user_id: The current user's ID skip_methods: HTTP methods to skip validation for (default: GET, HEAD, OPTIONS) Returns: True if validation passes Raises: CSRFProtectionError: If validation fails """ if skip_methods is None: skip_methods = ["GET", "HEAD", "OPTIONS"] # Skip validation for safe methods if request.method.upper() in skip_methods: return True # Get CSRF token from header csrf_token = request.headers.get(CSRF_TOKEN_HEADER) if not csrf_token: raise CSRFProtectionError("CSRF token is required") is_valid, error_message = validate_csrf_token(csrf_token, user_id) if not is_valid: raise CSRFProtectionError(error_message) return True