feat: implement 5 QA-driven security and quality proposals

Implemented proposals from comprehensive QA review:

1. extend-csrf-protection
   - Add POST to CSRF protected methods in frontend
   - Global CSRF middleware for all state-changing operations
   - Update tests with CSRF token fixtures

2. tighten-cors-websocket-security
   - Replace wildcard CORS with explicit method/header lists
   - Disable query parameter auth in production (code 4002)
   - Add per-user WebSocket connection limit (max 5, code 4005)

3. shorten-jwt-expiry
   - Reduce JWT expiry from 7 days to 60 minutes
   - Add refresh token support with 7-day expiry
   - Implement token rotation on refresh
   - Frontend auto-refresh when token near expiry (<5 min)

4. fix-frontend-quality
   - Add React.lazy() code splitting for all pages
   - Fix useCallback dependency arrays (Dashboard, Comments)
   - Add localStorage data validation in AuthContext
   - Complete i18n for AttachmentUpload component

5. enhance-backend-validation
   - Add SecurityAuditMiddleware for access denied logging
   - Add ErrorSanitizerMiddleware for production error messages
   - Protect /health/detailed with admin authentication
   - Add input length validation (comment 5000, desc 10000)

All 521 backend tests passing. Frontend builds successfully.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-12 23:19:05 +08:00
parent df50d5e7f8
commit 35c90fe76b
48 changed files with 2132 additions and 403 deletions

View File

