feat: implement security, error resilience, and query optimization proposals
Security Validation (enhance-security-validation): - JWT secret validation with entropy checking and pattern detection - CSRF protection middleware with token generation/validation - Frontend CSRF token auto-injection for DELETE/PUT/PATCH requests - MIME type validation with magic bytes detection for file uploads Error Resilience (add-error-resilience): - React ErrorBoundary component with fallback UI and retry functionality - ErrorBoundaryWithI18n wrapper for internationalization support - Page-level and section-level error boundaries in App.tsx Query Performance (optimize-query-performance): - Query monitoring utility with threshold warnings - N+1 query fixes using joinedload/selectinload - Optimized project members, tasks, and subtasks endpoints Bug Fixes: - WebSocket session management (P0): Return primitives instead of ORM objects - LIKE query injection (P1): Escape special characters in search queries Tests: 543 backend tests, 56 frontend tests passing Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -122,6 +122,11 @@ class Settings(BaseSettings):
|
||||
RATE_LIMIT_SENSITIVE: str = "20/minute" # Attachments, password change, report export
|
||||
RATE_LIMIT_HEAVY: str = "5/minute" # Report generation, bulk operations
|
||||
|
||||
# Development Mode Settings
|
||||
DEBUG: bool = False # Enable debug mode for development
|
||||
QUERY_LOGGING: bool = False # Enable SQLAlchemy query logging
|
||||
QUERY_COUNT_THRESHOLD: int = 10 # Warn when query count exceeds this threshold
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
@@ -104,6 +104,10 @@ def _on_invalidate(dbapi_conn, connection_record, exception):
|
||||
# Start pool statistics logging on module load
|
||||
_start_pool_stats_logging()
|
||||
|
||||
# Set up query logging if enabled
|
||||
from app.core.query_monitor import setup_query_logging
|
||||
setup_query_logging(engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Dependency for getting database session."""
|
||||
@@ -127,3 +131,25 @@ def get_pool_status() -> dict:
|
||||
"total_checkins": _pool_stats["checkins"],
|
||||
"invalidated_connections": _pool_stats["invalidated_connections"],
|
||||
}
|
||||
|
||||
|
||||
def escape_like(value: str) -> str:
|
||||
"""
|
||||
Escape special characters for SQL LIKE queries.
|
||||
|
||||
Escapes '%' and '_' characters which have special meaning in LIKE patterns.
|
||||
This prevents LIKE injection attacks where user input could match unintended patterns.
|
||||
|
||||
Args:
|
||||
value: The user input string to escape
|
||||
|
||||
Returns:
|
||||
Escaped string safe for use in LIKE patterns
|
||||
|
||||
Example:
|
||||
>>> escape_like("test%value")
|
||||
'test\\%value'
|
||||
>>> escape_like("user_name")
|
||||
'user\\_name'
|
||||
"""
|
||||
return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
167
backend/app/core/query_monitor.py
Normal file
167
backend/app/core/query_monitor.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Query monitoring utilities for detecting N+1 queries and performance issues.
|
||||
|
||||
This module provides:
|
||||
1. Query counting per request in development mode
|
||||
2. SQLAlchemy event listeners for query logging
|
||||
3. Threshold-based warnings for excessive queries
|
||||
"""
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Callable, Any
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thread-local storage for per-request query counting
|
||||
_query_context = threading.local()
|
||||
|
||||
|
||||
class QueryCounter:
|
||||
"""
|
||||
Context manager for counting database queries within a request.
|
||||
|
||||
Usage:
|
||||
with QueryCounter() as counter:
|
||||
# ... execute queries ...
|
||||
print(f"Executed {counter.count} queries")
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: Optional[int] = None, context_name: str = "request"):
|
||||
self.threshold = threshold or settings.QUERY_COUNT_THRESHOLD
|
||||
self.context_name = context_name
|
||||
self.count = 0
|
||||
self.queries = []
|
||||
self.start_time = None
|
||||
self.total_time = 0.0
|
||||
|
||||
def __enter__(self):
|
||||
self.count = 0
|
||||
self.queries = []
|
||||
self.start_time = time.time()
|
||||
_query_context.counter = self
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.total_time = time.time() - self.start_time
|
||||
_query_context.counter = None
|
||||
|
||||
# Log warning if threshold exceeded
|
||||
if self.count > self.threshold:
|
||||
logger.warning(
|
||||
"Query count threshold exceeded in %s: %d queries (threshold: %d, time: %.3fs)",
|
||||
self.context_name,
|
||||
self.count,
|
||||
self.threshold,
|
||||
self.total_time,
|
||||
)
|
||||
if settings.DEBUG:
|
||||
# In debug mode, also log the individual queries
|
||||
for i, (sql, duration) in enumerate(self.queries[:20], 1):
|
||||
logger.debug(" Query %d (%.3fs): %s", i, duration, sql[:200])
|
||||
if len(self.queries) > 20:
|
||||
logger.debug(" ... and %d more queries", len(self.queries) - 20)
|
||||
elif settings.DEBUG and self.count > 0:
|
||||
logger.debug(
|
||||
"Query count for %s: %d queries in %.3fs",
|
||||
self.context_name,
|
||||
self.count,
|
||||
self.total_time,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
def record_query(self, statement: str, duration: float):
|
||||
"""Record a query execution."""
|
||||
self.count += 1
|
||||
if settings.DEBUG:
|
||||
self.queries.append((statement, duration))
|
||||
|
||||
|
||||
def get_current_counter() -> Optional[QueryCounter]:
|
||||
"""Get the current request's query counter, if any."""
|
||||
return getattr(_query_context, 'counter', None)
|
||||
|
||||
|
||||
def setup_query_logging(engine: Engine):
|
||||
"""
|
||||
Set up SQLAlchemy event listeners for query logging.
|
||||
|
||||
This should be called once during application startup.
|
||||
Only activates if QUERY_LOGGING is enabled in settings.
|
||||
"""
|
||||
if not settings.QUERY_LOGGING:
|
||||
logger.info("Query logging is disabled")
|
||||
return
|
||||
|
||||
logger.info("Setting up query logging with threshold=%d", settings.QUERY_COUNT_THRESHOLD)
|
||||
|
||||
@event.listens_for(engine, "before_cursor_execute")
|
||||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
conn.info.setdefault('query_start_time', []).append(time.time())
|
||||
|
||||
@event.listens_for(engine, "after_cursor_execute")
|
||||
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
start_times = conn.info.get('query_start_time', [])
|
||||
duration = time.time() - start_times.pop() if start_times else 0.0
|
||||
|
||||
# Record in current counter if active
|
||||
counter = get_current_counter()
|
||||
if counter:
|
||||
counter.record_query(statement, duration)
|
||||
|
||||
# Also log individual queries if in debug mode
|
||||
if settings.DEBUG:
|
||||
logger.debug("SQL (%.3fs): %s", duration, statement[:500])
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_queries(context_name: str = "operation", threshold: Optional[int] = None):
|
||||
"""
|
||||
Context manager to count queries for a specific operation.
|
||||
|
||||
Args:
|
||||
context_name: Name for logging purposes
|
||||
threshold: Override the default query count threshold
|
||||
|
||||
Usage:
|
||||
with count_queries("list_members") as counter:
|
||||
members = db.query(ProjectMember).all()
|
||||
for member in members:
|
||||
print(member.user.name) # N+1 query!
|
||||
|
||||
# After block, logs warning if threshold exceeded
|
||||
print(f"Total queries: {counter.count}")
|
||||
"""
|
||||
with QueryCounter(threshold=threshold, context_name=context_name) as counter:
|
||||
yield counter
|
||||
|
||||
|
||||
def assert_query_count(max_queries: int):
|
||||
"""
|
||||
Decorator for testing that asserts maximum query count.
|
||||
|
||||
Usage in tests:
|
||||
@assert_query_count(5)
|
||||
def test_list_members():
|
||||
# Should use at most 5 queries
|
||||
response = client.get("/api/projects/xxx/members")
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
def wrapper(*args, **kwargs):
|
||||
with QueryCounter(threshold=max_queries, context_name=func.__name__) as counter:
|
||||
result = func(*args, **kwargs)
|
||||
if counter.count > max_queries:
|
||||
raise AssertionError(
|
||||
f"Query count {counter.count} exceeded maximum {max_queries} "
|
||||
f"in {func.__name__}"
|
||||
)
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
@@ -1,8 +1,283 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, Tuple
|
||||
from jose import jwt, JWTError
|
||||
import logging
|
||||
import math
|
||||
import hashlib
|
||||
import secrets
|
||||
import hmac
|
||||
from collections import Counter
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants for JWT secret validation
|
||||
MIN_SECRET_LENGTH = 32
|
||||
MIN_ENTROPY_BITS = 128 # Minimum entropy in bits for a secure secret
|
||||
COMMON_WEAK_PATTERNS = [
|
||||
"password", "secret", "changeme", "admin", "test", "demo",
|
||||
"123456", "qwerty", "abc123", "letmein", "welcome",
|
||||
]
|
||||
|
||||
|
||||
def calculate_entropy(data: str) -> float:
|
||||
"""
|
||||
Calculate Shannon entropy of a string in bits.
|
||||
|
||||
Higher entropy indicates more randomness and thus a stronger secret.
|
||||
A perfectly random string of length n with k possible characters has
|
||||
entropy of n * log2(k) bits.
|
||||
|
||||
Args:
|
||||
data: The string to calculate entropy for
|
||||
|
||||
Returns:
|
||||
Entropy value in bits
|
||||
"""
|
||||
if not data:
|
||||
return 0.0
|
||||
|
||||
# Count character frequencies
|
||||
char_counts = Counter(data)
|
||||
length = len(data)
|
||||
|
||||
# Calculate Shannon entropy
|
||||
entropy = 0.0
|
||||
for count in char_counts.values():
|
||||
if count > 0:
|
||||
probability = count / length
|
||||
entropy -= probability * math.log2(probability)
|
||||
|
||||
# Return total entropy in bits (per-character entropy * length)
|
||||
return entropy * length
|
||||
|
||||
|
||||
def has_repeating_patterns(secret: str) -> bool:
|
||||
"""
|
||||
Check if the secret contains obvious repeating patterns.
|
||||
|
||||
Args:
|
||||
secret: The secret string to check
|
||||
|
||||
Returns:
|
||||
True if repeating patterns are detected
|
||||
"""
|
||||
if len(secret) < 8:
|
||||
return False
|
||||
|
||||
# Check for repeating character sequences
|
||||
for pattern_len in range(2, len(secret) // 3 + 1):
|
||||
pattern = secret[:pattern_len]
|
||||
if pattern * (len(secret) // pattern_len) == secret[:len(pattern) * (len(secret) // pattern_len)]:
|
||||
# More than 50% of the string is the same pattern repeated
|
||||
if (len(secret) // pattern_len) >= 3:
|
||||
return True
|
||||
|
||||
# Check for consecutive same characters
|
||||
consecutive_count = 1
|
||||
for i in range(1, len(secret)):
|
||||
if secret[i] == secret[i-1]:
|
||||
consecutive_count += 1
|
||||
if consecutive_count >= len(secret) // 2:
|
||||
return True
|
||||
else:
|
||||
consecutive_count = 1
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def validate_jwt_secret_strength(secret: str) -> Tuple[bool, list]:
|
||||
"""
|
||||
Validate JWT secret key strength.
|
||||
|
||||
Checks:
|
||||
1. Minimum length (32 characters)
|
||||
2. Entropy (minimum 128 bits)
|
||||
3. Common weak patterns
|
||||
4. Repeating patterns
|
||||
|
||||
Args:
|
||||
secret: The JWT secret to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, list_of_warnings)
|
||||
"""
|
||||
warnings = []
|
||||
is_valid = True
|
||||
|
||||
# Check minimum length
|
||||
if len(secret) < MIN_SECRET_LENGTH:
|
||||
warnings.append(
|
||||
f"JWT secret is too short ({len(secret)} chars). "
|
||||
f"Minimum recommended length is {MIN_SECRET_LENGTH} characters."
|
||||
)
|
||||
is_valid = False
|
||||
|
||||
# Calculate and check entropy
|
||||
entropy = calculate_entropy(secret)
|
||||
if entropy < MIN_ENTROPY_BITS:
|
||||
warnings.append(
|
||||
f"JWT secret has low entropy ({entropy:.1f} bits). "
|
||||
f"Minimum recommended entropy is {MIN_ENTROPY_BITS} bits. "
|
||||
"Consider using a cryptographically random secret."
|
||||
)
|
||||
# Low entropy alone doesn't make it invalid, but it's a warning
|
||||
|
||||
# Check for common weak patterns
|
||||
secret_lower = secret.lower()
|
||||
for pattern in COMMON_WEAK_PATTERNS:
|
||||
if pattern in secret_lower:
|
||||
warnings.append(
|
||||
f"JWT secret contains common weak pattern: '{pattern}'. "
|
||||
"Use a cryptographically random secret."
|
||||
)
|
||||
break
|
||||
|
||||
# Check for repeating patterns
|
||||
if has_repeating_patterns(secret):
|
||||
warnings.append(
|
||||
"JWT secret contains repeating patterns. "
|
||||
"Use a cryptographically random secret."
|
||||
)
|
||||
|
||||
return is_valid, warnings
|
||||
|
||||
|
||||
def validate_jwt_secret_on_startup() -> None:
|
||||
"""
|
||||
Validate JWT secret strength on application startup.
|
||||
|
||||
Logs warnings for weak secrets and raises an error in production
|
||||
if the secret is critically weak.
|
||||
"""
|
||||
import os
|
||||
|
||||
secret = settings.JWT_SECRET_KEY
|
||||
is_valid, warnings = validate_jwt_secret_strength(secret)
|
||||
|
||||
# Log all warnings
|
||||
for warning in warnings:
|
||||
logger.warning("JWT Security Warning: %s", warning)
|
||||
|
||||
# In production, enforce stricter requirements
|
||||
is_production = os.environ.get("ENVIRONMENT", "").lower() == "production"
|
||||
|
||||
if not is_valid:
|
||||
if is_production:
|
||||
logger.critical(
|
||||
"JWT secret does not meet security requirements. "
|
||||
"Application startup blocked in production mode. "
|
||||
"Please configure a strong JWT_SECRET_KEY (minimum 32 characters)."
|
||||
)
|
||||
raise ValueError(
|
||||
"JWT_SECRET_KEY does not meet minimum security requirements. "
|
||||
"See logs for details."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"JWT secret does not meet security requirements. "
|
||||
"This would block startup in production mode."
|
||||
)
|
||||
|
||||
if warnings:
|
||||
logger.info(
|
||||
"JWT secret validation completed with %d warning(s). "
|
||||
"Consider using: python -c \"import secrets; print(secrets.token_urlsafe(48))\" "
|
||||
"to generate a strong secret.",
|
||||
len(warnings)
|
||||
)
|
||||
else:
|
||||
logger.info("JWT secret validation passed. Secret meets security requirements.")
|
||||
|
||||
|
||||
# CSRF Token Functions
|
||||
CSRF_TOKEN_LENGTH = 32
|
||||
CSRF_TOKEN_EXPIRY_SECONDS = 3600 # 1 hour
|
||||
|
||||
|
||||
def generate_csrf_token(user_id: str) -> str:
|
||||
"""
|
||||
Generate a CSRF token for a user.
|
||||
|
||||
The token is a combination of:
|
||||
- Random bytes for unpredictability
|
||||
- User ID binding to prevent token reuse across users
|
||||
- HMAC signature for integrity
|
||||
|
||||
Args:
|
||||
user_id: The user's ID to bind the token to
|
||||
|
||||
Returns:
|
||||
CSRF token string
|
||||
"""
|
||||
# Generate random token
|
||||
random_part = secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||
|
||||
# Create timestamp for expiry checking
|
||||
timestamp = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Create the token payload
|
||||
payload = f"{random_part}:{user_id}:{timestamp}"
|
||||
|
||||
# Sign with HMAC using JWT secret
|
||||
signature = hmac.new(
|
||||
settings.JWT_SECRET_KEY.encode(),
|
||||
payload.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()[:16]
|
||||
|
||||
# Return combined token
|
||||
return f"{payload}:{signature}"
|
||||
|
||||
|
||||
def validate_csrf_token(token: str, user_id: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate a CSRF token.
|
||||
|
||||
Args:
|
||||
token: The CSRF token to validate
|
||||
user_id: The expected user ID
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
if not token:
|
||||
return False, "CSRF token is required"
|
||||
|
||||
try:
|
||||
parts = token.split(":")
|
||||
if len(parts) != 4:
|
||||
return False, "Invalid CSRF token format"
|
||||
|
||||
random_part, token_user_id, timestamp_str, signature = parts
|
||||
|
||||
# Verify user ID matches
|
||||
if token_user_id != user_id:
|
||||
return False, "CSRF token user mismatch"
|
||||
|
||||
# Verify timestamp (check expiry)
|
||||
timestamp = int(timestamp_str)
|
||||
current_time = int(datetime.now(timezone.utc).timestamp())
|
||||
if current_time - timestamp > CSRF_TOKEN_EXPIRY_SECONDS:
|
||||
return False, "CSRF token expired"
|
||||
|
||||
# Verify signature
|
||||
payload = f"{random_part}:{token_user_id}:{timestamp_str}"
|
||||
expected_signature = hmac.new(
|
||||
settings.JWT_SECRET_KEY.encode(),
|
||||
payload.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()[:16]
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
return False, "CSRF token signature invalid"
|
||||
|
||||
return True, ""
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
return False, f"CSRF token validation error: {str(e)}"
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user