""" CSRF (Cross-Site Request Forgery) Protection Middleware. This module provides CSRF protection for sensitive state-changing operations. It validates CSRF tokens for specified protected endpoints. """ from fastapi import Request, HTTPException, status, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from typing import Optional, Callable, List from functools import wraps import logging from app.core.security import validate_csrf_token, generate_csrf_token logger = logging.getLogger(__name__) # Header name for CSRF token CSRF_TOKEN_HEADER = "X-CSRF-Token" # List of endpoint patterns that require CSRF protection # These are sensitive state-changing operations CSRF_PROTECTED_PATTERNS = [ # User operations "/api/v1/users/{user_id}/admin", # Admin status change "/api/users/{user_id}/admin", # Legacy # Password changes would go here if implemented # Delete operations "/api/attachments/{attachment_id}", # DELETE method "/api/tasks/{task_id}", # DELETE method (soft delete) "/api/projects/{project_id}", # DELETE method ] # Methods that require CSRF protection CSRF_PROTECTED_METHODS = ["DELETE", "PUT", "PATCH"] class CSRFProtectionError(HTTPException): """Custom exception for CSRF validation failures.""" def __init__(self, detail: str = "CSRF validation failed"): super().__init__( status_code=status.HTTP_403_FORBIDDEN, detail=detail ) def require_csrf_token(func: Callable) -> Callable: """ Decorator to require CSRF token validation for an endpoint. Usage: @router.delete("/resource/{id}") @require_csrf_token async def delete_resource(request: Request, id: str, current_user: User = Depends(get_current_user)): ... The decorator validates the X-CSRF-Token header against the current user. """ @wraps(func) async def wrapper(*args, **kwargs): # Extract request and current_user from kwargs request: Optional[Request] = kwargs.get("request") current_user = kwargs.get("current_user") if request is None: # Try to find request in args (for methods where request is positional) for arg in args: if isinstance(arg, Request): request = arg break if request is None: logger.error("CSRF validation failed: Request object not found") raise CSRFProtectionError("Internal error: Request not available") if current_user is None: logger.error("CSRF validation failed: User not authenticated") raise CSRFProtectionError("Authentication required for CSRF-protected endpoint") # Get CSRF token from header csrf_token = request.headers.get(CSRF_TOKEN_HEADER) if not csrf_token: logger.warning( "CSRF validation failed: Missing token for user %s on %s %s", current_user.id, request.method, request.url.path ) raise CSRFProtectionError("CSRF token is required") # Validate the token is_valid, error_message = validate_csrf_token(csrf_token, current_user.id) if not is_valid: logger.warning( "CSRF validation failed for user %s on %s %s: %s", current_user.id, request.method, request.url.path, error_message ) raise CSRFProtectionError(error_message) logger.debug( "CSRF validation passed for user %s on %s %s", current_user.id, request.method, request.url.path ) return await func(*args, **kwargs) return wrapper def get_csrf_token_for_user(user_id: str) -> str: """ Generate a CSRF token for a user. This function can be called from login endpoints to provide the client with a CSRF token. Args: user_id: The user's ID Returns: CSRF token string """ return generate_csrf_token(user_id) async def validate_csrf_for_request( request: Request, user_id: str, skip_methods: Optional[List[str]] = None ) -> bool: """ Validate CSRF token for a request. This is a utility function that can be used directly in endpoints without the decorator. Args: request: The FastAPI request object user_id: The current user's ID skip_methods: HTTP methods to skip validation for (default: GET, HEAD, OPTIONS) Returns: True if validation passes Raises: CSRFProtectionError: If validation fails """ if skip_methods is None: skip_methods = ["GET", "HEAD", "OPTIONS"] # Skip validation for safe methods if request.method.upper() in skip_methods: return True # Get CSRF token from header csrf_token = request.headers.get(CSRF_TOKEN_HEADER) if not csrf_token: raise CSRFProtectionError("CSRF token is required") is_valid, error_message = validate_csrf_token(csrf_token, user_id) if not is_valid: raise CSRFProtectionError(error_message) return True