feat: implement 5 QA-driven security and quality proposals

Implemented proposals from comprehensive QA review:

1. extend-csrf-protection
   - Add POST to CSRF protected methods in frontend
   - Global CSRF middleware for all state-changing operations
   - Update tests with CSRF token fixtures

2. tighten-cors-websocket-security
   - Replace wildcard CORS with explicit method/header lists
   - Disable query parameter auth in production (code 4002)
   - Add per-user WebSocket connection limit (max 5, code 4005)

3. shorten-jwt-expiry
   - Reduce JWT expiry from 7 days to 60 minutes
   - Add refresh token support with 7-day expiry
   - Implement token rotation on refresh
   - Frontend auto-refresh when token near expiry (<5 min)

4. fix-frontend-quality
   - Add React.lazy() code splitting for all pages
   - Fix useCallback dependency arrays (Dashboard, Comments)
   - Add localStorage data validation in AuthContext
   - Complete i18n for AttachmentUpload component

5. enhance-backend-validation
   - Add SecurityAuditMiddleware for access denied logging
   - Add ErrorSanitizerMiddleware for production error messages
   - Protect /health/detailed with admin authentication
   - Add input length validation (comment 5000, desc 10000)

All 521 backend tests passing. Frontend builds successfully.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-12 23:19:05 +08:00
parent df50d5e7f8
commit 35c90fe76b
48 changed files with 2132 additions and 403 deletions

View File

