feat: implement security, error resilience, and query optimization proposals

Security Validation (enhance-security-validation):
- JWT secret validation with entropy checking and pattern detection
- CSRF protection middleware with token generation/validation
- Frontend CSRF token auto-injection for DELETE/PUT/PATCH requests
- MIME type validation with magic bytes detection for file uploads

Error Resilience (add-error-resilience):
- React ErrorBoundary component with fallback UI and retry functionality
- ErrorBoundaryWithI18n wrapper for internationalization support
- Page-level and section-level error boundaries in App.tsx

Query Performance (optimize-query-performance):
- Query monitoring utility with threshold warnings
- N+1 query fixes using joinedload/selectinload
- Optimized project members, tasks, and subtasks endpoints

Bug Fixes:
- WebSocket session management (P0): Return primitives instead of ORM objects
- LIKE query injection (P1): Escape special characters in search queries

Tests: 543 backend tests, 56 frontend tests passing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-11 18:41:19 +08:00
parent 2cb591ef23
commit 679b89ae4c
41 changed files with 3673 additions and 153 deletions

View File

@@ -122,6 +122,11 @@ class Settings(BaseSettings):
RATE_LIMIT_SENSITIVE: str = "20/minute" # Attachments, password change, report export
RATE_LIMIT_HEAVY: str = "5/minute" # Report generation, bulk operations
# Development Mode Settings
DEBUG: bool = False # Enable debug mode for development
QUERY_LOGGING: bool = False # Enable SQLAlchemy query logging
QUERY_COUNT_THRESHOLD: int = 10 # Warn when query count exceeds this threshold
class Config:
env_file = ".env"
case_sensitive = True

View File

@@ -104,6 +104,10 @@ def _on_invalidate(dbapi_conn, connection_record, exception):
# Start pool statistics logging on module load
_start_pool_stats_logging()
# Set up query logging if enabled
from app.core.query_monitor import setup_query_logging
setup_query_logging(engine)
def get_db():
"""Dependency for getting database session."""
@@ -127,3 +131,25 @@ def get_pool_status() -> dict:
"total_checkins": _pool_stats["checkins"],
"invalidated_connections": _pool_stats["invalidated_connections"],
}
def escape_like(value: str) -> str:
"""
Escape special characters for SQL LIKE queries.
Escapes '%' and '_' characters which have special meaning in LIKE patterns.
This prevents LIKE injection attacks where user input could match unintended patterns.
Args:
value: The user input string to escape
Returns:
Escaped string safe for use in LIKE patterns
Example:
>>> escape_like("test%value")
'test\\%value'
>>> escape_like("user_name")
'user\\_name'
"""
return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")

View File