@@ -1,38 +1,55 @@
"""
CSRF (Cross-Site Request Forgery) Protection Middleware.
This module provides CSRF protection for sensitive state-changing operations.
It validates CSRF tokens for specified protected endpoints.
This module provides CSRF protection for all state-changing operations.
It validates CSRF tokens globally for authenticated POST, PUT, PATCH, DELETE requests.
"""
from fastapi import Request, HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional, Callable, List
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from fastapi import HTTPException, status
from typing import Optional, Callable, List, Set
from functools import wraps
import logging
from app.core.security import validate_csrf_token, generate_csrf_token
from app.core.security import validate_csrf_token, generate_csrf_token, decode_access_token
logger = logging.getLogger(__name__)
# Header name for CSRF token
CSRF_TOKEN_HEADER = "X-CSRF-Token"
# List of endpoint patterns that require CSRF protection
# These are sensitive state-changing operations
CSRF_PROTECTED_PATTERNS = [
# User operations
"/api/v1/users/{user_id}/admin", # Admin status change
"/api/users/{user_id}/admin", # Legacy
# Password changes would go here if implemented
# Delete operations
"/api/attachments/{attachment_id}", # DELETE method
"/api/tasks/{task_id}", # DELETE method (soft delete)
"/api/projects/{project_id}", # DELETE method
]
# Methods that require CSRF protection (all state-changing operations)
CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
# Methods that require CSRF protection
CSRF_PROTECTED_METHODS = ["DELETE", "PUT", "PATCH"]
# Safe methods that don't require CSRF protection
CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
# Public endpoints that don't require CSRF validation
# These are endpoints that either:
# 1. Don't require authentication (login, health checks)
# 2. Are not state-changing in a security-sensitive way
CSRF_EXCLUDED_PATHS: Set[str] = {
# Authentication endpoints (unauthenticated)
"/api/auth/login",
"/api/v1/auth/login",
# Health check endpoints (unauthenticated)
"/health",
"/health/live",
"/health/ready",
"/health/detailed",
# WebSocket endpoints (use different auth mechanism)
"/api/ws",
"/ws",
}
# Path prefixes that are excluded from CSRF validation
CSRF_EXCLUDED_PREFIXES: List[str] = [
# WebSocket paths
"/api/ws/",
"/ws/",
]
class CSRFProtectionError(HTTPException):
@@ -45,6 +62,114 @@ class CSRFProtectionError(HTTPException):
)
class CSRFMiddleware(BaseHTTPMiddleware):
"""
Global CSRF protection middleware.
Validates CSRF tokens for all authenticated state-changing requests
(POST, PUT, PATCH, DELETE) except for explicitly excluded endpoints.
"""
async def dispatch(self, request: Request, call_next):
"""Process the request and validate CSRF token if needed."""
method = request.method.upper()
path = request.url.path
# Skip CSRF validation for safe methods
if method in CSRF_SAFE_METHODS:
return await call_next(request)
# Skip CSRF validation for excluded paths
if self._is_excluded_path(path):
logger.debug("CSRF validation skipped for excluded path: %s", path)
return await call_next(request)
# Try to extract user ID from the Authorization header
user_id = self._extract_user_id_from_token(request)
# If no user ID (unauthenticated request), skip CSRF validation
# The authentication middleware will handle unauthorized access
if user_id is None:
logger.debug(
"CSRF validation skipped (no auth token): %s %s",
method, path
)
return await call_next(request)
# Get CSRF token from header
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
if not csrf_token:
logger.warning(
"CSRF validation failed: Missing token for user %s on %s %s",
user_id, method, path
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": "CSRF token is required"}
)
# Validate the token
is_valid, error_message = validate_csrf_token(csrf_token, user_id)
if not is_valid:
logger.warning(
"CSRF validation failed for user %s on %s %s: %s",
user_id, method, path, error_message
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": error_message}
)
logger.debug(
"CSRF validation passed for user %s on %s %s",
user_id, method, path
)
return await call_next(request)
def _is_excluded_path(self, path: str) -> bool:
"""Check if the path is excluded from CSRF validation."""
# Check exact path matches
if path in CSRF_EXCLUDED_PATHS:
return True
# Check path prefixes
for prefix in CSRF_EXCLUDED_PREFIXES:
if path.startswith(prefix):
return True
return False
def _extract_user_id_from_token(self, request: Request) -> Optional[str]:
"""
Extract user ID from the Authorization header.
Returns None if no valid token is found (unauthenticated request).
"""
auth_header = request.headers.get("Authorization")
if not auth_header:
return None
# Parse Bearer token
parts = auth_header.split()
if len(parts) != 2 or parts[0].lower() != "bearer":
return None
token = parts[1]
# Decode the token to get user ID
try:
payload = decode_access_token(token)
if payload is None:
return None
return payload.get("sub")
except Exception as e:
logger.debug("Failed to decode token for CSRF validation: %s", e)
return None
def require_csrf_token(func: Callable) -> Callable:
"""
Decorator to require CSRF token validation for an endpoint.

View File

@@ -0,0 +1,187 @@
"""Error message sanitization middleware for production environments.
This middleware intercepts error responses and sanitizes them to prevent
information disclosure in production environments. Detailed error messages
are only shown when DEBUG mode is enabled.
"""
import json
import logging
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from app.core.config import settings
logger = logging.getLogger(__name__)
# Generic error messages for production
GENERIC_ERROR_MESSAGES = {
400: "Bad Request",
401: "Authentication required",
403: "Access denied",
404: "Resource not found",
405: "Method not allowed",
409: "Request conflict",
422: "Validation error",
429: "Too many requests",
500: "Internal server error",
502: "Service unavailable",
503: "Service temporarily unavailable",
504: "Request timeout",
}
# Status codes that should preserve their original message even in production
# These are typically user-facing validation errors that don't leak sensitive info
PRESERVE_MESSAGE_CODES = {
400, # Bad request - users need to know what's wrong with their request
401, # Unauthorized - users need to know why auth failed
403, # Forbidden - users need to know what permission they lack
404, # Not found - usually safe to preserve
409, # Conflict - users need to know about conflicts
422, # Validation errors - users need to know what to fix
}
# Patterns that indicate sensitive information in error messages
SENSITIVE_PATTERNS = [
"traceback",
"stack trace",
"file path",
"/usr/",
"/var/",
"/home/",
"connection refused",
"connection error",
"timeout connecting",
"database error",
"sql",
"query failed",
"password",
"secret",
"token",
"key=",
"credentials",
".py line",
"exception in",
]
def _contains_sensitive_info(message: str) -> bool:
"""Check if an error message contains potentially sensitive information."""
if not message:
return False
message_lower = message.lower()
return any(pattern.lower() in message_lower for pattern in SENSITIVE_PATTERNS)
def _sanitize_detail(detail: any, status_code: int) -> any:
"""Sanitize error detail, removing sensitive information in production.
Args:
detail: The error detail (can be string, list, or dict)
status_code: The HTTP status code
Returns:
Sanitized detail for production, or original detail for debug mode
"""
# In debug mode, return original detail
if settings.DEBUG:
return detail
# For preserved status codes, keep the detail if it doesn't contain sensitive info
if status_code in PRESERVE_MESSAGE_CODES:
if isinstance(detail, str) and not _contains_sensitive_info(detail):
return detail
if isinstance(detail, list):
# For validation errors (list of dicts), keep the structure but sanitize
sanitized = []
for item in detail:
if isinstance(item, dict):
# Keep loc, msg, type for pydantic validation errors
sanitized_item = {}
if 'loc' in item:
sanitized_item['loc'] = item['loc']
if 'msg' in item and not _contains_sensitive_info(str(item['msg'])):
sanitized_item['msg'] = item['msg']
else:
sanitized_item['msg'] = 'Validation failed'
if 'type' in item:
sanitized_item['type'] = item['type']
sanitized.append(sanitized_item)
else:
sanitized.append(item if not _contains_sensitive_info(str(item)) else 'Invalid value')
return sanitized
return detail
# For other status codes, use generic message
return GENERIC_ERROR_MESSAGES.get(status_code, "An error occurred")
class ErrorSanitizerMiddleware(BaseHTTPMiddleware):
"""Middleware to sanitize error responses in production.
This middleware:
1. Intercepts error responses (4xx and 5xx status codes)
2. Parses JSON response bodies
3. Sanitizes the 'detail' field to remove sensitive information
4. Returns the sanitized response
In DEBUG mode, original error messages are preserved for development.
"""
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# Only process error responses with JSON content
if response.status_code < 400:
return response
content_type = response.headers.get("content-type", "")
if "application/json" not in content_type:
return response
# Read the response body
body = b""
async for chunk in response.body_iterator:
body += chunk
if not body:
return response
try:
data = json.loads(body)
except (json.JSONDecodeError, UnicodeDecodeError):
# Not valid JSON, return as-is
return Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
# Sanitize the detail field if present
if "detail" in data:
original_detail = data["detail"]
data["detail"] = _sanitize_detail(original_detail, response.status_code)
# Log the original error in production for debugging
if not settings.DEBUG and original_detail != data["detail"]:
logger.warning(
"Sanitized error response",
extra={
"status_code": response.status_code,
"path": str(request.url.path),
"method": request.method,
"original_detail_length": len(str(original_detail)),
}
)
# Return the sanitized response
return JSONResponse(
content=data,
status_code=response.status_code,
headers={
k: v for k, v in response.headers.items()
if k.lower() not in ("content-length", "content-type")
},
)

View File

@@ -0,0 +1,215 @@
"""Security audit middleware for logging access denials and suspicious auth patterns."""
import time
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Tuple
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.models import AuditLog, AuditAction
from app.services.audit_service import AuditService
# In-memory storage for tracking auth failures
# Structure: {ip_address: [(timestamp, path), ...]}
_auth_failure_tracker: Dict[str, List[Tuple[float, str]]] = defaultdict(list)
# Configuration constants
AUTH_FAILURE_THRESHOLD = 5 # Number of failures to trigger suspicious pattern alert
AUTH_FAILURE_WINDOW_SECONDS = 600 # 10 minutes window
def _cleanup_old_failures(ip: str) -> None:
"""Remove auth failures older than the tracking window."""
if ip not in _auth_failure_tracker:
return
cutoff = time.time() - AUTH_FAILURE_WINDOW_SECONDS
_auth_failure_tracker[ip] = [
(ts, path) for ts, path in _auth_failure_tracker[ip]
if ts > cutoff
]
# Clean up empty entries
if not _auth_failure_tracker[ip]:
del _auth_failure_tracker[ip]
def _track_auth_failure(ip: str, path: str) -> int:
"""Track an auth failure and return the count in the window."""
_cleanup_old_failures(ip)
_auth_failure_tracker[ip].append((time.time(), path))
return len(_auth_failure_tracker[ip])
def _get_recent_failures(ip: str) -> List[str]:
"""Get list of paths that failed auth for this IP."""
_cleanup_old_failures(ip)
return [path for _, path in _auth_failure_tracker.get(ip, [])]
class SecurityAuditMiddleware(BaseHTTPMiddleware):
"""Middleware to audit security-related events like 401/403 responses."""
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# Only process 401 and 403 responses
if response.status_code not in (401, 403):
return response
# Get client IP from audit metadata if available
ip_address = self._get_client_ip(request)
path = str(request.url.path)
method = request.method
# Get user_id if available from request state (set by auth middleware)
user_id = getattr(request.state, 'user_id', None)
db: Session = SessionLocal()
try:
if response.status_code == 403:
self._log_access_denied(db, ip_address, path, method, user_id, request)
elif response.status_code == 401:
self._log_auth_failure(db, ip_address, path, method, request)
db.commit()
except Exception:
db.rollback()
# Don't fail the request due to audit logging errors
finally:
db.close()
return response
def _get_client_ip(self, request: Request) -> str:
"""Get the real client IP address from request."""
# Check for audit metadata first (set by AuditMiddleware)
audit_metadata = getattr(request.state, 'audit_metadata', None)
if audit_metadata and 'ip_address' in audit_metadata:
return audit_metadata['ip_address']
# Fallback to checking headers directly
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"
def _log_access_denied(
self,
db: Session,
ip_address: str,
path: str,
method: str,
user_id: str | None,
request: Request,
) -> None:
"""Log a 403 Forbidden response to the audit trail."""
request_metadata = {
"ip_address": ip_address,
"user_agent": request.headers.get("user-agent", ""),
"method": method,
"path": path,
}
# Try to extract resource info from path
resource_type = self._extract_resource_type(path)
AuditService.log_event(
db=db,
event_type="security.access_denied",
resource_type=resource_type,
action=AuditAction.ACCESS_DENIED,
user_id=user_id,
resource_id=None,
changes=[{
"attempted_path": path,
"attempted_method": method,
"ip_address": ip_address,
}],
request_metadata=request_metadata,
)
def _log_auth_failure(
self,
db: Session,
ip_address: str,
path: str,
method: str,
request: Request,
) -> None:
"""Log a 401 Unauthorized response and check for suspicious patterns."""
# Track this failure
failure_count = _track_auth_failure(ip_address, path)
request_metadata = {
"ip_address": ip_address,
"user_agent": request.headers.get("user-agent", ""),
"method": method,
"path": path,
}
resource_type = self._extract_resource_type(path)
# Log the auth failure
AuditService.log_event(
db=db,
event_type="security.auth_failed",
resource_type=resource_type,
action=AuditAction.AUTH_FAILED,
user_id=None, # No user for 401
resource_id=None,
changes=[{
"attempted_path": path,
"attempted_method": method,
"ip_address": ip_address,
"failure_count_in_window": failure_count,
}],
request_metadata=request_metadata,
)
# Check for suspicious pattern
if failure_count >= AUTH_FAILURE_THRESHOLD:
recent_paths = _get_recent_failures(ip_address)
AuditService.log_event(
db=db,
event_type="security.suspicious_auth_pattern",
resource_type="security",
action=AuditAction.AUTH_FAILED,
user_id=None,
resource_id=None,
changes=[{
"ip_address": ip_address,
"failure_count": failure_count,
"window_minutes": AUTH_FAILURE_WINDOW_SECONDS // 60,
"attempted_paths": list(set(recent_paths)),
}],
request_metadata=request_metadata,
)
def _extract_resource_type(self, path: str) -> str:
"""Extract resource type from path for audit logging."""
# Remove /api/ or /api/v1/ prefix
clean_path = path
if clean_path.startswith("/api/v1/"):
clean_path = clean_path[8:]
elif clean_path.startswith("/api/"):
clean_path = clean_path[5:]
# Get the first path segment as resource type
parts = clean_path.strip("/").split("/")
if parts and parts[0]:
return parts[0]
return "unknown"