from datetime import datetime, timedelta, timezone from typing import Optional, Any, Tuple from jose import jwt, JWTError import logging import math import hashlib import secrets import hmac from collections import Counter from app.core.config import settings logger = logging.getLogger(__name__) # Constants for JWT secret validation MIN_SECRET_LENGTH = 32 MIN_ENTROPY_BITS = 128 # Minimum entropy in bits for a secure secret COMMON_WEAK_PATTERNS = [ "password", "secret", "changeme", "admin", "test", "demo", "123456", "qwerty", "abc123", "letmein", "welcome", ] def calculate_entropy(data: str) -> float: """ Calculate Shannon entropy of a string in bits. Higher entropy indicates more randomness and thus a stronger secret. A perfectly random string of length n with k possible characters has entropy of n * log2(k) bits. Args: data: The string to calculate entropy for Returns: Entropy value in bits """ if not data: return 0.0 # Count character frequencies char_counts = Counter(data) length = len(data) # Calculate Shannon entropy entropy = 0.0 for count in char_counts.values(): if count > 0: probability = count / length entropy -= probability * math.log2(probability) # Return total entropy in bits (per-character entropy * length) return entropy * length def has_repeating_patterns(secret: str) -> bool: """ Check if the secret contains obvious repeating patterns. Args: secret: The secret string to check Returns: True if repeating patterns are detected """ if len(secret) < 8: return False # Check for repeating character sequences for pattern_len in range(2, len(secret) // 3 + 1): pattern = secret[:pattern_len] if pattern * (len(secret) // pattern_len) == secret[:len(pattern) * (len(secret) // pattern_len)]: # More than 50% of the string is the same pattern repeated if (len(secret) // pattern_len) >= 3: return True # Check for consecutive same characters consecutive_count = 1 for i in range(1, len(secret)): if secret[i] == secret[i-1]: consecutive_count += 1 if consecutive_count >= len(secret) // 2: return True else: consecutive_count = 1 return False def validate_jwt_secret_strength(secret: str) -> Tuple[bool, list]: """ Validate JWT secret key strength. Checks: 1. Minimum length (32 characters) 2. Entropy (minimum 128 bits) 3. Common weak patterns 4. Repeating patterns Args: secret: The JWT secret to validate Returns: Tuple of (is_valid, list_of_warnings) """ warnings = [] is_valid = True # Check minimum length if len(secret) < MIN_SECRET_LENGTH: warnings.append( f"JWT secret is too short ({len(secret)} chars). " f"Minimum recommended length is {MIN_SECRET_LENGTH} characters." ) is_valid = False # Calculate and check entropy entropy = calculate_entropy(secret) if entropy < MIN_ENTROPY_BITS: warnings.append( f"JWT secret has low entropy ({entropy:.1f} bits). " f"Minimum recommended entropy is {MIN_ENTROPY_BITS} bits. " "Consider using a cryptographically random secret." ) # Low entropy alone doesn't make it invalid, but it's a warning # Check for common weak patterns secret_lower = secret.lower() for pattern in COMMON_WEAK_PATTERNS: if pattern in secret_lower: warnings.append( f"JWT secret contains common weak pattern: '{pattern}'. " "Use a cryptographically random secret." ) break # Check for repeating patterns if has_repeating_patterns(secret): warnings.append( "JWT secret contains repeating patterns. " "Use a cryptographically random secret." ) return is_valid, warnings def validate_jwt_secret_on_startup() -> None: """ Validate JWT secret strength on application startup. Logs warnings for weak secrets and raises an error in production if the secret is critically weak. """ import os secret = settings.JWT_SECRET_KEY is_valid, warnings = validate_jwt_secret_strength(secret) # Log all warnings for warning in warnings: logger.warning("JWT Security Warning: %s", warning) # In production, enforce stricter requirements is_production = os.environ.get("ENVIRONMENT", "").lower() == "production" if not is_valid: if is_production: logger.critical( "JWT secret does not meet security requirements. " "Application startup blocked in production mode. " "Please configure a strong JWT_SECRET_KEY (minimum 32 characters)." ) raise ValueError( "JWT_SECRET_KEY does not meet minimum security requirements. " "See logs for details." ) else: logger.warning( "JWT secret does not meet security requirements. " "This would block startup in production mode." ) if warnings: logger.info( "JWT secret validation completed with %d warning(s). " "Consider using: python -c \"import secrets; print(secrets.token_urlsafe(48))\" " "to generate a strong secret.", len(warnings) ) else: logger.info("JWT secret validation passed. Secret meets security requirements.") # CSRF Token Functions CSRF_TOKEN_LENGTH = 32 CSRF_TOKEN_EXPIRY_SECONDS = 3600 # 1 hour def generate_csrf_token(user_id: str) -> str: """ Generate a CSRF token for a user. The token is a combination of: - Random bytes for unpredictability - User ID binding to prevent token reuse across users - HMAC signature for integrity Args: user_id: The user's ID to bind the token to Returns: CSRF token string """ # Generate random token random_part = secrets.token_urlsafe(CSRF_TOKEN_LENGTH) # Create timestamp for expiry checking timestamp = int(datetime.now(timezone.utc).timestamp()) # Create the token payload payload = f"{random_part}:{user_id}:{timestamp}" # Sign with HMAC using JWT secret signature = hmac.new( settings.JWT_SECRET_KEY.encode(), payload.encode(), hashlib.sha256 ).hexdigest()[:16] # Return combined token return f"{payload}:{signature}" def validate_csrf_token(token: str, user_id: str) -> Tuple[bool, str]: """ Validate a CSRF token. Args: token: The CSRF token to validate user_id: The expected user ID Returns: Tuple of (is_valid, error_message) """ if not token: return False, "CSRF token is required" try: parts = token.split(":") if len(parts) != 4: return False, "Invalid CSRF token format" random_part, token_user_id, timestamp_str, signature = parts # Verify user ID matches if token_user_id != user_id: return False, "CSRF token user mismatch" # Verify timestamp (check expiry) timestamp = int(timestamp_str) current_time = int(datetime.now(timezone.utc).timestamp()) if current_time - timestamp > CSRF_TOKEN_EXPIRY_SECONDS: return False, "CSRF token expired" # Verify signature payload = f"{random_part}:{token_user_id}:{timestamp_str}" expected_signature = hmac.new( settings.JWT_SECRET_KEY.encode(), payload.encode(), hashlib.sha256 ).hexdigest()[:16] if not hmac.compare_digest(signature, expected_signature): return False, "CSRF token signature invalid" return True, "" except (ValueError, IndexError) as e: return False, f"CSRF token validation error: {str(e)}" def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """ Create a JWT access token. Args: data: Data to encode in the token expires_delta: Optional custom expiration time Returns: Encoded JWT token string """ to_encode = data.copy() now = datetime.now(timezone.utc) if expires_delta: expire = now + expires_delta else: expire = now + timedelta(minutes=settings.JWT_EXPIRE_MINUTES) to_encode.update({"exp": expire, "iat": now}) encoded_jwt = jwt.encode( to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM ) return encoded_jwt def decode_access_token(token: str) -> Optional[dict]: """ Decode and verify a JWT access token. Args: token: The JWT token to decode Returns: Decoded token payload if valid, None if invalid or expired """ try: payload = jwt.decode( token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] ) return payload except JWTError: return None def create_token_payload( user_id: str, email: str, role: str, department_id: Optional[str], is_system_admin: bool ) -> dict: """ Create a standardized token payload. Args: user_id: User's unique ID email: User's email role: User's role name department_id: User's department ID (can be None) is_system_admin: Whether user is a system admin Returns: dict: Token payload """ return { "sub": user_id, "email": email, "role": role, "department_id": department_id, "is_system_admin": is_system_admin, }