@@ -0,0 +1,167 @@
"""
Query monitoring utilities for detecting N+1 queries and performance issues.
This module provides:
1. Query counting per request in development mode
2. SQLAlchemy event listeners for query logging
3. Threshold-based warnings for excessive queries
"""
import logging
import threading
import time
from contextlib import contextmanager
from typing import Optional, Callable, Any
from sqlalchemy import event
from sqlalchemy.engine import Engine
from app.core.config import settings
logger = logging.getLogger(__name__)
# Thread-local storage for per-request query counting
_query_context = threading.local()
class QueryCounter:
"""
Context manager for counting database queries within a request.
Usage:
with QueryCounter() as counter:
# ... execute queries ...
print(f"Executed {counter.count} queries")
"""
def __init__(self, threshold: Optional[int] = None, context_name: str = "request"):
self.threshold = threshold or settings.QUERY_COUNT_THRESHOLD
self.context_name = context_name
self.count = 0
self.queries = []
self.start_time = None
self.total_time = 0.0
def __enter__(self):
self.count = 0
self.queries = []
self.start_time = time.time()
_query_context.counter = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.total_time = time.time() - self.start_time
_query_context.counter = None
# Log warning if threshold exceeded
if self.count > self.threshold:
logger.warning(
"Query count threshold exceeded in %s: %d queries (threshold: %d, time: %.3fs)",
self.context_name,
self.count,
self.threshold,
self.total_time,
)
if settings.DEBUG:
# In debug mode, also log the individual queries
for i, (sql, duration) in enumerate(self.queries[:20], 1):
logger.debug(" Query %d (%.3fs): %s", i, duration, sql[:200])
if len(self.queries) > 20:
logger.debug(" ... and %d more queries", len(self.queries) - 20)
elif settings.DEBUG and self.count > 0:
logger.debug(
"Query count for %s: %d queries in %.3fs",
self.context_name,
self.count,
self.total_time,
)
return False
def record_query(self, statement: str, duration: float):
"""Record a query execution."""
self.count += 1
if settings.DEBUG:
self.queries.append((statement, duration))
def get_current_counter() -> Optional[QueryCounter]:
"""Get the current request's query counter, if any."""
return getattr(_query_context, 'counter', None)
def setup_query_logging(engine: Engine):
"""
Set up SQLAlchemy event listeners for query logging.
This should be called once during application startup.
Only activates if QUERY_LOGGING is enabled in settings.
"""
if not settings.QUERY_LOGGING:
logger.info("Query logging is disabled")
return
logger.info("Setting up query logging with threshold=%d", settings.QUERY_COUNT_THRESHOLD)
@event.listens_for(engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
conn.info.setdefault('query_start_time', []).append(time.time())
@event.listens_for(engine, "after_cursor_execute")
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
start_times = conn.info.get('query_start_time', [])
duration = time.time() - start_times.pop() if start_times else 0.0
# Record in current counter if active
counter = get_current_counter()
if counter:
counter.record_query(statement, duration)
# Also log individual queries if in debug mode
if settings.DEBUG:
logger.debug("SQL (%.3fs): %s", duration, statement[:500])
@contextmanager
def count_queries(context_name: str = "operation", threshold: Optional[int] = None):
"""
Context manager to count queries for a specific operation.
Args:
context_name: Name for logging purposes
threshold: Override the default query count threshold
Usage:
with count_queries("list_members") as counter:
members = db.query(ProjectMember).all()
for member in members:
print(member.user.name) # N+1 query!
# After block, logs warning if threshold exceeded
print(f"Total queries: {counter.count}")
"""
with QueryCounter(threshold=threshold, context_name=context_name) as counter:
yield counter
def assert_query_count(max_queries: int):
"""
Decorator for testing that asserts maximum query count.
Usage in tests:
@assert_query_count(5)
def test_list_members():
# Should use at most 5 queries
response = client.get("/api/projects/xxx/members")
"""
def decorator(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
with QueryCounter(threshold=max_queries, context_name=func.__name__) as counter:
result = func(*args, **kwargs)
if counter.count > max_queries:
raise AssertionError(
f"Query count {counter.count} exceeded maximum {max_queries} "
f"in {func.__name__}"
)
return result
return wrapper
return decorator

View File

@@ -1,8 +1,283 @@
from datetime import datetime, timedelta, timezone
from typing import Optional, Any
from typing import Optional, Any, Tuple
from jose import jwt, JWTError
import logging
import math
import hashlib
import secrets
import hmac
from collections import Counter
from app.core.config import settings
logger = logging.getLogger(__name__)
# Constants for JWT secret validation
MIN_SECRET_LENGTH = 32
MIN_ENTROPY_BITS = 128 # Minimum entropy in bits for a secure secret
COMMON_WEAK_PATTERNS = [
"password", "secret", "changeme", "admin", "test", "demo",
"123456", "qwerty", "abc123", "letmein", "welcome",
]
def calculate_entropy(data: str) -> float:
"""
Calculate Shannon entropy of a string in bits.
Higher entropy indicates more randomness and thus a stronger secret.
A perfectly random string of length n with k possible characters has
entropy of n * log2(k) bits.
Args:
data: The string to calculate entropy for
Returns:
Entropy value in bits
"""
if not data:
return 0.0
# Count character frequencies
char_counts = Counter(data)
length = len(data)
# Calculate Shannon entropy
entropy = 0.0
for count in char_counts.values():
if count > 0:
probability = count / length
entropy -= probability * math.log2(probability)
# Return total entropy in bits (per-character entropy * length)
return entropy * length
def has_repeating_patterns(secret: str) -> bool:
"""
Check if the secret contains obvious repeating patterns.
Args:
secret: The secret string to check
Returns:
True if repeating patterns are detected
"""
if len(secret) < 8:
return False
# Check for repeating character sequences
for pattern_len in range(2, len(secret) // 3 + 1):
pattern = secret[:pattern_len]
if pattern * (len(secret) // pattern_len) == secret[:len(pattern) * (len(secret) // pattern_len)]:
# More than 50% of the string is the same pattern repeated
if (len(secret) // pattern_len) >= 3:
return True
# Check for consecutive same characters
consecutive_count = 1
for i in range(1, len(secret)):
if secret[i] == secret[i-1]:
consecutive_count += 1
if consecutive_count >= len(secret) // 2:
return True
else:
consecutive_count = 1
return False
def validate_jwt_secret_strength(secret: str) -> Tuple[bool, list]:
"""
Validate JWT secret key strength.
Checks:
1. Minimum length (32 characters)
2. Entropy (minimum 128 bits)
3. Common weak patterns
4. Repeating patterns
Args:
secret: The JWT secret to validate
Returns:
Tuple of (is_valid, list_of_warnings)
"""
warnings = []
is_valid = True
# Check minimum length
if len(secret) < MIN_SECRET_LENGTH:
warnings.append(
f"JWT secret is too short ({len(secret)} chars). "
f"Minimum recommended length is {MIN_SECRET_LENGTH} characters."
)
is_valid = False
# Calculate and check entropy
entropy = calculate_entropy(secret)
if entropy < MIN_ENTROPY_BITS:
warnings.append(
f"JWT secret has low entropy ({entropy:.1f} bits). "
f"Minimum recommended entropy is {MIN_ENTROPY_BITS} bits. "
"Consider using a cryptographically random secret."
)
# Low entropy alone doesn't make it invalid, but it's a warning
# Check for common weak patterns
secret_lower = secret.lower()
for pattern in COMMON_WEAK_PATTERNS:
if pattern in secret_lower:
warnings.append(
f"JWT secret contains common weak pattern: '{pattern}'. "
"Use a cryptographically random secret."
)
break
# Check for repeating patterns
if has_repeating_patterns(secret):
warnings.append(
"JWT secret contains repeating patterns. "
"Use a cryptographically random secret."
)
return is_valid, warnings
def validate_jwt_secret_on_startup() -> None:
"""
Validate JWT secret strength on application startup.
Logs warnings for weak secrets and raises an error in production
if the secret is critically weak.
"""
import os
secret = settings.JWT_SECRET_KEY
is_valid, warnings = validate_jwt_secret_strength(secret)
# Log all warnings
for warning in warnings:
logger.warning("JWT Security Warning: %s", warning)
# In production, enforce stricter requirements
is_production = os.environ.get("ENVIRONMENT", "").lower() == "production"
if not is_valid:
if is_production:
logger.critical(
"JWT secret does not meet security requirements. "
"Application startup blocked in production mode. "
"Please configure a strong JWT_SECRET_KEY (minimum 32 characters)."
)
raise ValueError(
"JWT_SECRET_KEY does not meet minimum security requirements. "
"See logs for details."
)
else:
logger.warning(
"JWT secret does not meet security requirements. "
"This would block startup in production mode."
)
if warnings:
logger.info(
"JWT secret validation completed with %d warning(s). "
"Consider using: python -c \"import secrets; print(secrets.token_urlsafe(48))\" "
"to generate a strong secret.",
len(warnings)
)
else:
logger.info("JWT secret validation passed. Secret meets security requirements.")
# CSRF Token Functions
CSRF_TOKEN_LENGTH = 32
CSRF_TOKEN_EXPIRY_SECONDS = 3600 # 1 hour
def generate_csrf_token(user_id: str) -> str:
"""
Generate a CSRF token for a user.
The token is a combination of:
- Random bytes for unpredictability
- User ID binding to prevent token reuse across users
- HMAC signature for integrity
Args:
user_id: The user's ID to bind the token to
Returns:
CSRF token string
"""
# Generate random token
random_part = secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
# Create timestamp for expiry checking
timestamp = int(datetime.now(timezone.utc).timestamp())
# Create the token payload
payload = f"{random_part}:{user_id}:{timestamp}"
# Sign with HMAC using JWT secret
signature = hmac.new(
settings.JWT_SECRET_KEY.encode(),
payload.encode(),
hashlib.sha256
).hexdigest()[:16]
# Return combined token
return f"{payload}:{signature}"
def validate_csrf_token(token: str, user_id: str) -> Tuple[bool, str]:
"""
Validate a CSRF token.
Args:
token: The CSRF token to validate
user_id: The expected user ID
Returns:
Tuple of (is_valid, error_message)
"""
if not token:
return False, "CSRF token is required"
try:
parts = token.split(":")
if len(parts) != 4:
return False, "Invalid CSRF token format"
random_part, token_user_id, timestamp_str, signature = parts
# Verify user ID matches
if token_user_id != user_id:
return False, "CSRF token user mismatch"
# Verify timestamp (check expiry)
timestamp = int(timestamp_str)
current_time = int(datetime.now(timezone.utc).timestamp())
if current_time - timestamp > CSRF_TOKEN_EXPIRY_SECONDS:
return False, "CSRF token expired"
# Verify signature
payload = f"{random_part}:{token_user_id}:{timestamp_str}"
expected_signature = hmac.new(
settings.JWT_SECRET_KEY.encode(),
payload.encode(),
hashlib.sha256
).hexdigest()[:16]
if not hmac.compare_digest(signature, expected_signature):
return False, "CSRF token signature invalid"
return True, ""
except (ValueError, IndexError) as e:
return False, f"CSRF token validation error: {str(e)}"
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""