Implemented proposals from comprehensive QA review: 1. extend-csrf-protection - Add POST to CSRF protected methods in frontend - Global CSRF middleware for all state-changing operations - Update tests with CSRF token fixtures 2. tighten-cors-websocket-security - Replace wildcard CORS with explicit method/header lists - Disable query parameter auth in production (code 4002) - Add per-user WebSocket connection limit (max 5, code 4005) 3. shorten-jwt-expiry - Reduce JWT expiry from 7 days to 60 minutes - Add refresh token support with 7-day expiry - Implement token rotation on refresh - Frontend auto-refresh when token near expiry (<5 min) 4. fix-frontend-quality - Add React.lazy() code splitting for all pages - Fix useCallback dependency arrays (Dashboard, Comments) - Add localStorage data validation in AuthContext - Complete i18n for AttachmentUpload component 5. enhance-backend-validation - Add SecurityAuditMiddleware for access denied logging - Add ErrorSanitizerMiddleware for production error messages - Protect /health/detailed with admin authentication - Add input length validation (comment 5000, desc 10000) All 521 backend tests passing. Frontend builds successfully. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
496 lines
14 KiB
Python
496 lines
14 KiB
Python
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional, Any, Tuple
|
|
from jose import jwt, JWTError
|
|
import logging
|
|
import math
|
|
import hashlib
|
|
import secrets
|
|
import hmac
|
|
from collections import Counter
|
|
|
|
from app.core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Constants for JWT secret validation
|
|
MIN_SECRET_LENGTH = 32
|
|
MIN_ENTROPY_BITS = 128 # Minimum entropy in bits for a secure secret
|
|
COMMON_WEAK_PATTERNS = [
|
|
"password", "secret", "changeme", "admin", "test", "demo",
|
|
"123456", "qwerty", "abc123", "letmein", "welcome",
|
|
]
|
|
|
|
|
|
def calculate_entropy(data: str) -> float:
|
|
"""
|
|
Calculate Shannon entropy of a string in bits.
|
|
|
|
Higher entropy indicates more randomness and thus a stronger secret.
|
|
A perfectly random string of length n with k possible characters has
|
|
entropy of n * log2(k) bits.
|
|
|
|
Args:
|
|
data: The string to calculate entropy for
|
|
|
|
Returns:
|
|
Entropy value in bits
|
|
"""
|
|
if not data:
|
|
return 0.0
|
|
|
|
# Count character frequencies
|
|
char_counts = Counter(data)
|
|
length = len(data)
|
|
|
|
# Calculate Shannon entropy
|
|
entropy = 0.0
|
|
for count in char_counts.values():
|
|
if count > 0:
|
|
probability = count / length
|
|
entropy -= probability * math.log2(probability)
|
|
|
|
# Return total entropy in bits (per-character entropy * length)
|
|
return entropy * length
|
|
|
|
|
|
def has_repeating_patterns(secret: str) -> bool:
|
|
"""
|
|
Check if the secret contains obvious repeating patterns.
|
|
|
|
Args:
|
|
secret: The secret string to check
|
|
|
|
Returns:
|
|
True if repeating patterns are detected
|
|
"""
|
|
if len(secret) < 8:
|
|
return False
|
|
|
|
# Check for repeating character sequences
|
|
for pattern_len in range(2, len(secret) // 3 + 1):
|
|
pattern = secret[:pattern_len]
|
|
if pattern * (len(secret) // pattern_len) == secret[:len(pattern) * (len(secret) // pattern_len)]:
|
|
# More than 50% of the string is the same pattern repeated
|
|
if (len(secret) // pattern_len) >= 3:
|
|
return True
|
|
|
|
# Check for consecutive same characters
|
|
consecutive_count = 1
|
|
for i in range(1, len(secret)):
|
|
if secret[i] == secret[i-1]:
|
|
consecutive_count += 1
|
|
if consecutive_count >= len(secret) // 2:
|
|
return True
|
|
else:
|
|
consecutive_count = 1
|
|
|
|
return False
|
|
|
|
|
|
def validate_jwt_secret_strength(secret: str) -> Tuple[bool, list]:
|
|
"""
|
|
Validate JWT secret key strength.
|
|
|
|
Checks:
|
|
1. Minimum length (32 characters)
|
|
2. Entropy (minimum 128 bits)
|
|
3. Common weak patterns
|
|
4. Repeating patterns
|
|
|
|
Args:
|
|
secret: The JWT secret to validate
|
|
|
|
Returns:
|
|
Tuple of (is_valid, list_of_warnings)
|
|
"""
|
|
warnings = []
|
|
is_valid = True
|
|
|
|
# Check minimum length
|
|
if len(secret) < MIN_SECRET_LENGTH:
|
|
warnings.append(
|
|
f"JWT secret is too short ({len(secret)} chars). "
|
|
f"Minimum recommended length is {MIN_SECRET_LENGTH} characters."
|
|
)
|
|
is_valid = False
|
|
|
|
# Calculate and check entropy
|
|
entropy = calculate_entropy(secret)
|
|
if entropy < MIN_ENTROPY_BITS:
|
|
warnings.append(
|
|
f"JWT secret has low entropy ({entropy:.1f} bits). "
|
|
f"Minimum recommended entropy is {MIN_ENTROPY_BITS} bits. "
|
|
"Consider using a cryptographically random secret."
|
|
)
|
|
# Low entropy alone doesn't make it invalid, but it's a warning
|
|
|
|
# Check for common weak patterns
|
|
secret_lower = secret.lower()
|
|
for pattern in COMMON_WEAK_PATTERNS:
|
|
if pattern in secret_lower:
|
|
warnings.append(
|
|
f"JWT secret contains common weak pattern: '{pattern}'. "
|
|
"Use a cryptographically random secret."
|
|
)
|
|
break
|
|
|
|
# Check for repeating patterns
|
|
if has_repeating_patterns(secret):
|
|
warnings.append(
|
|
"JWT secret contains repeating patterns. "
|
|
"Use a cryptographically random secret."
|
|
)
|
|
|
|
return is_valid, warnings
|
|
|
|
|
|
def validate_jwt_secret_on_startup() -> None:
|
|
"""
|
|
Validate JWT secret strength on application startup.
|
|
|
|
Logs warnings for weak secrets and raises an error in production
|
|
if the secret is critically weak.
|
|
"""
|
|
import os
|
|
|
|
secret = settings.JWT_SECRET_KEY
|
|
is_valid, warnings = validate_jwt_secret_strength(secret)
|
|
|
|
# Log all warnings
|
|
for warning in warnings:
|
|
logger.warning("JWT Security Warning: %s", warning)
|
|
|
|
# In production, enforce stricter requirements
|
|
is_production = os.environ.get("ENVIRONMENT", "").lower() == "production"
|
|
|
|
if not is_valid:
|
|
if is_production:
|
|
logger.critical(
|
|
"JWT secret does not meet security requirements. "
|
|
"Application startup blocked in production mode. "
|
|
"Please configure a strong JWT_SECRET_KEY (minimum 32 characters)."
|
|
)
|
|
raise ValueError(
|
|
"JWT_SECRET_KEY does not meet minimum security requirements. "
|
|
"See logs for details."
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"JWT secret does not meet security requirements. "
|
|
"This would block startup in production mode."
|
|
)
|
|
|
|
if warnings:
|
|
logger.info(
|
|
"JWT secret validation completed with %d warning(s). "
|
|
"Consider using: python -c \"import secrets; print(secrets.token_urlsafe(48))\" "
|
|
"to generate a strong secret.",
|
|
len(warnings)
|
|
)
|
|
else:
|
|
logger.info("JWT secret validation passed. Secret meets security requirements.")
|
|
|
|
|
|
# CSRF Token Functions
|
|
CSRF_TOKEN_LENGTH = 32
|
|
CSRF_TOKEN_EXPIRY_SECONDS = 3600 # 1 hour
|
|
|
|
|
|
def generate_csrf_token(user_id: str) -> str:
|
|
"""
|
|
Generate a CSRF token for a user.
|
|
|
|
The token is a combination of:
|
|
- Random bytes for unpredictability
|
|
- User ID binding to prevent token reuse across users
|
|
- HMAC signature for integrity
|
|
|
|
Args:
|
|
user_id: The user's ID to bind the token to
|
|
|
|
Returns:
|
|
CSRF token string
|
|
"""
|
|
# Generate random token
|
|
random_part = secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
|
|
|
# Create timestamp for expiry checking
|
|
timestamp = int(datetime.now(timezone.utc).timestamp())
|
|
|
|
# Create the token payload
|
|
payload = f"{random_part}:{user_id}:{timestamp}"
|
|
|
|
# Sign with HMAC using JWT secret
|
|
signature = hmac.new(
|
|
settings.JWT_SECRET_KEY.encode(),
|
|
payload.encode(),
|
|
hashlib.sha256
|
|
).hexdigest()[:16]
|
|
|
|
# Return combined token
|
|
return f"{payload}:{signature}"
|
|
|
|
|
|
def validate_csrf_token(token: str, user_id: str) -> Tuple[bool, str]:
|
|
"""
|
|
Validate a CSRF token.
|
|
|
|
Args:
|
|
token: The CSRF token to validate
|
|
user_id: The expected user ID
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
"""
|
|
if not token:
|
|
return False, "CSRF token is required"
|
|
|
|
try:
|
|
parts = token.split(":")
|
|
if len(parts) != 4:
|
|
return False, "Invalid CSRF token format"
|
|
|
|
random_part, token_user_id, timestamp_str, signature = parts
|
|
|
|
# Verify user ID matches
|
|
if token_user_id != user_id:
|
|
return False, "CSRF token user mismatch"
|
|
|
|
# Verify timestamp (check expiry)
|
|
timestamp = int(timestamp_str)
|
|
current_time = int(datetime.now(timezone.utc).timestamp())
|
|
if current_time - timestamp > CSRF_TOKEN_EXPIRY_SECONDS:
|
|
return False, "CSRF token expired"
|
|
|
|
# Verify signature
|
|
payload = f"{random_part}:{token_user_id}:{timestamp_str}"
|
|
expected_signature = hmac.new(
|
|
settings.JWT_SECRET_KEY.encode(),
|
|
payload.encode(),
|
|
hashlib.sha256
|
|
).hexdigest()[:16]
|
|
|
|
if not hmac.compare_digest(signature, expected_signature):
|
|
return False, "CSRF token signature invalid"
|
|
|
|
return True, ""
|
|
|
|
except (ValueError, IndexError) as e:
|
|
return False, f"CSRF token validation error: {str(e)}"
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""
|
|
Create a JWT access token.
|
|
|
|
Args:
|
|
data: Data to encode in the token
|
|
expires_delta: Optional custom expiration time
|
|
|
|
Returns:
|
|
Encoded JWT token string
|
|
"""
|
|
to_encode = data.copy()
|
|
now = datetime.now(timezone.utc)
|
|
|
|
if expires_delta:
|
|
expire = now + expires_delta
|
|
else:
|
|
expire = now + timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
|
|
|
|
to_encode.update({"exp": expire, "iat": now})
|
|
|
|
encoded_jwt = jwt.encode(
|
|
to_encode,
|
|
settings.JWT_SECRET_KEY,
|
|
algorithm=settings.JWT_ALGORITHM
|
|
)
|
|
return encoded_jwt
|
|
|
|
|
|
def decode_access_token(token: str) -> Optional[dict]:
|
|
"""
|
|
Decode and verify a JWT access token.
|
|
|
|
Args:
|
|
token: The JWT token to decode
|
|
|
|
Returns:
|
|
Decoded token payload if valid, None if invalid or expired
|
|
"""
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.JWT_SECRET_KEY,
|
|
algorithms=[settings.JWT_ALGORITHM]
|
|
)
|
|
return payload
|
|
except JWTError:
|
|
return None
|
|
|
|
|
|
def create_token_payload(
|
|
user_id: str,
|
|
email: str,
|
|
role: str,
|
|
department_id: Optional[str],
|
|
is_system_admin: bool
|
|
) -> dict:
|
|
"""
|
|
Create a standardized token payload.
|
|
|
|
Args:
|
|
user_id: User's unique ID
|
|
email: User's email
|
|
role: User's role name
|
|
department_id: User's department ID (can be None)
|
|
is_system_admin: Whether user is a system admin
|
|
|
|
Returns:
|
|
dict: Token payload
|
|
"""
|
|
return {
|
|
"sub": user_id,
|
|
"email": email,
|
|
"role": role,
|
|
"department_id": department_id,
|
|
"is_system_admin": is_system_admin,
|
|
}
|
|
|
|
|
|
# Refresh Token Functions
|
|
REFRESH_TOKEN_BYTES = 32
|
|
|
|
|
|
def generate_refresh_token() -> str:
|
|
"""
|
|
Generate a cryptographically secure refresh token.
|
|
|
|
Returns:
|
|
A URL-safe base64-encoded random token
|
|
"""
|
|
return secrets.token_urlsafe(REFRESH_TOKEN_BYTES)
|
|
|
|
|
|
def get_refresh_token_key(user_id: str, token: str) -> str:
|
|
"""
|
|
Generate the Redis key for a refresh token.
|
|
|
|
Args:
|
|
user_id: The user's ID
|
|
token: The refresh token
|
|
|
|
Returns:
|
|
Redis key string
|
|
"""
|
|
# Hash the token to avoid storing it directly as a key
|
|
token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
|
|
return f"refresh_token:{user_id}:{token_hash}"
|
|
|
|
|
|
def store_refresh_token(redis_client, user_id: str, token: str) -> None:
|
|
"""
|
|
Store a refresh token in Redis with user binding.
|
|
|
|
Args:
|
|
redis_client: Redis client instance
|
|
user_id: The user's ID
|
|
token: The refresh token to store
|
|
"""
|
|
key = get_refresh_token_key(user_id, token)
|
|
# Store with TTL based on REFRESH_TOKEN_EXPIRE_DAYS
|
|
ttl_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
|
|
redis_client.setex(key, ttl_seconds, user_id)
|
|
|
|
|
|
def validate_refresh_token(redis_client, user_id: str, token: str) -> bool:
|
|
"""
|
|
Validate a refresh token exists in Redis and is bound to the user.
|
|
|
|
Args:
|
|
redis_client: Redis client instance
|
|
user_id: The expected user ID
|
|
token: The refresh token to validate
|
|
|
|
Returns:
|
|
True if token is valid, False otherwise
|
|
"""
|
|
key = get_refresh_token_key(user_id, token)
|
|
stored_user_id = redis_client.get(key)
|
|
|
|
if stored_user_id is None:
|
|
return False
|
|
|
|
# Handle Redis bytes type
|
|
if isinstance(stored_user_id, bytes):
|
|
stored_user_id = stored_user_id.decode("utf-8")
|
|
|
|
return stored_user_id == user_id
|
|
|
|
|
|
def invalidate_refresh_token(redis_client, user_id: str, token: str) -> bool:
|
|
"""
|
|
Invalidate (delete) a refresh token from Redis.
|
|
|
|
Args:
|
|
redis_client: Redis client instance
|
|
user_id: The user's ID
|
|
token: The refresh token to invalidate
|
|
|
|
Returns:
|
|
True if token was deleted, False if it didn't exist
|
|
"""
|
|
key = get_refresh_token_key(user_id, token)
|
|
result = redis_client.delete(key)
|
|
return result > 0 if isinstance(result, int) else bool(result)
|
|
|
|
|
|
def invalidate_all_user_refresh_tokens(redis_client, user_id: str) -> int:
|
|
"""
|
|
Invalidate all refresh tokens for a user.
|
|
|
|
Args:
|
|
redis_client: Redis client instance
|
|
user_id: The user's ID
|
|
|
|
Returns:
|
|
Number of tokens invalidated
|
|
"""
|
|
pattern = f"refresh_token:{user_id}:*"
|
|
count = 0
|
|
for key in redis_client.scan_iter(match=pattern):
|
|
redis_client.delete(key)
|
|
count += 1
|
|
return count
|
|
|
|
|
|
def decode_refresh_token_user_id(token: str, redis_client) -> Optional[str]:
|
|
"""
|
|
Find the user ID associated with a refresh token by searching Redis.
|
|
|
|
This is used when we only have the token and need to find which user it belongs to.
|
|
Note: This is less efficient but necessary for refresh token validation when
|
|
the user_id is not provided in the request.
|
|
|
|
Args:
|
|
token: The refresh token
|
|
redis_client: Redis client instance
|
|
|
|
Returns:
|
|
User ID if found, None otherwise
|
|
"""
|
|
# We need to search for the token across all users
|
|
# This is done by checking the token hash pattern
|
|
token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
|
|
pattern = f"refresh_token:*:{token_hash}"
|
|
|
|
for key in redis_client.scan_iter(match=pattern):
|
|
# Extract user_id from key format: refresh_token:{user_id}:{token_hash}
|
|
if isinstance(key, bytes):
|
|
key = key.decode("utf-8")
|
|
parts = key.split(":")
|
|
if len(parts) == 3:
|
|
return parts[1]
|
|
|
|
return None
|