"""Security audit middleware for logging access denials and suspicious auth patterns.""" import time from collections import defaultdict from datetime import datetime, timedelta, timezone from typing import Dict, List, Tuple from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response from sqlalchemy.orm import Session from app.core.database import SessionLocal from app.models import AuditLog, AuditAction from app.services.audit_service import AuditService # In-memory storage for tracking auth failures # Structure: {ip_address: [(timestamp, path), ...]} _auth_failure_tracker: Dict[str, List[Tuple[float, str]]] = defaultdict(list) # Configuration constants AUTH_FAILURE_THRESHOLD = 5 # Number of failures to trigger suspicious pattern alert AUTH_FAILURE_WINDOW_SECONDS = 600 # 10 minutes window def _cleanup_old_failures(ip: str) -> None: """Remove auth failures older than the tracking window.""" if ip not in _auth_failure_tracker: return cutoff = time.time() - AUTH_FAILURE_WINDOW_SECONDS _auth_failure_tracker[ip] = [ (ts, path) for ts, path in _auth_failure_tracker[ip] if ts > cutoff ] # Clean up empty entries if not _auth_failure_tracker[ip]: del _auth_failure_tracker[ip] def _track_auth_failure(ip: str, path: str) -> int: """Track an auth failure and return the count in the window.""" _cleanup_old_failures(ip) _auth_failure_tracker[ip].append((time.time(), path)) return len(_auth_failure_tracker[ip]) def _get_recent_failures(ip: str) -> List[str]: """Get list of paths that failed auth for this IP.""" _cleanup_old_failures(ip) return [path for _, path in _auth_failure_tracker.get(ip, [])] class SecurityAuditMiddleware(BaseHTTPMiddleware): """Middleware to audit security-related events like 401/403 responses.""" async def dispatch(self, request: Request, call_next) -> Response: response = await call_next(request) # Only process 401 and 403 responses if response.status_code not in (401, 403): return response # Get client IP from audit metadata if available ip_address = self._get_client_ip(request) path = str(request.url.path) method = request.method # Get user_id if available from request state (set by auth middleware) user_id = getattr(request.state, 'user_id', None) db: Session = SessionLocal() try: if response.status_code == 403: self._log_access_denied(db, ip_address, path, method, user_id, request) elif response.status_code == 401: self._log_auth_failure(db, ip_address, path, method, request) db.commit() except Exception: db.rollback() # Don't fail the request due to audit logging errors finally: db.close() return response def _get_client_ip(self, request: Request) -> str: """Get the real client IP address from request.""" # Check for audit metadata first (set by AuditMiddleware) audit_metadata = getattr(request.state, 'audit_metadata', None) if audit_metadata and 'ip_address' in audit_metadata: return audit_metadata['ip_address'] # Fallback to checking headers directly forwarded = request.headers.get("x-forwarded-for") if forwarded: return forwarded.split(",")[0].strip() real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip if request.client: return request.client.host return "unknown" def _log_access_denied( self, db: Session, ip_address: str, path: str, method: str, user_id: str | None, request: Request, ) -> None: """Log a 403 Forbidden response to the audit trail.""" request_metadata = { "ip_address": ip_address, "user_agent": request.headers.get("user-agent", ""), "method": method, "path": path, } # Try to extract resource info from path resource_type = self._extract_resource_type(path) AuditService.log_event( db=db, event_type="security.access_denied", resource_type=resource_type, action=AuditAction.ACCESS_DENIED, user_id=user_id, resource_id=None, changes=[{ "attempted_path": path, "attempted_method": method, "ip_address": ip_address, }], request_metadata=request_metadata, ) def _log_auth_failure( self, db: Session, ip_address: str, path: str, method: str, request: Request, ) -> None: """Log a 401 Unauthorized response and check for suspicious patterns.""" # Track this failure failure_count = _track_auth_failure(ip_address, path) request_metadata = { "ip_address": ip_address, "user_agent": request.headers.get("user-agent", ""), "method": method, "path": path, } resource_type = self._extract_resource_type(path) # Log the auth failure AuditService.log_event( db=db, event_type="security.auth_failed", resource_type=resource_type, action=AuditAction.AUTH_FAILED, user_id=None, # No user for 401 resource_id=None, changes=[{ "attempted_path": path, "attempted_method": method, "ip_address": ip_address, "failure_count_in_window": failure_count, }], request_metadata=request_metadata, ) # Check for suspicious pattern if failure_count >= AUTH_FAILURE_THRESHOLD: recent_paths = _get_recent_failures(ip_address) AuditService.log_event( db=db, event_type="security.suspicious_auth_pattern", resource_type="security", action=AuditAction.AUTH_FAILED, user_id=None, resource_id=None, changes=[{ "ip_address": ip_address, "failure_count": failure_count, "window_minutes": AUTH_FAILURE_WINDOW_SECONDS // 60, "attempted_paths": list(set(recent_paths)), }], request_metadata=request_metadata, ) def _extract_resource_type(self, path: str) -> str: """Extract resource type from path for audit logging.""" # Remove /api/ or /api/v1/ prefix clean_path = path if clean_path.startswith("/api/v1/"): clean_path = clean_path[8:] elif clean_path.startswith("/api/"): clean_path = clean_path[5:] # Get the first path segment as resource type parts = clean_path.strip("/").split("/") if parts and parts[0]: return parts[0] return "unknown"