feat: implement 5 QA-driven security and quality proposals
Implemented proposals from comprehensive QA review: 1. extend-csrf-protection - Add POST to CSRF protected methods in frontend - Global CSRF middleware for all state-changing operations - Update tests with CSRF token fixtures 2. tighten-cors-websocket-security - Replace wildcard CORS with explicit method/header lists - Disable query parameter auth in production (code 4002) - Add per-user WebSocket connection limit (max 5, code 4005) 3. shorten-jwt-expiry - Reduce JWT expiry from 7 days to 60 minutes - Add refresh token support with 7-day expiry - Implement token rotation on refresh - Frontend auto-refresh when token near expiry (<5 min) 4. fix-frontend-quality - Add React.lazy() code splitting for all pages - Fix useCallback dependency arrays (Dashboard, Comments) - Add localStorage data validation in AuthContext - Complete i18n for AttachmentUpload component 5. enhance-backend-validation - Add SecurityAuditMiddleware for access denied logging - Add ErrorSanitizerMiddleware for production error messages - Protect /health/detailed with admin authentication - Add input length validation (comment 5000, desc 10000) All 521 backend tests passing. Frontend builds successfully. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -3,12 +3,28 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.core.security import create_access_token, create_token_payload
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_token_payload,
|
||||
generate_refresh_token,
|
||||
store_refresh_token,
|
||||
validate_refresh_token,
|
||||
invalidate_refresh_token,
|
||||
invalidate_all_user_refresh_tokens,
|
||||
decode_refresh_token_user_id,
|
||||
)
|
||||
from app.core.redis import get_redis
|
||||
from app.core.rate_limiter import limiter
|
||||
from app.models.user import User
|
||||
from app.models.audit_log import AuditAction
|
||||
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo, CSRFTokenResponse
|
||||
from app.schemas.auth import (
|
||||
LoginRequest,
|
||||
LoginResponse,
|
||||
UserInfo,
|
||||
CSRFTokenResponse,
|
||||
RefreshTokenRequest,
|
||||
RefreshTokenResponse,
|
||||
)
|
||||
from app.services.auth_client import (
|
||||
verify_credentials,
|
||||
AuthAPIError,
|
||||
@@ -119,6 +135,9 @@ async def login(
|
||||
# Create access token
|
||||
access_token = create_access_token(token_data)
|
||||
|
||||
# Generate refresh token
|
||||
refresh_token = generate_refresh_token()
|
||||
|
||||
# Store session in Redis (sync with JWT expiry)
|
||||
redis_client.setex(
|
||||
f"session:{user.id}",
|
||||
@@ -126,6 +145,9 @@ async def login(
|
||||
access_token,
|
||||
)
|
||||
|
||||
# Store refresh token in Redis with user binding
|
||||
store_refresh_token(redis_client, user.id, refresh_token)
|
||||
|
||||
# Log successful login
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
@@ -141,6 +163,8 @@ async def login(
|
||||
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=settings.JWT_EXPIRE_MINUTES * 60,
|
||||
user=UserInfo(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
@@ -158,14 +182,114 @@ async def logout(
|
||||
redis_client=Depends(get_redis),
|
||||
):
|
||||
"""
|
||||
Logout user and invalidate session.
|
||||
Logout user and invalidate session and all refresh tokens.
|
||||
"""
|
||||
# Remove session from Redis
|
||||
redis_client.delete(f"session:{current_user.id}")
|
||||
|
||||
# Invalidate all refresh tokens for this user
|
||||
invalidate_all_user_refresh_tokens(redis_client, current_user.id)
|
||||
|
||||
return {"detail": "Successfully logged out"}
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=RefreshTokenResponse)
|
||||
@limiter.limit("10/minute")
|
||||
async def refresh_access_token(
|
||||
request: Request,
|
||||
refresh_request: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db),
|
||||
redis_client=Depends(get_redis),
|
||||
):
|
||||
"""
|
||||
Refresh access token using a valid refresh token.
|
||||
|
||||
This endpoint implements refresh token rotation:
|
||||
- Validates the provided refresh token
|
||||
- Issues a new access token
|
||||
- Issues a new refresh token (rotating the old one)
|
||||
- Invalidates the old refresh token
|
||||
|
||||
This provides enhanced security by ensuring refresh tokens are single-use.
|
||||
"""
|
||||
old_refresh_token = refresh_request.refresh_token
|
||||
|
||||
# Find the user ID associated with this refresh token
|
||||
user_id = decode_refresh_token_user_id(old_refresh_token, redis_client)
|
||||
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Validate the refresh token is still valid and bound to this user
|
||||
if not validate_refresh_token(redis_client, user_id, old_refresh_token):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
if user is None:
|
||||
# Invalidate the token since user no longer exists
|
||||
invalidate_refresh_token(redis_client, user_id, old_refresh_token)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
# Invalidate all tokens for disabled user
|
||||
invalidate_all_user_refresh_tokens(redis_client, user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User account is disabled",
|
||||
)
|
||||
|
||||
# Invalidate the old refresh token (rotation)
|
||||
invalidate_refresh_token(redis_client, user_id, old_refresh_token)
|
||||
|
||||
# Get role name
|
||||
role_name = user.role.name if user.role else None
|
||||
|
||||
# Create new token payload
|
||||
token_data = create_token_payload(
|
||||
user_id=user.id,
|
||||
email=user.email,
|
||||
role=role_name,
|
||||
department_id=user.department_id,
|
||||
is_system_admin=user.is_system_admin,
|
||||
)
|
||||
|
||||
# Create new access token
|
||||
new_access_token = create_access_token(token_data)
|
||||
|
||||
# Generate new refresh token (rotation)
|
||||
new_refresh_token = generate_refresh_token()
|
||||
|
||||
# Store new session in Redis
|
||||
redis_client.setex(
|
||||
f"session:{user.id}",
|
||||
settings.JWT_EXPIRE_MINUTES * 60,
|
||||
new_access_token,
|
||||
)
|
||||
|
||||
# Store new refresh token
|
||||
store_refresh_token(redis_client, user.id, new_refresh_token)
|
||||
|
||||
return RefreshTokenResponse(
|
||||
access_token=new_access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
expires_in=settings.JWT_EXPIRE_MINUTES * 60,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserInfo)
|
||||
async def get_current_user_info(
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from app.core import database
|
||||
from app.core.security import decode_access_token
|
||||
from app.core.redis import get_redis_sync
|
||||
from app.core.config import settings
|
||||
from app.models import User, Notification, Project
|
||||
from app.services.websocket_manager import manager
|
||||
from app.core.redis_pubsub import NotificationSubscriber, ProjectTaskSubscriber
|
||||
@@ -72,14 +73,24 @@ async def authenticate_websocket(
|
||||
Supports two authentication methods:
|
||||
1. First message authentication (preferred, more secure)
|
||||
- Client sends: {"type": "auth", "token": "<jwt_token>"}
|
||||
2. Query parameter authentication (deprecated, for backward compatibility)
|
||||
2. Query parameter authentication (disabled in production, for backward compatibility only)
|
||||
- Client connects with: ?token=<jwt_token>
|
||||
|
||||
Returns:
|
||||
Tuple of (user_id, error_reason). user_id is None if authentication fails.
|
||||
Error reasons: "invalid_token", "invalid_message", "missing_token",
|
||||
"timeout", "error", "query_auth_disabled"
|
||||
"""
|
||||
# If token provided via query parameter (backward compatibility)
|
||||
if query_token:
|
||||
# Reject query parameter auth in production for security
|
||||
if settings.ENVIRONMENT == "production":
|
||||
logger.warning(
|
||||
"WebSocket query parameter authentication attempted in production environment. "
|
||||
"This is disabled for security reasons."
|
||||
)
|
||||
return None, "query_auth_disabled"
|
||||
|
||||
logger.warning(
|
||||
"WebSocket authentication via query parameter is deprecated. "
|
||||
"Please use first-message authentication for better security."
|
||||
@@ -195,9 +206,21 @@ async def websocket_notifications(
|
||||
user_id, error_reason = await authenticate_websocket(websocket, token)
|
||||
|
||||
if user_id is None:
|
||||
if error_reason == "invalid_token":
|
||||
if error_reason == "query_auth_disabled":
|
||||
await websocket.send_json({"type": "error", "message": "Query parameter auth disabled in production"})
|
||||
await websocket.close(code=4002, reason="Query parameter auth disabled in production")
|
||||
elif error_reason == "invalid_token":
|
||||
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
else:
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
return
|
||||
|
||||
# Check connection limit before accepting
|
||||
can_connect, reject_reason = await manager.check_connection_limit(user_id)
|
||||
if not can_connect:
|
||||
await websocket.send_json({"type": "error", "message": reject_reason})
|
||||
await websocket.close(code=4005, reason=reject_reason)
|
||||
return
|
||||
|
||||
await manager.connect(websocket, user_id)
|
||||
@@ -394,9 +417,21 @@ async def websocket_project_sync(
|
||||
user_id, error_reason = await authenticate_websocket(websocket, token)
|
||||
|
||||
if user_id is None:
|
||||
if error_reason == "invalid_token":
|
||||
if error_reason == "query_auth_disabled":
|
||||
await websocket.send_json({"type": "error", "message": "Query parameter auth disabled in production"})
|
||||
await websocket.close(code=4002, reason="Query parameter auth disabled in production")
|
||||
elif error_reason == "invalid_token":
|
||||
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
else:
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
return
|
||||
|
||||
# Check connection limit before accepting
|
||||
can_connect, reject_reason = await manager.check_connection_limit(user_id)
|
||||
if not can_connect:
|
||||
await websocket.send_json({"type": "error", "message": reject_reason})
|
||||
await websocket.close(code=4005, reason=reject_reason)
|
||||
return
|
||||
|
||||
# Verify user has access to the project
|
||||
|
||||
@@ -28,7 +28,8 @@ class Settings(BaseSettings):
|
||||
# JWT - Must be set in environment, no default allowed
|
||||
JWT_SECRET_KEY: str = ""
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_EXPIRE_MINUTES: int = 10080 # 7 days
|
||||
JWT_EXPIRE_MINUTES: int = 60 # 1 hour (short-lived access token)
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # Refresh token valid for 7 days
|
||||
|
||||
@field_validator("JWT_SECRET_KEY")
|
||||
@classmethod
|
||||
@@ -127,6 +128,12 @@ class Settings(BaseSettings):
|
||||
QUERY_LOGGING: bool = False # Enable SQLAlchemy query logging
|
||||
QUERY_COUNT_THRESHOLD: int = 10 # Warn when query count exceeds this threshold
|
||||
|
||||
# Environment
|
||||
ENVIRONMENT: str = "development" # Options: development, staging, production
|
||||
|
||||
# WebSocket Settings
|
||||
MAX_WEBSOCKET_CONNECTIONS_PER_USER: int = 5 # Maximum concurrent WebSocket connections per user
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
@@ -356,3 +356,140 @@ def create_token_payload(
|
||||
"department_id": department_id,
|
||||
"is_system_admin": is_system_admin,
|
||||
}
|
||||
|
||||
|
||||
# Refresh Token Functions
|
||||
REFRESH_TOKEN_BYTES = 32
|
||||
|
||||
|
||||
def generate_refresh_token() -> str:
|
||||
"""
|
||||
Generate a cryptographically secure refresh token.
|
||||
|
||||
Returns:
|
||||
A URL-safe base64-encoded random token
|
||||
"""
|
||||
return secrets.token_urlsafe(REFRESH_TOKEN_BYTES)
|
||||
|
||||
|
||||
def get_refresh_token_key(user_id: str, token: str) -> str:
|
||||
"""
|
||||
Generate the Redis key for a refresh token.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
token: The refresh token
|
||||
|
||||
Returns:
|
||||
Redis key string
|
||||
"""
|
||||
# Hash the token to avoid storing it directly as a key
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
|
||||
return f"refresh_token:{user_id}:{token_hash}"
|
||||
|
||||
|
||||
def store_refresh_token(redis_client, user_id: str, token: str) -> None:
|
||||
"""
|
||||
Store a refresh token in Redis with user binding.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
user_id: The user's ID
|
||||
token: The refresh token to store
|
||||
"""
|
||||
key = get_refresh_token_key(user_id, token)
|
||||
# Store with TTL based on REFRESH_TOKEN_EXPIRE_DAYS
|
||||
ttl_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
|
||||
redis_client.setex(key, ttl_seconds, user_id)
|
||||
|
||||
|
||||
def validate_refresh_token(redis_client, user_id: str, token: str) -> bool:
|
||||
"""
|
||||
Validate a refresh token exists in Redis and is bound to the user.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
user_id: The expected user ID
|
||||
token: The refresh token to validate
|
||||
|
||||
Returns:
|
||||
True if token is valid, False otherwise
|
||||
"""
|
||||
key = get_refresh_token_key(user_id, token)
|
||||
stored_user_id = redis_client.get(key)
|
||||
|
||||
if stored_user_id is None:
|
||||
return False
|
||||
|
||||
# Handle Redis bytes type
|
||||
if isinstance(stored_user_id, bytes):
|
||||
stored_user_id = stored_user_id.decode("utf-8")
|
||||
|
||||
return stored_user_id == user_id
|
||||
|
||||
|
||||
def invalidate_refresh_token(redis_client, user_id: str, token: str) -> bool:
|
||||
"""
|
||||
Invalidate (delete) a refresh token from Redis.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
user_id: The user's ID
|
||||
token: The refresh token to invalidate
|
||||
|
||||
Returns:
|
||||
True if token was deleted, False if it didn't exist
|
||||
"""
|
||||
key = get_refresh_token_key(user_id, token)
|
||||
result = redis_client.delete(key)
|
||||
return result > 0 if isinstance(result, int) else bool(result)
|
||||
|
||||
|
||||
def invalidate_all_user_refresh_tokens(redis_client, user_id: str) -> int:
|
||||
"""
|
||||
Invalidate all refresh tokens for a user.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
Number of tokens invalidated
|
||||
"""
|
||||
pattern = f"refresh_token:{user_id}:*"
|
||||
count = 0
|
||||
for key in redis_client.scan_iter(match=pattern):
|
||||
redis_client.delete(key)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def decode_refresh_token_user_id(token: str, redis_client) -> Optional[str]:
|
||||
"""
|
||||
Find the user ID associated with a refresh token by searching Redis.
|
||||
|
||||
This is used when we only have the token and need to find which user it belongs to.
|
||||
Note: This is less efficient but necessary for refresh token validation when
|
||||
the user_id is not provided in the request.
|
||||
|
||||
Args:
|
||||
token: The refresh token
|
||||
redis_client: Redis client instance
|
||||
|
||||
Returns:
|
||||
User ID if found, None otherwise
|
||||
"""
|
||||
# We need to search for the token across all users
|
||||
# This is done by checking the token hash pattern
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
|
||||
pattern = f"refresh_token:*:{token_hash}"
|
||||
|
||||
for key in redis_client.scan_iter(match=pattern):
|
||||
# Extract user_id from key format: refresh_token:{user_id}:{token_hash}
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode("utf-8")
|
||||
parts = key.split(":")
|
||||
if len(parts) == 3:
|
||||
return parts[1]
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from fastapi import FastAPI, Request, APIRouter
|
||||
from fastapi import FastAPI, Request, APIRouter, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from slowapi import _rate_limit_exceeded_handler
|
||||
@@ -9,6 +9,9 @@ from slowapi.errors import RateLimitExceeded
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.middleware.audit import AuditMiddleware
|
||||
from app.middleware.csrf import CSRFMiddleware
|
||||
from app.middleware.security_audit import SecurityAuditMiddleware
|
||||
from app.middleware.error_sanitizer import ErrorSanitizerMiddleware
|
||||
from app.core.scheduler import start_scheduler, shutdown_scheduler, scheduler
|
||||
from app.core.rate_limiter import limiter
|
||||
from app.core.deprecation import DeprecationMiddleware
|
||||
@@ -61,6 +64,8 @@ from app.core.database import get_pool_status, engine
|
||||
from app.core.redis import redis_client
|
||||
from app.services.notification_service import get_redis_fallback_status
|
||||
from app.services.file_storage_service import file_storage_service
|
||||
from app.middleware.auth import require_system_admin
|
||||
from app.models import User
|
||||
|
||||
app = FastAPI(
|
||||
title="Project Control API",
|
||||
@@ -73,18 +78,28 @@ app = FastAPI(
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# CORS middleware
|
||||
# CORS middleware - Explicit methods and headers for security
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "X-CSRF-Token", "X-Request-ID"],
|
||||
)
|
||||
|
||||
# Error sanitizer middleware - sanitizes error messages in production
|
||||
# Must be first in the chain to intercept all error responses
|
||||
app.add_middleware(ErrorSanitizerMiddleware)
|
||||
|
||||
# Audit middleware - extracts request metadata for audit logging
|
||||
app.add_middleware(AuditMiddleware)
|
||||
|
||||
# Security audit middleware - logs 401/403 responses to audit trail
|
||||
app.add_middleware(SecurityAuditMiddleware)
|
||||
|
||||
# CSRF middleware - validates CSRF tokens for state-changing requests
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# Deprecation middleware - adds deprecation headers to legacy /api/ routes
|
||||
app.add_middleware(DeprecationMiddleware)
|
||||
|
||||
@@ -252,14 +267,20 @@ async def readiness_check():
|
||||
|
||||
|
||||
@app.get("/health/detailed")
|
||||
async def detailed_health_check():
|
||||
"""Detailed health check endpoint.
|
||||
async def detailed_health_check(
|
||||
current_user: User = Depends(require_system_admin),
|
||||
):
|
||||
"""Detailed health check endpoint (requires system admin).
|
||||
|
||||
Returns comprehensive status of all system components:
|
||||
- database: Connection pool status and connectivity
|
||||
- redis: Connection status and fallback queue status
|
||||
- storage: File storage validation status
|
||||
- scheduler: Background job scheduler status
|
||||
|
||||
Note: This endpoint requires system admin authentication because it exposes
|
||||
sensitive infrastructure details including connection pool statistics and
|
||||
internal service states.
|
||||
"""
|
||||
db_health = check_database_health()
|
||||
redis_health = check_redis_health()
|
||||
|
||||
@@ -1,38 +1,55 @@
|
||||
"""
|
||||
CSRF (Cross-Site Request Forgery) Protection Middleware.
|
||||
|
||||
This module provides CSRF protection for sensitive state-changing operations.
|
||||
It validates CSRF tokens for specified protected endpoints.
|
||||
This module provides CSRF protection for all state-changing operations.
|
||||
It validates CSRF tokens globally for authenticated POST, PUT, PATCH, DELETE requests.
|
||||
"""
|
||||
|
||||
from fastapi import Request, HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from typing import Optional, Callable, List
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from fastapi import HTTPException, status
|
||||
from typing import Optional, Callable, List, Set
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
from app.core.security import validate_csrf_token, generate_csrf_token
|
||||
from app.core.security import validate_csrf_token, generate_csrf_token, decode_access_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Header name for CSRF token
|
||||
CSRF_TOKEN_HEADER = "X-CSRF-Token"
|
||||
|
||||
# List of endpoint patterns that require CSRF protection
|
||||
# These are sensitive state-changing operations
|
||||
CSRF_PROTECTED_PATTERNS = [
|
||||
# User operations
|
||||
"/api/v1/users/{user_id}/admin", # Admin status change
|
||||
"/api/users/{user_id}/admin", # Legacy
|
||||
# Password changes would go here if implemented
|
||||
# Delete operations
|
||||
"/api/attachments/{attachment_id}", # DELETE method
|
||||
"/api/tasks/{task_id}", # DELETE method (soft delete)
|
||||
"/api/projects/{project_id}", # DELETE method
|
||||
]
|
||||
# Methods that require CSRF protection (all state-changing operations)
|
||||
CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
|
||||
|
||||
# Methods that require CSRF protection
|
||||
CSRF_PROTECTED_METHODS = ["DELETE", "PUT", "PATCH"]
|
||||
# Safe methods that don't require CSRF protection
|
||||
CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
|
||||
|
||||
# Public endpoints that don't require CSRF validation
|
||||
# These are endpoints that either:
|
||||
# 1. Don't require authentication (login, health checks)
|
||||
# 2. Are not state-changing in a security-sensitive way
|
||||
CSRF_EXCLUDED_PATHS: Set[str] = {
|
||||
# Authentication endpoints (unauthenticated)
|
||||
"/api/auth/login",
|
||||
"/api/v1/auth/login",
|
||||
# Health check endpoints (unauthenticated)
|
||||
"/health",
|
||||
"/health/live",
|
||||
"/health/ready",
|
||||
"/health/detailed",
|
||||
# WebSocket endpoints (use different auth mechanism)
|
||||
"/api/ws",
|
||||
"/ws",
|
||||
}
|
||||
|
||||
# Path prefixes that are excluded from CSRF validation
|
||||
CSRF_EXCLUDED_PREFIXES: List[str] = [
|
||||
# WebSocket paths
|
||||
"/api/ws/",
|
||||
"/ws/",
|
||||
]
|
||||
|
||||
|
||||
class CSRFProtectionError(HTTPException):
|
||||
@@ -45,6 +62,114 @@ class CSRFProtectionError(HTTPException):
|
||||
)
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Global CSRF protection middleware.
|
||||
|
||||
Validates CSRF tokens for all authenticated state-changing requests
|
||||
(POST, PUT, PATCH, DELETE) except for explicitly excluded endpoints.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process the request and validate CSRF token if needed."""
|
||||
method = request.method.upper()
|
||||
path = request.url.path
|
||||
|
||||
# Skip CSRF validation for safe methods
|
||||
if method in CSRF_SAFE_METHODS:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip CSRF validation for excluded paths
|
||||
if self._is_excluded_path(path):
|
||||
logger.debug("CSRF validation skipped for excluded path: %s", path)
|
||||
return await call_next(request)
|
||||
|
||||
# Try to extract user ID from the Authorization header
|
||||
user_id = self._extract_user_id_from_token(request)
|
||||
|
||||
# If no user ID (unauthenticated request), skip CSRF validation
|
||||
# The authentication middleware will handle unauthorized access
|
||||
if user_id is None:
|
||||
logger.debug(
|
||||
"CSRF validation skipped (no auth token): %s %s",
|
||||
method, path
|
||||
)
|
||||
return await call_next(request)
|
||||
|
||||
# Get CSRF token from header
|
||||
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
|
||||
|
||||
if not csrf_token:
|
||||
logger.warning(
|
||||
"CSRF validation failed: Missing token for user %s on %s %s",
|
||||
user_id, method, path
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={"detail": "CSRF token is required"}
|
||||
)
|
||||
|
||||
# Validate the token
|
||||
is_valid, error_message = validate_csrf_token(csrf_token, user_id)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
"CSRF validation failed for user %s on %s %s: %s",
|
||||
user_id, method, path, error_message
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={"detail": error_message}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"CSRF validation passed for user %s on %s %s",
|
||||
user_id, method, path
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _is_excluded_path(self, path: str) -> bool:
|
||||
"""Check if the path is excluded from CSRF validation."""
|
||||
# Check exact path matches
|
||||
if path in CSRF_EXCLUDED_PATHS:
|
||||
return True
|
||||
|
||||
# Check path prefixes
|
||||
for prefix in CSRF_EXCLUDED_PREFIXES:
|
||||
if path.startswith(prefix):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_user_id_from_token(self, request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract user ID from the Authorization header.
|
||||
|
||||
Returns None if no valid token is found (unauthenticated request).
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return None
|
||||
|
||||
# Parse Bearer token
|
||||
parts = auth_header.split()
|
||||
if len(parts) != 2 or parts[0].lower() != "bearer":
|
||||
return None
|
||||
|
||||
token = parts[1]
|
||||
|
||||
# Decode the token to get user ID
|
||||
try:
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
return None
|
||||
return payload.get("sub")
|
||||
except Exception as e:
|
||||
logger.debug("Failed to decode token for CSRF validation: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def require_csrf_token(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to require CSRF token validation for an endpoint.
|
||||
|
||||
187
backend/app/middleware/error_sanitizer.py
Normal file
187
backend/app/middleware/error_sanitizer.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Error message sanitization middleware for production environments.
|
||||
|
||||
This middleware intercepts error responses and sanitizes them to prevent
|
||||
information disclosure in production environments. Detailed error messages
|
||||
are only shown when DEBUG mode is enabled.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Generic error messages for production
|
||||
GENERIC_ERROR_MESSAGES = {
|
||||
400: "Bad Request",
|
||||
401: "Authentication required",
|
||||
403: "Access denied",
|
||||
404: "Resource not found",
|
||||
405: "Method not allowed",
|
||||
409: "Request conflict",
|
||||
422: "Validation error",
|
||||
429: "Too many requests",
|
||||
500: "Internal server error",
|
||||
502: "Service unavailable",
|
||||
503: "Service temporarily unavailable",
|
||||
504: "Request timeout",
|
||||
}
|
||||
|
||||
# Status codes that should preserve their original message even in production
|
||||
# These are typically user-facing validation errors that don't leak sensitive info
|
||||
PRESERVE_MESSAGE_CODES = {
|
||||
400, # Bad request - users need to know what's wrong with their request
|
||||
401, # Unauthorized - users need to know why auth failed
|
||||
403, # Forbidden - users need to know what permission they lack
|
||||
404, # Not found - usually safe to preserve
|
||||
409, # Conflict - users need to know about conflicts
|
||||
422, # Validation errors - users need to know what to fix
|
||||
}
|
||||
|
||||
# Patterns that indicate sensitive information in error messages
|
||||
SENSITIVE_PATTERNS = [
|
||||
"traceback",
|
||||
"stack trace",
|
||||
"file path",
|
||||
"/usr/",
|
||||
"/var/",
|
||||
"/home/",
|
||||
"connection refused",
|
||||
"connection error",
|
||||
"timeout connecting",
|
||||
"database error",
|
||||
"sql",
|
||||
"query failed",
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"key=",
|
||||
"credentials",
|
||||
".py line",
|
||||
"exception in",
|
||||
]
|
||||
|
||||
|
||||
def _contains_sensitive_info(message: str) -> bool:
|
||||
"""Check if an error message contains potentially sensitive information."""
|
||||
if not message:
|
||||
return False
|
||||
message_lower = message.lower()
|
||||
return any(pattern.lower() in message_lower for pattern in SENSITIVE_PATTERNS)
|
||||
|
||||
|
||||
def _sanitize_detail(detail: any, status_code: int) -> any:
|
||||
"""Sanitize error detail, removing sensitive information in production.
|
||||
|
||||
Args:
|
||||
detail: The error detail (can be string, list, or dict)
|
||||
status_code: The HTTP status code
|
||||
|
||||
Returns:
|
||||
Sanitized detail for production, or original detail for debug mode
|
||||
"""
|
||||
# In debug mode, return original detail
|
||||
if settings.DEBUG:
|
||||
return detail
|
||||
|
||||
# For preserved status codes, keep the detail if it doesn't contain sensitive info
|
||||
if status_code in PRESERVE_MESSAGE_CODES:
|
||||
if isinstance(detail, str) and not _contains_sensitive_info(detail):
|
||||
return detail
|
||||
if isinstance(detail, list):
|
||||
# For validation errors (list of dicts), keep the structure but sanitize
|
||||
sanitized = []
|
||||
for item in detail:
|
||||
if isinstance(item, dict):
|
||||
# Keep loc, msg, type for pydantic validation errors
|
||||
sanitized_item = {}
|
||||
if 'loc' in item:
|
||||
sanitized_item['loc'] = item['loc']
|
||||
if 'msg' in item and not _contains_sensitive_info(str(item['msg'])):
|
||||
sanitized_item['msg'] = item['msg']
|
||||
else:
|
||||
sanitized_item['msg'] = 'Validation failed'
|
||||
if 'type' in item:
|
||||
sanitized_item['type'] = item['type']
|
||||
sanitized.append(sanitized_item)
|
||||
else:
|
||||
sanitized.append(item if not _contains_sensitive_info(str(item)) else 'Invalid value')
|
||||
return sanitized
|
||||
return detail
|
||||
|
||||
# For other status codes, use generic message
|
||||
return GENERIC_ERROR_MESSAGES.get(status_code, "An error occurred")
|
||||
|
||||
|
||||
class ErrorSanitizerMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to sanitize error responses in production.
|
||||
|
||||
This middleware:
|
||||
1. Intercepts error responses (4xx and 5xx status codes)
|
||||
2. Parses JSON response bodies
|
||||
3. Sanitizes the 'detail' field to remove sensitive information
|
||||
4. Returns the sanitized response
|
||||
|
||||
In DEBUG mode, original error messages are preserved for development.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
response = await call_next(request)
|
||||
|
||||
# Only process error responses with JSON content
|
||||
if response.status_code < 400:
|
||||
return response
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" not in content_type:
|
||||
return response
|
||||
|
||||
# Read the response body
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk
|
||||
|
||||
if not body:
|
||||
return response
|
||||
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
# Not valid JSON, return as-is
|
||||
return Response(
|
||||
content=body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
# Sanitize the detail field if present
|
||||
if "detail" in data:
|
||||
original_detail = data["detail"]
|
||||
data["detail"] = _sanitize_detail(original_detail, response.status_code)
|
||||
|
||||
# Log the original error in production for debugging
|
||||
if not settings.DEBUG and original_detail != data["detail"]:
|
||||
logger.warning(
|
||||
"Sanitized error response",
|
||||
extra={
|
||||
"status_code": response.status_code,
|
||||
"path": str(request.url.path),
|
||||
"method": request.method,
|
||||
"original_detail_length": len(str(original_detail)),
|
||||
}
|
||||
)
|
||||
|
||||
# Return the sanitized response
|
||||
return JSONResponse(
|
||||
content=data,
|
||||
status_code=response.status_code,
|
||||
headers={
|
||||
k: v for k, v in response.headers.items()
|
||||
if k.lower() not in ("content-length", "content-type")
|
||||
},
|
||||
)
|
||||
215
backend/app/middleware/security_audit.py
Normal file
215
backend/app/middleware/security_audit.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Security audit middleware for logging access denials and suspicious auth patterns."""
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List, Tuple
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.models import AuditLog, AuditAction
|
||||
from app.services.audit_service import AuditService
|
||||
|
||||
|
||||
# In-memory storage for tracking auth failures
|
||||
# Structure: {ip_address: [(timestamp, path), ...]}
|
||||
_auth_failure_tracker: Dict[str, List[Tuple[float, str]]] = defaultdict(list)
|
||||
|
||||
# Configuration constants
|
||||
AUTH_FAILURE_THRESHOLD = 5 # Number of failures to trigger suspicious pattern alert
|
||||
AUTH_FAILURE_WINDOW_SECONDS = 600 # 10 minutes window
|
||||
|
||||
|
||||
def _cleanup_old_failures(ip: str) -> None:
|
||||
"""Remove auth failures older than the tracking window."""
|
||||
if ip not in _auth_failure_tracker:
|
||||
return
|
||||
|
||||
cutoff = time.time() - AUTH_FAILURE_WINDOW_SECONDS
|
||||
_auth_failure_tracker[ip] = [
|
||||
(ts, path) for ts, path in _auth_failure_tracker[ip]
|
||||
if ts > cutoff
|
||||
]
|
||||
|
||||
# Clean up empty entries
|
||||
if not _auth_failure_tracker[ip]:
|
||||
del _auth_failure_tracker[ip]
|
||||
|
||||
|
||||
def _track_auth_failure(ip: str, path: str) -> int:
|
||||
"""Track an auth failure and return the count in the window."""
|
||||
_cleanup_old_failures(ip)
|
||||
_auth_failure_tracker[ip].append((time.time(), path))
|
||||
return len(_auth_failure_tracker[ip])
|
||||
|
||||
|
||||
def _get_recent_failures(ip: str) -> List[str]:
|
||||
"""Get list of paths that failed auth for this IP."""
|
||||
_cleanup_old_failures(ip)
|
||||
return [path for _, path in _auth_failure_tracker.get(ip, [])]
|
||||
|
||||
|
||||
class SecurityAuditMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to audit security-related events like 401/403 responses."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
response = await call_next(request)
|
||||
|
||||
# Only process 401 and 403 responses
|
||||
if response.status_code not in (401, 403):
|
||||
return response
|
||||
|
||||
# Get client IP from audit metadata if available
|
||||
ip_address = self._get_client_ip(request)
|
||||
path = str(request.url.path)
|
||||
method = request.method
|
||||
|
||||
# Get user_id if available from request state (set by auth middleware)
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
if response.status_code == 403:
|
||||
self._log_access_denied(db, ip_address, path, method, user_id, request)
|
||||
elif response.status_code == 401:
|
||||
self._log_auth_failure(db, ip_address, path, method, request)
|
||||
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
# Don't fail the request due to audit logging errors
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return response
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Get the real client IP address from request."""
|
||||
# Check for audit metadata first (set by AuditMiddleware)
|
||||
audit_metadata = getattr(request.state, 'audit_metadata', None)
|
||||
if audit_metadata and 'ip_address' in audit_metadata:
|
||||
return audit_metadata['ip_address']
|
||||
|
||||
# Fallback to checking headers directly
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _log_access_denied(
|
||||
self,
|
||||
db: Session,
|
||||
ip_address: str,
|
||||
path: str,
|
||||
method: str,
|
||||
user_id: str | None,
|
||||
request: Request,
|
||||
) -> None:
|
||||
"""Log a 403 Forbidden response to the audit trail."""
|
||||
request_metadata = {
|
||||
"ip_address": ip_address,
|
||||
"user_agent": request.headers.get("user-agent", ""),
|
||||
"method": method,
|
||||
"path": path,
|
||||
}
|
||||
|
||||
# Try to extract resource info from path
|
||||
resource_type = self._extract_resource_type(path)
|
||||
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="security.access_denied",
|
||||
resource_type=resource_type,
|
||||
action=AuditAction.ACCESS_DENIED,
|
||||
user_id=user_id,
|
||||
resource_id=None,
|
||||
changes=[{
|
||||
"attempted_path": path,
|
||||
"attempted_method": method,
|
||||
"ip_address": ip_address,
|
||||
}],
|
||||
request_metadata=request_metadata,
|
||||
)
|
||||
|
||||
def _log_auth_failure(
|
||||
self,
|
||||
db: Session,
|
||||
ip_address: str,
|
||||
path: str,
|
||||
method: str,
|
||||
request: Request,
|
||||
) -> None:
|
||||
"""Log a 401 Unauthorized response and check for suspicious patterns."""
|
||||
# Track this failure
|
||||
failure_count = _track_auth_failure(ip_address, path)
|
||||
|
||||
request_metadata = {
|
||||
"ip_address": ip_address,
|
||||
"user_agent": request.headers.get("user-agent", ""),
|
||||
"method": method,
|
||||
"path": path,
|
||||
}
|
||||
|
||||
resource_type = self._extract_resource_type(path)
|
||||
|
||||
# Log the auth failure
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="security.auth_failed",
|
||||
resource_type=resource_type,
|
||||
action=AuditAction.AUTH_FAILED,
|
||||
user_id=None, # No user for 401
|
||||
resource_id=None,
|
||||
changes=[{
|
||||
"attempted_path": path,
|
||||
"attempted_method": method,
|
||||
"ip_address": ip_address,
|
||||
"failure_count_in_window": failure_count,
|
||||
}],
|
||||
request_metadata=request_metadata,
|
||||
)
|
||||
|
||||
# Check for suspicious pattern
|
||||
if failure_count >= AUTH_FAILURE_THRESHOLD:
|
||||
recent_paths = _get_recent_failures(ip_address)
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="security.suspicious_auth_pattern",
|
||||
resource_type="security",
|
||||
action=AuditAction.AUTH_FAILED,
|
||||
user_id=None,
|
||||
resource_id=None,
|
||||
changes=[{
|
||||
"ip_address": ip_address,
|
||||
"failure_count": failure_count,
|
||||
"window_minutes": AUTH_FAILURE_WINDOW_SECONDS // 60,
|
||||
"attempted_paths": list(set(recent_paths)),
|
||||
}],
|
||||
request_metadata=request_metadata,
|
||||
)
|
||||
|
||||
def _extract_resource_type(self, path: str) -> str:
|
||||
"""Extract resource type from path for audit logging."""
|
||||
# Remove /api/ or /api/v1/ prefix
|
||||
clean_path = path
|
||||
if clean_path.startswith("/api/v1/"):
|
||||
clean_path = clean_path[8:]
|
||||
elif clean_path.startswith("/api/"):
|
||||
clean_path = clean_path[5:]
|
||||
|
||||
# Get the first path segment as resource type
|
||||
parts = clean_path.strip("/").split("/")
|
||||
if parts and parts[0]:
|
||||
return parts[0]
|
||||
|
||||
return "unknown"
|
||||
@@ -13,6 +13,8 @@ class AuditAction(str, enum.Enum):
|
||||
RESTORE = "restore"
|
||||
LOGIN = "login"
|
||||
LOGOUT = "logout"
|
||||
ACCESS_DENIED = "access_denied"
|
||||
AUTH_FAILED = "auth_failed"
|
||||
|
||||
|
||||
class SensitivityLevel(str, enum.Enum):
|
||||
@@ -42,10 +44,20 @@ EVENT_SENSITIVITY = {
|
||||
"attachment.upload": SensitivityLevel.LOW,
|
||||
"attachment.download": SensitivityLevel.LOW,
|
||||
"attachment.delete": SensitivityLevel.MEDIUM,
|
||||
# Security events
|
||||
"security.access_denied": SensitivityLevel.MEDIUM,
|
||||
"security.auth_failed": SensitivityLevel.MEDIUM,
|
||||
"security.suspicious_auth_pattern": SensitivityLevel.HIGH,
|
||||
}
|
||||
|
||||
# Events that should trigger alerts
|
||||
ALERT_EVENTS = {"project.delete", "user.permission_change", "user.admin_change", "role.permission_change"}
|
||||
ALERT_EVENTS = {
|
||||
"project.delete",
|
||||
"user.permission_change",
|
||||
"user.admin_change",
|
||||
"role.permission_change",
|
||||
"security.suspicious_auth_pattern",
|
||||
}
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
@@ -57,7 +69,7 @@ class AuditLog(Base):
|
||||
resource_id = Column(String(36), nullable=True)
|
||||
user_id = Column(String(36), ForeignKey("pjctrl_users.id", ondelete="SET NULL"), nullable=True)
|
||||
action = Column(
|
||||
Enum("create", "update", "delete", "restore", "login", "logout", name="audit_action_enum"),
|
||||
Enum("create", "update", "delete", "restore", "login", "logout", "access_denied", "auth_failed", name="audit_action_enum"),
|
||||
nullable=False
|
||||
)
|
||||
changes = Column(JSON, nullable=True)
|
||||
|
||||
@@ -9,10 +9,25 @@ class LoginRequest(BaseModel):
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = Field(default=3600, description="Access token expiry in seconds")
|
||||
user: "UserInfo"
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""Request body for refresh token endpoint."""
|
||||
refresh_token: str = Field(..., description="The refresh token to use for obtaining a new access token")
|
||||
|
||||
|
||||
class RefreshTokenResponse(BaseModel):
|
||||
"""Response for refresh token endpoint."""
|
||||
access_token: str
|
||||
refresh_token: str # New refresh token (rotation)
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = Field(default=3600, description="Access token expiry in seconds")
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
|
||||
@@ -4,12 +4,12 @@ from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CommentCreate(BaseModel):
|
||||
content: str = Field(..., min_length=1, max_length=10000)
|
||||
content: str = Field(..., min_length=1, max_length=5000)
|
||||
parent_comment_id: Optional[str] = None
|
||||
|
||||
|
||||
class CommentUpdate(BaseModel):
|
||||
content: str = Field(..., min_length=1, max_length=10000)
|
||||
content: str = Field(..., min_length=1, max_length=5000)
|
||||
|
||||
|
||||
class CommentAuthor(BaseModel):
|
||||
|
||||
@@ -25,7 +25,7 @@ class CustomFieldDefinition(BaseModel):
|
||||
class ProjectTemplateBase(BaseModel):
|
||||
"""Base schema for project template."""
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
description: Optional[str] = None
|
||||
description: Optional[str] = Field(None, max_length=2000)
|
||||
is_public: bool = Field(default=False)
|
||||
task_statuses: Optional[List[TaskStatusDefinition]] = None
|
||||
custom_fields: Optional[List[CustomFieldDefinition]] = None
|
||||
@@ -43,7 +43,7 @@ class ProjectTemplateCreate(ProjectTemplateBase):
|
||||
class ProjectTemplateUpdate(BaseModel):
|
||||
"""Schema for updating a project template."""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
description: Optional[str] = None
|
||||
description: Optional[str] = Field(None, max_length=2000)
|
||||
is_public: Optional[bool] = None
|
||||
task_statuses: Optional[List[TaskStatusDefinition]] = None
|
||||
custom_fields: Optional[List[CustomFieldDefinition]] = None
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
from typing import Dict, Set, Optional, Tuple
|
||||
from fastapi import WebSocket
|
||||
from app.core.redis import get_redis_sync
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,13 +20,48 @@ class ConnectionManager:
|
||||
self._lock = asyncio.Lock()
|
||||
self._project_lock = asyncio.Lock()
|
||||
|
||||
async def check_connection_limit(self, user_id: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if user can create a new WebSocket connection.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
Tuple of (can_connect: bool, reject_reason: str | None)
|
||||
- can_connect: True if user is within connection limit
|
||||
- reject_reason: Error message if connection should be rejected
|
||||
"""
|
||||
max_connections = settings.MAX_WEBSOCKET_CONNECTIONS_PER_USER
|
||||
async with self._lock:
|
||||
current_count = len(self.active_connections.get(user_id, set()))
|
||||
if current_count >= max_connections:
|
||||
logger.warning(
|
||||
f"User {user_id} exceeded WebSocket connection limit "
|
||||
f"({current_count}/{max_connections})"
|
||||
)
|
||||
return False, "Too many connections"
|
||||
return True, None
|
||||
|
||||
def get_user_connection_count(self, user_id: str) -> int:
|
||||
"""Get the current number of WebSocket connections for a user."""
|
||||
return len(self.active_connections.get(user_id, set()))
|
||||
|
||||
async def connect(self, websocket: WebSocket, user_id: str):
|
||||
"""Accept and track a new WebSocket connection."""
|
||||
await websocket.accept()
|
||||
"""
|
||||
Track a new WebSocket connection.
|
||||
|
||||
Note: WebSocket must already be accepted before calling this method.
|
||||
Connection limit should be checked via check_connection_limit() before calling.
|
||||
"""
|
||||
async with self._lock:
|
||||
if user_id not in self.active_connections:
|
||||
self.active_connections[user_id] = set()
|
||||
self.active_connections[user_id].add(websocket)
|
||||
logger.debug(
|
||||
f"User {user_id} connected. Total connections: "
|
||||
f"{len(self.active_connections[user_id])}"
|
||||
)
|
||||
|
||||
async def disconnect(self, websocket: WebSocket, user_id: str):
|
||||
"""Remove a WebSocket connection."""
|
||||
|
||||
Reference in New Issue
Block a user