Files
PROJECT-CONTORL/backend/app/core/security.py
beabigegg 35c90fe76b feat: implement 5 QA-driven security and quality proposals
Implemented proposals from comprehensive QA review:

1. extend-csrf-protection
   - Add POST to CSRF protected methods in frontend
   - Global CSRF middleware for all state-changing operations
   - Update tests with CSRF token fixtures

2. tighten-cors-websocket-security
   - Replace wildcard CORS with explicit method/header lists
   - Disable query parameter auth in production (code 4002)
   - Add per-user WebSocket connection limit (max 5, code 4005)

3. shorten-jwt-expiry
   - Reduce JWT expiry from 7 days to 60 minutes
   - Add refresh token support with 7-day expiry
   - Implement token rotation on refresh
   - Frontend auto-refresh when token near expiry (<5 min)

4. fix-frontend-quality
   - Add React.lazy() code splitting for all pages
   - Fix useCallback dependency arrays (Dashboard, Comments)
   - Add localStorage data validation in AuthContext
   - Complete i18n for AttachmentUpload component

5. enhance-backend-validation
   - Add SecurityAuditMiddleware for access denied logging
   - Add ErrorSanitizerMiddleware for production error messages
   - Protect /health/detailed with admin authentication
   - Add input length validation (comment 5000, desc 10000)

All 521 backend tests passing. Frontend builds successfully.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 23:19:05 +08:00

496 lines
14 KiB
Python