@@ -3,12 +3,28 @@ from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.database import get_db
from app.core.security import create_access_token, create_token_payload
from app.core.security import (
create_access_token,
create_token_payload,
generate_refresh_token,
store_refresh_token,
validate_refresh_token,
invalidate_refresh_token,
invalidate_all_user_refresh_tokens,
decode_refresh_token_user_id,
)
from app.core.redis import get_redis
from app.core.rate_limiter import limiter
from app.models.user import User
from app.models.audit_log import AuditAction
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo, CSRFTokenResponse
from app.schemas.auth import (
LoginRequest,
LoginResponse,
UserInfo,
CSRFTokenResponse,
RefreshTokenRequest,
RefreshTokenResponse,
)
from app.services.auth_client import (
verify_credentials,
AuthAPIError,
@@ -119,6 +135,9 @@ async def login(
# Create access token
access_token = create_access_token(token_data)
# Generate refresh token
refresh_token = generate_refresh_token()
# Store session in Redis (sync with JWT expiry)
redis_client.setex(
f"session:{user.id}",
@@ -126,6 +145,9 @@ async def login(
access_token,
)
# Store refresh token in Redis with user binding
store_refresh_token(redis_client, user.id, refresh_token)
# Log successful login
AuditService.log_event(
db=db,
@@ -141,6 +163,8 @@ async def login(
return LoginResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=settings.JWT_EXPIRE_MINUTES * 60,
user=UserInfo(
id=user.id,
email=user.email,
@@ -158,14 +182,114 @@ async def logout(
redis_client=Depends(get_redis),
):
"""
Logout user and invalidate session.
Logout user and invalidate session and all refresh tokens.
"""
# Remove session from Redis
redis_client.delete(f"session:{current_user.id}")
# Invalidate all refresh tokens for this user
invalidate_all_user_refresh_tokens(redis_client, current_user.id)
return {"detail": "Successfully logged out"}
@router.post("/refresh", response_model=RefreshTokenResponse)
@limiter.limit("10/minute")
async def refresh_access_token(
request: Request,
refresh_request: RefreshTokenRequest,
db: Session = Depends(get_db),
redis_client=Depends(get_redis),
):
"""
Refresh access token using a valid refresh token.
This endpoint implements refresh token rotation:
- Validates the provided refresh token
- Issues a new access token
- Issues a new refresh token (rotating the old one)
- Invalidates the old refresh token
This provides enhanced security by ensuring refresh tokens are single-use.
"""
old_refresh_token = refresh_request.refresh_token
# Find the user ID associated with this refresh token
user_id = decode_refresh_token_user_id(old_refresh_token, redis_client)
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
headers={"WWW-Authenticate": "Bearer"},
)
# Validate the refresh token is still valid and bound to this user
if not validate_refresh_token(redis_client, user_id, old_refresh_token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
headers={"WWW-Authenticate": "Bearer"},
)
# Get user from database
user = db.query(User).filter(User.id == user_id).first()
if user is None:
# Invalidate the token since user no longer exists
invalidate_refresh_token(redis_client, user_id, old_refresh_token)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
# Invalidate all tokens for disabled user
invalidate_all_user_refresh_tokens(redis_client, user_id)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User account is disabled",
)
# Invalidate the old refresh token (rotation)
invalidate_refresh_token(redis_client, user_id, old_refresh_token)
# Get role name
role_name = user.role.name if user.role else None
# Create new token payload
token_data = create_token_payload(
user_id=user.id,
email=user.email,
role=role_name,
department_id=user.department_id,
is_system_admin=user.is_system_admin,
)
# Create new access token
new_access_token = create_access_token(token_data)
# Generate new refresh token (rotation)
new_refresh_token = generate_refresh_token()
# Store new session in Redis
redis_client.setex(
f"session:{user.id}",
settings.JWT_EXPIRE_MINUTES * 60,
new_access_token,
)
# Store new refresh token
store_refresh_token(redis_client, user.id, new_refresh_token)
return RefreshTokenResponse(
access_token=new_access_token,
refresh_token=new_refresh_token,
expires_in=settings.JWT_EXPIRE_MINUTES * 60,
)
@router.get("/me", response_model=UserInfo)
async def get_current_user_info(
current_user: User = Depends(get_current_user),

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from app.core import database
from app.core.security import decode_access_token
from app.core.redis import get_redis_sync
from app.core.config import settings
from app.models import User, Notification, Project
from app.services.websocket_manager import manager
from app.core.redis_pubsub import NotificationSubscriber, ProjectTaskSubscriber
@@ -72,14 +73,24 @@ async def authenticate_websocket(
Supports two authentication methods:
1. First message authentication (preferred, more secure)
- Client sends: {"type": "auth", "token": "<jwt_token>"}
2. Query parameter authentication (deprecated, for backward compatibility)
2. Query parameter authentication (disabled in production, for backward compatibility only)
- Client connects with: ?token=<jwt_token>
Returns:
Tuple of (user_id, error_reason). user_id is None if authentication fails.
Error reasons: "invalid_token", "invalid_message", "missing_token",
"timeout", "error", "query_auth_disabled"
"""
# If token provided via query parameter (backward compatibility)
if query_token:
# Reject query parameter auth in production for security
if settings.ENVIRONMENT == "production":
logger.warning(
"WebSocket query parameter authentication attempted in production environment. "
"This is disabled for security reasons."
)
return None, "query_auth_disabled"
logger.warning(
"WebSocket authentication via query parameter is deprecated. "
"Please use first-message authentication for better security."
@@ -195,9 +206,21 @@ async def websocket_notifications(
user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
if error_reason == "query_auth_disabled":
await websocket.send_json({"type": "error", "message": "Query parameter auth disabled in production"})
await websocket.close(code=4002, reason="Query parameter auth disabled in production")
elif error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
await websocket.close(code=4001, reason="Invalid or expired token")
else:
await websocket.close(code=4001, reason="Invalid or expired token")
return
# Check connection limit before accepting
can_connect, reject_reason = await manager.check_connection_limit(user_id)
if not can_connect:
await websocket.send_json({"type": "error", "message": reject_reason})
await websocket.close(code=4005, reason=reject_reason)
return
await manager.connect(websocket, user_id)
@@ -394,9 +417,21 @@ async def websocket_project_sync(
user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
if error_reason == "query_auth_disabled":
await websocket.send_json({"type": "error", "message": "Query parameter auth disabled in production"})
await websocket.close(code=4002, reason="Query parameter auth disabled in production")
elif error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
await websocket.close(code=4001, reason="Invalid or expired token")
else:
await websocket.close(code=4001, reason="Invalid or expired token")
return
# Check connection limit before accepting
can_connect, reject_reason = await manager.check_connection_limit(user_id)
if not can_connect:
await websocket.send_json({"type": "error", "message": reject_reason})
await websocket.close(code=4005, reason=reject_reason)
return
# Verify user has access to the project

View File

@@ -28,7 +28,8 @@ class Settings(BaseSettings):
# JWT - Must be set in environment, no default allowed
JWT_SECRET_KEY: str = ""
JWT_ALGORITHM: str = "HS256"
JWT_EXPIRE_MINUTES: int = 10080 # 7 days
JWT_EXPIRE_MINUTES: int = 60 # 1 hour (short-lived access token)
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # Refresh token valid for 7 days
@field_validator("JWT_SECRET_KEY")
@classmethod
@@ -127,6 +128,12 @@ class Settings(BaseSettings):
QUERY_LOGGING: bool = False # Enable SQLAlchemy query logging
QUERY_COUNT_THRESHOLD: int = 10 # Warn when query count exceeds this threshold
# Environment
ENVIRONMENT: str = "development" # Options: development, staging, production
# WebSocket Settings
MAX_WEBSOCKET_CONNECTIONS_PER_USER: int = 5 # Maximum concurrent WebSocket connections per user
class Config:
env_file = ".env"
case_sensitive = True

View File

@@ -356,3 +356,140 @@ def create_token_payload(
"department_id": department_id,
"is_system_admin": is_system_admin,
}
# Refresh Token Functions
REFRESH_TOKEN_BYTES = 32
def generate_refresh_token() -> str:
"""
Generate a cryptographically secure refresh token.
Returns:
A URL-safe base64-encoded random token
"""
return secrets.token_urlsafe(REFRESH_TOKEN_BYTES)
def get_refresh_token_key(user_id: str, token: str) -> str:
"""
Generate the Redis key for a refresh token.
Args:
user_id: The user's ID
token: The refresh token
Returns:
Redis key string
"""
# Hash the token to avoid storing it directly as a key
token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
return f"refresh_token:{user_id}:{token_hash}"
def store_refresh_token(redis_client, user_id: str, token: str) -> None:
"""
Store a refresh token in Redis with user binding.
Args:
redis_client: Redis client instance
user_id: The user's ID
token: The refresh token to store
"""
key = get_refresh_token_key(user_id, token)
# Store with TTL based on REFRESH_TOKEN_EXPIRE_DAYS
ttl_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
redis_client.setex(key, ttl_seconds, user_id)
def validate_refresh_token(redis_client, user_id: str, token: str) -> bool:
"""
Validate a refresh token exists in Redis and is bound to the user.
Args:
redis_client: Redis client instance
user_id: The expected user ID
token: The refresh token to validate
Returns:
True if token is valid, False otherwise
"""
key = get_refresh_token_key(user_id, token)
stored_user_id = redis_client.get(key)
if stored_user_id is None:
return False
# Handle Redis bytes type
if isinstance(stored_user_id, bytes):
stored_user_id = stored_user_id.decode("utf-8")
return stored_user_id == user_id
def invalidate_refresh_token(redis_client, user_id: str, token: str) -> bool:
"""
Invalidate (delete) a refresh token from Redis.
Args:
redis_client: Redis client instance
user_id: The user's ID
token: The refresh token to invalidate
Returns:
True if token was deleted, False if it didn't exist
"""
key = get_refresh_token_key(user_id, token)
result = redis_client.delete(key)
return result > 0 if isinstance(result, int) else bool(result)
def invalidate_all_user_refresh_tokens(redis_client, user_id: str) -> int:
"""
Invalidate all refresh tokens for a user.
Args:
redis_client: Redis client instance
user_id: The user's ID
Returns:
Number of tokens invalidated
"""
pattern = f"refresh_token:{user_id}:*"
count = 0
for key in redis_client.scan_iter(match=pattern):
redis_client.delete(key)
count += 1
return count
def decode_refresh_token_user_id(token: str, redis_client) -> Optional[str]:
"""
Find the user ID associated with a refresh token by searching Redis.
This is used when we only have the token and need to find which user it belongs to.
Note: This is less efficient but necessary for refresh token validation when
the user_id is not provided in the request.
Args:
token: The refresh token
redis_client: Redis client instance
Returns:
User ID if found, None otherwise
"""
# We need to search for the token across all users
# This is done by checking the token hash pattern
token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
pattern = f"refresh_token:*:{token_hash}"
for key in redis_client.scan_iter(match=pattern):
# Extract user_id from key format: refresh_token:{user_id}:{token_hash}
if isinstance(key, bytes):
key = key.decode("utf-8")
parts = key.split(":")
if len(parts) == 3:
return parts[1]
return None

View File

@@ -1,7 +1,7 @@
import os
from contextlib import asynccontextmanager
from datetime import datetime
from fastapi import FastAPI, Request, APIRouter
from fastapi import FastAPI, Request, APIRouter, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from slowapi import _rate_limit_exceeded_handler
@@ -9,6 +9,9 @@ from slowapi.errors import RateLimitExceeded
from sqlalchemy import text
from app.middleware.audit import AuditMiddleware
from app.middleware.csrf import CSRFMiddleware
from app.middleware.security_audit import SecurityAuditMiddleware
from app.middleware.error_sanitizer import ErrorSanitizerMiddleware
from app.core.scheduler import start_scheduler, shutdown_scheduler, scheduler
from app.core.rate_limiter import limiter
from app.core.deprecation import DeprecationMiddleware
@@ -61,6 +64,8 @@ from app.core.database import get_pool_status, engine
from app.core.redis import redis_client
from app.services.notification_service import get_redis_fallback_status
from app.services.file_storage_service import file_storage_service
from app.middleware.auth import require_system_admin
from app.models import User
app = FastAPI(
title="Project Control API",
@@ -73,18 +78,28 @@ app = FastAPI(
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# CORS middleware
# CORS middleware - Explicit methods and headers for security
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-CSRF-Token", "X-Request-ID"],
)
# Error sanitizer middleware - sanitizes error messages in production
# Must be first in the chain to intercept all error responses
app.add_middleware(ErrorSanitizerMiddleware)
# Audit middleware - extracts request metadata for audit logging
app.add_middleware(AuditMiddleware)
# Security audit middleware - logs 401/403 responses to audit trail
app.add_middleware(SecurityAuditMiddleware)
# CSRF middleware - validates CSRF tokens for state-changing requests
app.add_middleware(CSRFMiddleware)
# Deprecation middleware - adds deprecation headers to legacy /api/ routes
app.add_middleware(DeprecationMiddleware)
@@ -252,14 +267,20 @@ async def readiness_check():
@app.get("/health/detailed")
async def detailed_health_check():
"""Detailed health check endpoint.
async def detailed_health_check(
current_user: User = Depends(require_system_admin),
):
"""Detailed health check endpoint (requires system admin).
Returns comprehensive status of all system components:
- database: Connection pool status and connectivity
- redis: Connection status and fallback queue status
- storage: File storage validation status
- scheduler: Background job scheduler status
Note: This endpoint requires system admin authentication because it exposes
sensitive infrastructure details including connection pool statistics and
internal service states.
"""
db_health = check_database_health()
redis_health = check_redis_health()

View File

@@ -1,38 +1,55 @@
"""
CSRF (Cross-Site Request Forgery) Protection Middleware.
This module provides CSRF protection for sensitive state-changing operations.
It validates CSRF tokens for specified protected endpoints.
This module provides CSRF protection for all state-changing operations.
It validates CSRF tokens globally for authenticated POST, PUT, PATCH, DELETE requests.
"""
from fastapi import Request, HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional, Callable, List
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from fastapi import HTTPException, status
from typing import Optional, Callable, List, Set
from functools import wraps
import logging
from app.core.security import validate_csrf_token, generate_csrf_token
from app.core.security import validate_csrf_token, generate_csrf_token, decode_access_token
logger = logging.getLogger(__name__)
# Header name for CSRF token
CSRF_TOKEN_HEADER = "X-CSRF-Token"
# List of endpoint patterns that require CSRF protection
# These are sensitive state-changing operations
CSRF_PROTECTED_PATTERNS = [
# User operations
"/api/v1/users/{user_id}/admin", # Admin status change
"/api/users/{user_id}/admin", # Legacy
# Password changes would go here if implemented
# Delete operations
"/api/attachments/{attachment_id}", # DELETE method
"/api/tasks/{task_id}", # DELETE method (soft delete)
"/api/projects/{project_id}", # DELETE method
]
# Methods that require CSRF protection (all state-changing operations)
CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
# Methods that require CSRF protection
CSRF_PROTECTED_METHODS = ["DELETE", "PUT", "PATCH"]
# Safe methods that don't require CSRF protection
CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
# Public endpoints that don't require CSRF validation
# These are endpoints that either:
# 1. Don't require authentication (login, health checks)
# 2. Are not state-changing in a security-sensitive way
CSRF_EXCLUDED_PATHS: Set[str] = {
# Authentication endpoints (unauthenticated)
"/api/auth/login",
"/api/v1/auth/login",
# Health check endpoints (unauthenticated)
"/health",
"/health/live",
"/health/ready",
"/health/detailed",
# WebSocket endpoints (use different auth mechanism)
"/api/ws",
"/ws",
}
# Path prefixes that are excluded from CSRF validation
CSRF_EXCLUDED_PREFIXES: List[str] = [
# WebSocket paths
"/api/ws/",
"/ws/",
]
class CSRFProtectionError(HTTPException):
@@ -45,6 +62,114 @@ class CSRFProtectionError(HTTPException):
)
class CSRFMiddleware(BaseHTTPMiddleware):
"""
Global CSRF protection middleware.
Validates CSRF tokens for all authenticated state-changing requests
(POST, PUT, PATCH, DELETE) except for explicitly excluded endpoints.
"""
async def dispatch(self, request: Request, call_next):
"""Process the request and validate CSRF token if needed."""
method = request.method.upper()
path = request.url.path
# Skip CSRF validation for safe methods
if method in CSRF_SAFE_METHODS:
return await call_next(request)
# Skip CSRF validation for excluded paths
if self._is_excluded_path(path):
logger.debug("CSRF validation skipped for excluded path: %s", path)
return await call_next(request)
# Try to extract user ID from the Authorization header
user_id = self._extract_user_id_from_token(request)
# If no user ID (unauthenticated request), skip CSRF validation
# The authentication middleware will handle unauthorized access
if user_id is None:
logger.debug(
"CSRF validation skipped (no auth token): %s %s",
method, path
)
return await call_next(request)
# Get CSRF token from header
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
if not csrf_token:
logger.warning(
"CSRF validation failed: Missing token for user %s on %s %s",
user_id, method, path
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": "CSRF token is required"}
)
# Validate the token
is_valid, error_message = validate_csrf_token(csrf_token, user_id)
if not is_valid:
logger.warning(
"CSRF validation failed for user %s on %s %s: %s",
user_id, method, path, error_message
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": error_message}
)
logger.debug(
"CSRF validation passed for user %s on %s %s",
user_id, method, path
)
return await call_next(request)
def _is_excluded_path(self, path: str) -> bool:
"""Check if the path is excluded from CSRF validation."""
# Check exact path matches
if path in CSRF_EXCLUDED_PATHS:
return True
# Check path prefixes
for prefix in CSRF_EXCLUDED_PREFIXES:
if path.startswith(prefix):
return True
return False
def _extract_user_id_from_token(self, request: Request) -> Optional[str]:
"""
Extract user ID from the Authorization header.
Returns None if no valid token is found (unauthenticated request).
"""
auth_header = request.headers.get("Authorization")
if not auth_header:
return None
# Parse Bearer token
parts = auth_header.split()
if len(parts) != 2 or parts[0].lower() != "bearer":
return None
token = parts[1]
# Decode the token to get user ID
try:
payload = decode_access_token(token)
if payload is None:
return None
return payload.get("sub")
except Exception as e:
logger.debug("Failed to decode token for CSRF validation: %s", e)
return None
def require_csrf_token(func: Callable) -> Callable:
"""
Decorator to require CSRF token validation for an endpoint.

View File

@@ -0,0 +1,187 @@
"""Error message sanitization middleware for production environments.
This middleware intercepts error responses and sanitizes them to prevent
information disclosure in production environments. Detailed error messages
are only shown when DEBUG mode is enabled.
"""
import json
import logging
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from app.core.config import settings
logger = logging.getLogger(__name__)
# Generic error messages for production
GENERIC_ERROR_MESSAGES = {
400: "Bad Request",
401: "Authentication required",
403: "Access denied",
404: "Resource not found",
405: "Method not allowed",
409: "Request conflict",
422: "Validation error",
429: "Too many requests",
500: "Internal server error",
502: "Service unavailable",
503: "Service temporarily unavailable",
504: "Request timeout",
}
# Status codes that should preserve their original message even in production
# These are typically user-facing validation errors that don't leak sensitive info
PRESERVE_MESSAGE_CODES = {
400, # Bad request - users need to know what's wrong with their request
401, # Unauthorized - users need to know why auth failed
403, # Forbidden - users need to know what permission they lack
404, # Not found - usually safe to preserve
409, # Conflict - users need to know about conflicts
422, # Validation errors - users need to know what to fix
}
# Patterns that indicate sensitive information in error messages
SENSITIVE_PATTERNS = [
"traceback",
"stack trace",
"file path",
"/usr/",
"/var/",
"/home/",
"connection refused",
"connection error",
"timeout connecting",
"database error",
"sql",
"query failed",
"password",
"secret",
"token",
"key=",
"credentials",
".py line",
"exception in",
]
def _contains_sensitive_info(message: str) -> bool:
"""Check if an error message contains potentially sensitive information."""
if not message:
return False
message_lower = message.lower()
return any(pattern.lower() in message_lower for pattern in SENSITIVE_PATTERNS)
def _sanitize_detail(detail: any, status_code: int) -> any:
"""Sanitize error detail, removing sensitive information in production.
Args:
detail: The error detail (can be string, list, or dict)
status_code: The HTTP status code
Returns:
Sanitized detail for production, or original detail for debug mode
"""
# In debug mode, return original detail
if settings.DEBUG:
return detail
# For preserved status codes, keep the detail if it doesn't contain sensitive info
if status_code in PRESERVE_MESSAGE_CODES:
if isinstance(detail, str) and not _contains_sensitive_info(detail):
return detail
if isinstance(detail, list):
# For validation errors (list of dicts), keep the structure but sanitize
sanitized = []
for item in detail:
if isinstance(item, dict):
# Keep loc, msg, type for pydantic validation errors
sanitized_item = {}
if 'loc' in item:
sanitized_item['loc'] = item['loc']
if 'msg' in item and not _contains_sensitive_info(str(item['msg'])):
sanitized_item['msg'] = item['msg']
else:
sanitized_item['msg'] = 'Validation failed'
if 'type' in item:
sanitized_item['type'] = item['type']
sanitized.append(sanitized_item)
else:
sanitized.append(item if not _contains_sensitive_info(str(item)) else 'Invalid value')
return sanitized
return detail
# For other status codes, use generic message
return GENERIC_ERROR_MESSAGES.get(status_code, "An error occurred")
class ErrorSanitizerMiddleware(BaseHTTPMiddleware):
"""Middleware to sanitize error responses in production.
This middleware:
1. Intercepts error responses (4xx and 5xx status codes)
2. Parses JSON response bodies
3. Sanitizes the 'detail' field to remove sensitive information
4. Returns the sanitized response
In DEBUG mode, original error messages are preserved for development.
"""
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# Only process error responses with JSON content
if response.status_code < 400:
return response
content_type = response.headers.get("content-type", "")
if "application/json" not in content_type:
return response
# Read the response body
body = b""
async for chunk in response.body_iterator:
body += chunk
if not body:
return response
try:
data = json.loads(body)
except (json.JSONDecodeError, UnicodeDecodeError):
# Not valid JSON, return as-is
return Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
# Sanitize the detail field if present
if "detail" in data:
original_detail = data["detail"]
data["detail"] = _sanitize_detail(original_detail, response.status_code)
# Log the original error in production for debugging
if not settings.DEBUG and original_detail != data["detail"]:
logger.warning(
"Sanitized error response",
extra={
"status_code": response.status_code,
"path": str(request.url.path),
"method": request.method,
"original_detail_length": len(str(original_detail)),
}
)
# Return the sanitized response
return JSONResponse(
content=data,
status_code=response.status_code,
headers={
k: v for k, v in response.headers.items()
if k.lower() not in ("content-length", "content-type")
},
)

View File

@@ -0,0 +1,215 @@
"""Security audit middleware for logging access denials and suspicious auth patterns."""
import time
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Tuple
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.models import AuditLog, AuditAction
from app.services.audit_service import AuditService
# In-memory storage for tracking auth failures
# Structure: {ip_address: [(timestamp, path), ...]}
_auth_failure_tracker: Dict[str, List[Tuple[float, str]]] = defaultdict(list)
# Configuration constants
AUTH_FAILURE_THRESHOLD = 5 # Number of failures to trigger suspicious pattern alert
AUTH_FAILURE_WINDOW_SECONDS = 600 # 10 minutes window
def _cleanup_old_failures(ip: str) -> None:
"""Remove auth failures older than the tracking window."""
if ip not in _auth_failure_tracker:
return
cutoff = time.time() - AUTH_FAILURE_WINDOW_SECONDS
_auth_failure_tracker[ip] = [
(ts, path) for ts, path in _auth_failure_tracker[ip]
if ts > cutoff
]
# Clean up empty entries
if not _auth_failure_tracker[ip]:
del _auth_failure_tracker[ip]
def _track_auth_failure(ip: str, path: str) -> int:
"""Track an auth failure and return the count in the window."""
_cleanup_old_failures(ip)
_auth_failure_tracker[ip].append((time.time(), path))
return len(_auth_failure_tracker[ip])
def _get_recent_failures(ip: str) -> List[str]:
"""Get list of paths that failed auth for this IP."""
_cleanup_old_failures(ip)
return [path for _, path in _auth_failure_tracker.get(ip, [])]
class SecurityAuditMiddleware(BaseHTTPMiddleware):
"""Middleware to audit security-related events like 401/403 responses."""
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# Only process 401 and 403 responses
if response.status_code not in (401, 403):
return response
# Get client IP from audit metadata if available
ip_address = self._get_client_ip(request)
path = str(request.url.path)
method = request.method
# Get user_id if available from request state (set by auth middleware)
user_id = getattr(request.state, 'user_id', None)
db: Session = SessionLocal()
try:
if response.status_code == 403:
self._log_access_denied(db, ip_address, path, method, user_id, request)
elif response.status_code == 401:
self._log_auth_failure(db, ip_address, path, method, request)
db.commit()
except Exception:
db.rollback()
# Don't fail the request due to audit logging errors
finally:
db.close()
return response
def _get_client_ip(self, request: Request) -> str:
"""Get the real client IP address from request."""
# Check for audit metadata first (set by AuditMiddleware)
audit_metadata = getattr(request.state, 'audit_metadata', None)
if audit_metadata and 'ip_address' in audit_metadata:
return audit_metadata['ip_address']
# Fallback to checking headers directly
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"
def _log_access_denied(
self,
db: Session,
ip_address: str,
path: str,
method: str,
user_id: str | None,
request: Request,
) -> None:
"""Log a 403 Forbidden response to the audit trail."""
request_metadata = {
"ip_address": ip_address,
"user_agent": request.headers.get("user-agent", ""),
"method": method,
"path": path,
}
# Try to extract resource info from path
resource_type = self._extract_resource_type(path)
AuditService.log_event(
db=db,
event_type="security.access_denied",
resource_type=resource_type,
action=AuditAction.ACCESS_DENIED,
user_id=user_id,
resource_id=None,
changes=[{
"attempted_path": path,
"attempted_method": method,
"ip_address": ip_address,
}],
request_metadata=request_metadata,
)
def _log_auth_failure(
self,
db: Session,
ip_address: str,
path: str,
method: str,
request: Request,
) -> None:
"""Log a 401 Unauthorized response and check for suspicious patterns."""
# Track this failure
failure_count = _track_auth_failure(ip_address, path)
request_metadata = {
"ip_address": ip_address,
"user_agent": request.headers.get("user-agent", ""),
"method": method,
"path": path,
}
resource_type = self._extract_resource_type(path)
# Log the auth failure
AuditService.log_event(
db=db,
event_type="security.auth_failed",
resource_type=resource_type,
action=AuditAction.AUTH_FAILED,
user_id=None, # No user for 401
resource_id=None,
changes=[{
"attempted_path": path,
"attempted_method": method,
"ip_address": ip_address,
"failure_count_in_window": failure_count,
}],
request_metadata=request_metadata,
)
# Check for suspicious pattern
if failure_count >= AUTH_FAILURE_THRESHOLD:
recent_paths = _get_recent_failures(ip_address)
AuditService.log_event(
db=db,
event_type="security.suspicious_auth_pattern",
resource_type="security",
action=AuditAction.AUTH_FAILED,
user_id=None,
resource_id=None,
changes=[{
"ip_address": ip_address,
"failure_count": failure_count,
"window_minutes": AUTH_FAILURE_WINDOW_SECONDS // 60,
"attempted_paths": list(set(recent_paths)),
}],
request_metadata=request_metadata,
)
def _extract_resource_type(self, path: str) -> str:
"""Extract resource type from path for audit logging."""
# Remove /api/ or /api/v1/ prefix
clean_path = path
if clean_path.startswith("/api/v1/"):
clean_path = clean_path[8:]
elif clean_path.startswith("/api/"):
clean_path = clean_path[5:]
# Get the first path segment as resource type
parts = clean_path.strip("/").split("/")
if parts and parts[0]:
return parts[0]
return "unknown"

View File

@@ -13,6 +13,8 @@ class AuditAction(str, enum.Enum):
RESTORE = "restore"
LOGIN = "login"
LOGOUT = "logout"
ACCESS_DENIED = "access_denied"
AUTH_FAILED = "auth_failed"
class SensitivityLevel(str, enum.Enum):
@@ -42,10 +44,20 @@ EVENT_SENSITIVITY = {
"attachment.upload": SensitivityLevel.LOW,
"attachment.download": SensitivityLevel.LOW,
"attachment.delete": SensitivityLevel.MEDIUM,
# Security events
"security.access_denied": SensitivityLevel.MEDIUM,
"security.auth_failed": SensitivityLevel.MEDIUM,
"security.suspicious_auth_pattern": SensitivityLevel.HIGH,
}
# Events that should trigger alerts
ALERT_EVENTS = {"project.delete", "user.permission_change", "user.admin_change", "role.permission_change"}
ALERT_EVENTS = {
"project.delete",
"user.permission_change",
"user.admin_change",
"role.permission_change",
"security.suspicious_auth_pattern",
}
class AuditLog(Base):
@@ -57,7 +69,7 @@ class AuditLog(Base):
resource_id = Column(String(36), nullable=True)
user_id = Column(String(36), ForeignKey("pjctrl_users.id", ondelete="SET NULL"), nullable=True)
action = Column(
Enum("create", "update", "delete", "restore", "login", "logout", name="audit_action_enum"),
Enum("create", "update", "delete", "restore", "login", "logout", "access_denied", "auth_failed", name="audit_action_enum"),
nullable=False
)
changes = Column(JSON, nullable=True)

View File

@@ -9,10 +9,25 @@ class LoginRequest(BaseModel):
class LoginResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int = Field(default=3600, description="Access token expiry in seconds")
user: "UserInfo"
class RefreshTokenRequest(BaseModel):
"""Request body for refresh token endpoint."""
refresh_token: str = Field(..., description="The refresh token to use for obtaining a new access token")
class RefreshTokenResponse(BaseModel):
"""Response for refresh token endpoint."""
access_token: str
refresh_token: str # New refresh token (rotation)
token_type: str = "bearer"
expires_in: int = Field(default=3600, description="Access token expiry in seconds")
class UserInfo(BaseModel):
id: str
email: str

View File

@@ -4,12 +4,12 @@ from pydantic import BaseModel, Field
class CommentCreate(BaseModel):
content: str = Field(..., min_length=1, max_length=10000)
content: str = Field(..., min_length=1, max_length=5000)
parent_comment_id: Optional[str] = None
class CommentUpdate(BaseModel):
content: str = Field(..., min_length=1, max_length=10000)
content: str = Field(..., min_length=1, max_length=5000)
class CommentAuthor(BaseModel):

View File

@@ -25,7 +25,7 @@ class CustomFieldDefinition(BaseModel):
class ProjectTemplateBase(BaseModel):
"""Base schema for project template."""
name: str = Field(..., min_length=1, max_length=200)
description: Optional[str] = None
description: Optional[str] = Field(None, max_length=2000)
is_public: bool = Field(default=False)
task_statuses: Optional[List[TaskStatusDefinition]] = None
custom_fields: Optional[List[CustomFieldDefinition]] = None
@@ -43,7 +43,7 @@ class ProjectTemplateCreate(ProjectTemplateBase):
class ProjectTemplateUpdate(BaseModel):
"""Schema for updating a project template."""
name: Optional[str] = Field(None, min_length=1, max_length=200)
description: Optional[str] = None
description: Optional[str] = Field(None, max_length=2000)
is_public: Optional[bool] = None
task_statuses: Optional[List[TaskStatusDefinition]] = None
custom_fields: Optional[List[CustomFieldDefinition]] = None

View File

@@ -4,6 +4,7 @@ import logging
from typing import Dict, Set, Optional, Tuple
from fastapi import WebSocket
from app.core.redis import get_redis_sync
from app.core.config import settings
logger = logging.getLogger(__name__)
@@ -19,13 +20,48 @@ class ConnectionManager:
self._lock = asyncio.Lock()
self._project_lock = asyncio.Lock()
async def check_connection_limit(self, user_id: str) -> Tuple[bool, Optional[str]]:
"""
Check if user can create a new WebSocket connection.
Args:
user_id: The user's ID
Returns:
Tuple of (can_connect: bool, reject_reason: str | None)
- can_connect: True if user is within connection limit
- reject_reason: Error message if connection should be rejected
"""
max_connections = settings.MAX_WEBSOCKET_CONNECTIONS_PER_USER
async with self._lock:
current_count = len(self.active_connections.get(user_id, set()))
if current_count >= max_connections:
logger.warning(
f"User {user_id} exceeded WebSocket connection limit "
f"({current_count}/{max_connections})"
)
return False, "Too many connections"
return True, None
def get_user_connection_count(self, user_id: str) -> int:
"""Get the current number of WebSocket connections for a user."""
return len(self.active_connections.get(user_id, set()))
async def connect(self, websocket: WebSocket, user_id: str):
"""Accept and track a new WebSocket connection."""
await websocket.accept()
"""
Track a new WebSocket connection.
Note: WebSocket must already be accepted before calling this method.
Connection limit should be checked via check_connection_limit() before calling.
"""
async with self._lock:
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
logger.debug(
f"User {user_id} connected. Total connections: "
f"{len(self.active_connections[user_id])}"
)
async def disconnect(self, websocket: WebSocket, user_id: str):
"""Remove a WebSocket connection."""