feat: implement 5 QA-driven security and quality proposals
Implemented proposals from comprehensive QA review: 1. extend-csrf-protection - Add POST to CSRF protected methods in frontend - Global CSRF middleware for all state-changing operations - Update tests with CSRF token fixtures 2. tighten-cors-websocket-security - Replace wildcard CORS with explicit method/header lists - Disable query parameter auth in production (code 4002) - Add per-user WebSocket connection limit (max 5, code 4005) 3. shorten-jwt-expiry - Reduce JWT expiry from 7 days to 60 minutes - Add refresh token support with 7-day expiry - Implement token rotation on refresh - Frontend auto-refresh when token near expiry (<5 min) 4. fix-frontend-quality - Add React.lazy() code splitting for all pages - Fix useCallback dependency arrays (Dashboard, Comments) - Add localStorage data validation in AuthContext - Complete i18n for AttachmentUpload component 5. enhance-backend-validation - Add SecurityAuditMiddleware for access denied logging - Add ErrorSanitizerMiddleware for production error messages - Protect /health/detailed with admin authentication - Add input length validation (comment 5000, desc 10000) All 521 backend tests passing. Frontend builds successfully. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
215
backend/app/middleware/security_audit.py
Normal file
215
backend/app/middleware/security_audit.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Security audit middleware for logging access denials and suspicious auth patterns."""
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List, Tuple
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.models import AuditLog, AuditAction
|
||||
from app.services.audit_service import AuditService
|
||||
|
||||
|
||||
# In-memory storage for tracking auth failures
|
||||
# Structure: {ip_address: [(timestamp, path), ...]}
|
||||
_auth_failure_tracker: Dict[str, List[Tuple[float, str]]] = defaultdict(list)
|
||||
|
||||
# Configuration constants
|
||||
AUTH_FAILURE_THRESHOLD = 5 # Number of failures to trigger suspicious pattern alert
|
||||
AUTH_FAILURE_WINDOW_SECONDS = 600 # 10 minutes window
|
||||
|
||||
|
||||
def _cleanup_old_failures(ip: str) -> None:
|
||||
"""Remove auth failures older than the tracking window."""
|
||||
if ip not in _auth_failure_tracker:
|
||||
return
|
||||
|
||||
cutoff = time.time() - AUTH_FAILURE_WINDOW_SECONDS
|
||||
_auth_failure_tracker[ip] = [
|
||||
(ts, path) for ts, path in _auth_failure_tracker[ip]
|
||||
if ts > cutoff
|
||||
]
|
||||
|
||||
# Clean up empty entries
|
||||
if not _auth_failure_tracker[ip]:
|
||||
del _auth_failure_tracker[ip]
|
||||
|
||||
|
||||
def _track_auth_failure(ip: str, path: str) -> int:
|
||||
"""Track an auth failure and return the count in the window."""
|
||||
_cleanup_old_failures(ip)
|
||||
_auth_failure_tracker[ip].append((time.time(), path))
|
||||
return len(_auth_failure_tracker[ip])
|
||||
|
||||
|
||||
def _get_recent_failures(ip: str) -> List[str]:
|
||||
"""Get list of paths that failed auth for this IP."""
|
||||
_cleanup_old_failures(ip)
|
||||
return [path for _, path in _auth_failure_tracker.get(ip, [])]
|
||||
|
||||
|
||||
class SecurityAuditMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to audit security-related events like 401/403 responses."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
response = await call_next(request)
|
||||
|
||||
# Only process 401 and 403 responses
|
||||
if response.status_code not in (401, 403):
|
||||
return response
|
||||
|
||||
# Get client IP from audit metadata if available
|
||||
ip_address = self._get_client_ip(request)
|
||||
path = str(request.url.path)
|
||||
method = request.method
|
||||
|
||||
# Get user_id if available from request state (set by auth middleware)
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
if response.status_code == 403:
|
||||
self._log_access_denied(db, ip_address, path, method, user_id, request)
|
||||
elif response.status_code == 401:
|
||||
self._log_auth_failure(db, ip_address, path, method, request)
|
||||
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
# Don't fail the request due to audit logging errors
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return response
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Get the real client IP address from request."""
|
||||
# Check for audit metadata first (set by AuditMiddleware)
|
||||
audit_metadata = getattr(request.state, 'audit_metadata', None)
|
||||
if audit_metadata and 'ip_address' in audit_metadata:
|
||||
return audit_metadata['ip_address']
|
||||
|
||||
# Fallback to checking headers directly
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _log_access_denied(
|
||||
self,
|
||||
db: Session,
|
||||
ip_address: str,
|
||||
path: str,
|
||||
method: str,
|
||||
user_id: str | None,
|
||||
request: Request,
|
||||
) -> None:
|
||||
"""Log a 403 Forbidden response to the audit trail."""
|
||||
request_metadata = {
|
||||
"ip_address": ip_address,
|
||||
"user_agent": request.headers.get("user-agent", ""),
|
||||
"method": method,
|
||||
"path": path,
|
||||
}
|
||||
|
||||
# Try to extract resource info from path
|
||||
resource_type = self._extract_resource_type(path)
|
||||
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="security.access_denied",
|
||||
resource_type=resource_type,
|
||||
action=AuditAction.ACCESS_DENIED,
|
||||
user_id=user_id,
|
||||
resource_id=None,
|
||||
changes=[{
|
||||
"attempted_path": path,
|
||||
"attempted_method": method,
|
||||
"ip_address": ip_address,
|
||||
}],
|
||||
request_metadata=request_metadata,
|
||||
)
|
||||
|
||||
def _log_auth_failure(
|
||||
self,
|
||||
db: Session,
|
||||
ip_address: str,
|
||||
path: str,
|
||||
method: str,
|
||||
request: Request,
|
||||
) -> None:
|
||||
"""Log a 401 Unauthorized response and check for suspicious patterns."""
|
||||
# Track this failure
|
||||
failure_count = _track_auth_failure(ip_address, path)
|
||||
|
||||
request_metadata = {
|
||||
"ip_address": ip_address,
|
||||
"user_agent": request.headers.get("user-agent", ""),
|
||||
"method": method,
|
||||
"path": path,
|
||||
}
|
||||
|
||||
resource_type = self._extract_resource_type(path)
|
||||
|
||||
# Log the auth failure
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="security.auth_failed",
|
||||
resource_type=resource_type,
|
||||
action=AuditAction.AUTH_FAILED,
|
||||
user_id=None, # No user for 401
|
||||
resource_id=None,
|
||||
changes=[{
|
||||
"attempted_path": path,
|
||||
"attempted_method": method,
|
||||
"ip_address": ip_address,
|
||||
"failure_count_in_window": failure_count,
|
||||
}],
|
||||
request_metadata=request_metadata,
|
||||
)
|
||||
|
||||
# Check for suspicious pattern
|
||||
if failure_count >= AUTH_FAILURE_THRESHOLD:
|
||||
recent_paths = _get_recent_failures(ip_address)
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="security.suspicious_auth_pattern",
|
||||
resource_type="security",
|
||||
action=AuditAction.AUTH_FAILED,
|
||||
user_id=None,
|
||||
resource_id=None,
|
||||
changes=[{
|
||||
"ip_address": ip_address,
|
||||
"failure_count": failure_count,
|
||||
"window_minutes": AUTH_FAILURE_WINDOW_SECONDS // 60,
|
||||
"attempted_paths": list(set(recent_paths)),
|
||||
}],
|
||||
request_metadata=request_metadata,
|
||||
)
|
||||
|
||||
def _extract_resource_type(self, path: str) -> str:
|
||||
"""Extract resource type from path for audit logging."""
|
||||
# Remove /api/ or /api/v1/ prefix
|
||||
clean_path = path
|
||||
if clean_path.startswith("/api/v1/"):
|
||||
clean_path = clean_path[8:]
|
||||
elif clean_path.startswith("/api/"):
|
||||
clean_path = clean_path[5:]
|
||||
|
||||
# Get the first path segment as resource type
|
||||
parts = clean_path.strip("/").split("/")
|
||||
if parts and parts[0]:
|
||||
return parts[0]
|
||||
|
||||
return "unknown"
|
||||
Reference in New Issue
Block a user