from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from typing import Optional class AuditMiddleware(BaseHTTPMiddleware): """Middleware to extract audit metadata from requests.""" async def dispatch(self, request: Request, call_next): # Extract metadata from request request.state.audit_metadata = { "ip_address": self.get_client_ip(request), "user_agent": request.headers.get("user-agent", ""), "method": request.method, "path": str(request.url.path), } response = await call_next(request) return response @staticmethod def get_client_ip(request: Request) -> str: """Get the real client IP address from request.""" # Check for forwarded headers (when behind a proxy) forwarded = request.headers.get("x-forwarded-for") if forwarded: # Take the first IP in the chain (original client) return forwarded.split(",")[0].strip() real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip # Fallback to direct client if request.client: return request.client.host return "unknown" def get_audit_metadata(request: Request) -> Optional[dict]: """Get audit metadata from request state.""" return getattr(request.state, "audit_metadata", None)