from datetime import datetime, timedelta, timezone
from typing import Optional, Any, Tuple
from jose import jwt, JWTError
import logging
import math
import hashlib
import secrets
import hmac
from collections import Counter
from app.core.config import settings
logger = logging.getLogger(__name__)
# Constants for JWT secret validation
MIN_SECRET_LENGTH = 32
MIN_ENTROPY_BITS = 128 # Minimum entropy in bits for a secure secret
COMMON_WEAK_PATTERNS = [
"password", "secret", "changeme", "admin", "test", "demo",
"123456", "qwerty", "abc123", "letmein", "welcome",
]
def calculate_entropy(data: str) -> float:
"""
Calculate Shannon entropy of a string in bits.
Higher entropy indicates more randomness and thus a stronger secret.
A perfectly random string of length n with k possible characters has
entropy of n * log2(k) bits.
Args:
data: The string to calculate entropy for
Returns:
Entropy value in bits
"""
if not data:
return 0.0
# Count character frequencies
char_counts = Counter(data)
length = len(data)
# Calculate Shannon entropy
entropy = 0.0
for count in char_counts.values():
if count > 0:
probability = count / length
entropy -= probability * math.log2(probability)
# Return total entropy in bits (per-character entropy * length)
return entropy * length
def has_repeating_patterns(secret: str) -> bool:
"""
Check if the secret contains obvious repeating patterns.
Args:
secret: The secret string to check
Returns:
True if repeating patterns are detected
"""
if len(secret) < 8:
return False
# Check for repeating character sequences
for pattern_len in range(2, len(secret) // 3 + 1):
pattern = secret[:pattern_len]
if pattern * (len(secret) // pattern_len) == secret[:len(pattern) * (len(secret) // pattern_len)]:
# More than 50% of the string is the same pattern repeated
if (len(secret) // pattern_len) >= 3:
return True
# Check for consecutive same characters
consecutive_count = 1
for i in range(1, len(secret)):
if secret[i] == secret[i-1]:
consecutive_count += 1
if consecutive_count >= len(secret) // 2:
return True
else:
consecutive_count = 1
return False
def validate_jwt_secret_strength(secret: str) -> Tuple[bool, list]:
"""
Validate JWT secret key strength.
Checks:
1. Minimum length (32 characters)
2. Entropy (minimum 128 bits)
3. Common weak patterns
4. Repeating patterns
Args:
secret: The JWT secret to validate
Returns:
Tuple of (is_valid, list_of_warnings)
"""
warnings = []
is_valid = True
# Check minimum length
if len(secret) < MIN_SECRET_LENGTH:
warnings.append(
f"JWT secret is too short ({len(secret)} chars). "
f"Minimum recommended length is {MIN_SECRET_LENGTH} characters."
)
is_valid = False
# Calculate and check entropy
entropy = calculate_entropy(secret)
if entropy < MIN_ENTROPY_BITS:
warnings.append(
f"JWT secret has low entropy ({entropy:.1f} bits). "
f"Minimum recommended entropy is {MIN_ENTROPY_BITS} bits. "
"Consider using a cryptographically random secret."
)
# Low entropy alone doesn't make it invalid, but it's a warning
# Check for common weak patterns
secret_lower = secret.lower()
for pattern in COMMON_WEAK_PATTERNS:
if pattern in secret_lower:
warnings.append(
f"JWT secret contains common weak pattern: '{pattern}'. "
"Use a cryptographically random secret."
)
break
# Check for repeating patterns
if has_repeating_patterns(secret):
warnings.append(
"JWT secret contains repeating patterns. "
"Use a cryptographically random secret."
)
return is_valid, warnings
def validate_jwt_secret_on_startup() -> None:
"""
Validate JWT secret strength on application startup.
Logs warnings for weak secrets and raises an error in production
if the secret is critically weak.
"""
import os
secret = settings.JWT_SECRET_KEY
is_valid, warnings = validate_jwt_secret_strength(secret)
# Log all warnings
for warning in warnings:
logger.warning("JWT Security Warning: %s", warning)
# In production, enforce stricter requirements
is_production = os.environ.get("ENVIRONMENT", "").lower() == "production"
if not is_valid:
if is_production:
logger.critical(
"JWT secret does not meet security requirements. "
"Application startup blocked in production mode. "
"Please configure a strong JWT_SECRET_KEY (minimum 32 characters)."
)
raise ValueError(
"JWT_SECRET_KEY does not meet minimum security requirements. "
"See logs for details."
)
else:
logger.warning(
"JWT secret does not meet security requirements. "
"This would block startup in production mode."
)
if warnings:
logger.info(
"JWT secret validation completed with %d warning(s). "
"Consider using: python -c \"import secrets; print(secrets.token_urlsafe(48))\" "
"to generate a strong secret.",
len(warnings)
)
else:
logger.info("JWT secret validation passed. Secret meets security requirements.")
# CSRF Token Functions
CSRF_TOKEN_LENGTH = 32
CSRF_TOKEN_EXPIRY_SECONDS = 3600 # 1 hour
def generate_csrf_token(user_id: str) -> str:
"""
Generate a CSRF token for a user.
The token is a combination of:
- Random bytes for unpredictability
- User ID binding to prevent token reuse across users
- HMAC signature for integrity
Args:
user_id: The user's ID to bind the token to
Returns:
CSRF token string
"""
# Generate random token
random_part = secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
# Create timestamp for expiry checking
timestamp = int(datetime.now(timezone.utc).timestamp())
# Create the token payload
payload = f"{random_part}:{user_id}:{timestamp}"
# Sign with HMAC using JWT secret
signature = hmac.new(
settings.JWT_SECRET_KEY.encode(),
payload.encode(),
hashlib.sha256
).hexdigest()[:16]
# Return combined token
return f"{payload}:{signature}"
def validate_csrf_token(token: str, user_id: str) -> Tuple[bool, str]:
"""
Validate a CSRF token.
Args:
token: The CSRF token to validate
user_id: The expected user ID
Returns:
Tuple of (is_valid, error_message)
"""
if not token:
return False, "CSRF token is required"
try:
parts = token.split(":")
if len(parts) != 4:
return False, "Invalid CSRF token format"
random_part, token_user_id, timestamp_str, signature = parts
# Verify user ID matches
if token_user_id != user_id:
return False, "CSRF token user mismatch"
# Verify timestamp (check expiry)
timestamp = int(timestamp_str)
current_time = int(datetime.now(timezone.utc).timestamp())
if current_time - timestamp > CSRF_TOKEN_EXPIRY_SECONDS:
return False, "CSRF token expired"
# Verify signature
payload = f"{random_part}:{token_user_id}:{timestamp_str}"
expected_signature = hmac.new(
settings.JWT_SECRET_KEY.encode(),
payload.encode(),
hashlib.sha256
).hexdigest()[:16]
if not hmac.compare_digest(signature, expected_signature):
return False, "CSRF token signature invalid"
return True, ""
except (ValueError, IndexError) as e:
return False, f"CSRF token validation error: {str(e)}"
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""
Create a JWT access token.
Args:
data: Data to encode in the token
expires_delta: Optional custom expiration time
Returns:
Encoded JWT token string
"""
to_encode = data.copy()
now = datetime.now(timezone.utc)
if expires_delta:
expire = now + expires_delta
else:
expire = now + timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
to_encode.update({"exp": expire, "iat": now})
encoded_jwt = jwt.encode(
to_encode,
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM
)
return encoded_jwt
def decode_access_token(token: str) -> Optional[dict]:
"""
Decode and verify a JWT access token.
Args:
token: The JWT token to decode
Returns:
Decoded token payload if valid, None if invalid or expired
"""
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
)
return payload
except JWTError:
return None
def create_token_payload(
user_id: str,
email: str,
role: str,
department_id: Optional[str],
is_system_admin: bool
) -> dict:
"""
Create a standardized token payload.
Args:
user_id: User's unique ID
email: User's email
role: User's role name
department_id: User's department ID (can be None)
is_system_admin: Whether user is a system admin
Returns:
dict: Token payload
"""
return {
"sub": user_id,
"email": email,
"role": role,
"department_id": department_id,
"is_system_admin": is_system_admin,
}
# Refresh Token Functions
REFRESH_TOKEN_BYTES = 32
def generate_refresh_token() -> str:
"""
Generate a cryptographically secure refresh token.
Returns:
A URL-safe base64-encoded random token
"""
return secrets.token_urlsafe(REFRESH_TOKEN_BYTES)
def get_refresh_token_key(user_id: str, token: str) -> str:
"""
Generate the Redis key for a refresh token.
Args:
user_id: The user's ID
token: The refresh token
Returns:
Redis key string
"""
# Hash the token to avoid storing it directly as a key
token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
return f"refresh_token:{user_id}:{token_hash}"
def store_refresh_token(redis_client, user_id: str, token: str) -> None:
"""
Store a refresh token in Redis with user binding.
Args:
redis_client: Redis client instance
user_id: The user's ID
token: The refresh token to store
"""
key = get_refresh_token_key(user_id, token)
# Store with TTL based on REFRESH_TOKEN_EXPIRE_DAYS
ttl_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
redis_client.setex(key, ttl_seconds, user_id)
def validate_refresh_token(redis_client, user_id: str, token: str) -> bool:
"""
Validate a refresh token exists in Redis and is bound to the user.
Args:
redis_client: Redis client instance
user_id: The expected user ID
token: The refresh token to validate
Returns:
True if token is valid, False otherwise
"""
key = get_refresh_token_key(user_id, token)
stored_user_id = redis_client.get(key)
if stored_user_id is None:
return False
# Handle Redis bytes type
if isinstance(stored_user_id, bytes):
stored_user_id = stored_user_id.decode("utf-8")
return stored_user_id == user_id
def invalidate_refresh_token(redis_client, user_id: str, token: str) -> bool:
"""
Invalidate (delete) a refresh token from Redis.
Args:
redis_client: Redis client instance
user_id: The user's ID
token: The refresh token to invalidate
Returns:
True if token was deleted, False if it didn't exist
"""
key = get_refresh_token_key(user_id, token)
result = redis_client.delete(key)
return result > 0 if isinstance(result, int) else bool(result)
def invalidate_all_user_refresh_tokens(redis_client, user_id: str) -> int:
"""
Invalidate all refresh tokens for a user.
Args:
redis_client: Redis client instance
user_id: The user's ID
Returns:
Number of tokens invalidated
"""
pattern = f"refresh_token:{user_id}:*"
count = 0
for key in redis_client.scan_iter(match=pattern):
redis_client.delete(key)
count += 1
return count
def decode_refresh_token_user_id(token: str, redis_client) -> Optional[str]:
"""
Find the user ID associated with a refresh token by searching Redis.
This is used when we only have the token and need to find which user it belongs to.
Note: This is less efficient but necessary for refresh token validation when
the user_id is not provided in the request.
Args:
token: The refresh token
redis_client: Redis client instance
Returns:
User ID if found, None otherwise
"""
# We need to search for the token across all users
# This is done by checking the token hash pattern
token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
pattern = f"refresh_token:*:{token_hash}"
for key in redis_client.scan_iter(match=pattern):
# Extract user_id from key format: refresh_token:{user_id}:{token_hash}
if isinstance(key, bytes):
key = key.decode("utf-8")
parts = key.split(":")
if len(parts) == 3:
return parts[1]
return None