"""Error message sanitization middleware for production environments. This middleware intercepts error responses and sanitizes them to prevent information disclosure in production environments. Detailed error messages are only shown when DEBUG mode is enabled. """ import json import logging from typing import Optional from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response, JSONResponse from app.core.config import settings logger = logging.getLogger(__name__) # Generic error messages for production GENERIC_ERROR_MESSAGES = { 400: "Bad Request", 401: "Authentication required", 403: "Access denied", 404: "Resource not found", 405: "Method not allowed", 409: "Request conflict", 422: "Validation error", 429: "Too many requests", 500: "Internal server error", 502: "Service unavailable", 503: "Service temporarily unavailable", 504: "Request timeout", } # Status codes that should preserve their original message even in production # These are typically user-facing validation errors that don't leak sensitive info PRESERVE_MESSAGE_CODES = { 400, # Bad request - users need to know what's wrong with their request 401, # Unauthorized - users need to know why auth failed 403, # Forbidden - users need to know what permission they lack 404, # Not found - usually safe to preserve 409, # Conflict - users need to know about conflicts 422, # Validation errors - users need to know what to fix } # Patterns that indicate sensitive information in error messages SENSITIVE_PATTERNS = [ "traceback", "stack trace", "file path", "/usr/", "/var/", "/home/", "connection refused", "connection error", "timeout connecting", "database error", "sql", "query failed", "password", "secret", "token", "key=", "credentials", ".py line", "exception in", ] def _contains_sensitive_info(message: str) -> bool: """Check if an error message contains potentially sensitive information.""" if not message: return False message_lower = message.lower() return any(pattern.lower() in message_lower for pattern in SENSITIVE_PATTERNS) def _sanitize_detail(detail: any, status_code: int) -> any: """Sanitize error detail, removing sensitive information in production. Args: detail: The error detail (can be string, list, or dict) status_code: The HTTP status code Returns: Sanitized detail for production, or original detail for debug mode """ # In debug mode, return original detail if settings.DEBUG: return detail # For preserved status codes, keep the detail if it doesn't contain sensitive info if status_code in PRESERVE_MESSAGE_CODES: if isinstance(detail, str) and not _contains_sensitive_info(detail): return detail if isinstance(detail, list): # For validation errors (list of dicts), keep the structure but sanitize sanitized = [] for item in detail: if isinstance(item, dict): # Keep loc, msg, type for pydantic validation errors sanitized_item = {} if 'loc' in item: sanitized_item['loc'] = item['loc'] if 'msg' in item and not _contains_sensitive_info(str(item['msg'])): sanitized_item['msg'] = item['msg'] else: sanitized_item['msg'] = 'Validation failed' if 'type' in item: sanitized_item['type'] = item['type'] sanitized.append(sanitized_item) else: sanitized.append(item if not _contains_sensitive_info(str(item)) else 'Invalid value') return sanitized return detail # For other status codes, use generic message return GENERIC_ERROR_MESSAGES.get(status_code, "An error occurred") class ErrorSanitizerMiddleware(BaseHTTPMiddleware): """Middleware to sanitize error responses in production. This middleware: 1. Intercepts error responses (4xx and 5xx status codes) 2. Parses JSON response bodies 3. Sanitizes the 'detail' field to remove sensitive information 4. Returns the sanitized response In DEBUG mode, original error messages are preserved for development. """ async def dispatch(self, request: Request, call_next) -> Response: response = await call_next(request) # Only process error responses with JSON content if response.status_code < 400: return response content_type = response.headers.get("content-type", "") if "application/json" not in content_type: return response # Read the response body body = b"" async for chunk in response.body_iterator: body += chunk if not body: return response try: data = json.loads(body) except (json.JSONDecodeError, UnicodeDecodeError): # Not valid JSON, return as-is return Response( content=body, status_code=response.status_code, headers=dict(response.headers), media_type=response.media_type, ) # Sanitize the detail field if present if "detail" in data: original_detail = data["detail"] data["detail"] = _sanitize_detail(original_detail, response.status_code) # Log the original error in production for debugging if not settings.DEBUG and original_detail != data["detail"]: logger.warning( "Sanitized error response", extra={ "status_code": response.status_code, "path": str(request.url.path), "method": request.method, "original_detail_length": len(str(original_detail)), } ) # Return the sanitized response return JSONResponse( content=data, status_code=response.status_code, headers={ k: v for k, v in response.headers.items() if k.lower() not in ("content-length", "content-type") }, )