feat: implement security, error resilience, and query optimization proposals

Security Validation (enhance-security-validation):
- JWT secret validation with entropy checking and pattern detection
- CSRF protection middleware with token generation/validation
- Frontend CSRF token auto-injection for DELETE/PUT/PATCH requests
- MIME type validation with magic bytes detection for file uploads

Error Resilience (add-error-resilience):
- React ErrorBoundary component with fallback UI and retry functionality
- ErrorBoundaryWithI18n wrapper for internationalization support
- Page-level and section-level error boundaries in App.tsx

Query Performance (optimize-query-performance):
- Query monitoring utility with threshold warnings
- N+1 query fixes using joinedload/selectinload
- Optimized project members, tasks, and subtasks endpoints

Bug Fixes:
- WebSocket session management (P0): Return primitives instead of ORM objects
- LIKE query injection (P1): Escape special characters in search queries

Tests: 543 backend tests, 56 frontend tests passing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-11 18:41:19 +08:00
parent 2cb591ef23
commit 679b89ae4c
41 changed files with 3673 additions and 153 deletions

View File

@@ -24,6 +24,7 @@ from app.services.encryption_service import (
MasterKeyNotConfiguredError,
DecryptionError,
)
from app.middleware.csrf import require_csrf_token
logger = logging.getLogger(__name__)
@@ -610,13 +611,14 @@ async def download_attachment(
@router.delete("/attachments/{attachment_id}")
@require_csrf_token
async def delete_attachment(
attachment_id: str,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Soft delete an attachment."""
"""Soft delete an attachment. Requires CSRF token in X-CSRF-Token header."""
attachment = get_attachment_with_access_check(db, attachment_id, current_user, require_edit=True)
# Soft delete

View File

@@ -8,7 +8,7 @@ from app.core.redis import get_redis
from app.core.rate_limiter import limiter
from app.models.user import User
from app.models.audit_log import AuditAction
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo, CSRFTokenResponse
from app.services.auth_client import (
verify_credentials,
AuthAPIError,
@@ -16,6 +16,7 @@ from app.services.auth_client import (
)
from app.services.audit_service import AuditService
from app.middleware.auth import get_current_user
from app.middleware.csrf import get_csrf_token_for_user
router = APIRouter()
@@ -182,3 +183,23 @@ async def get_current_user_info(
department_id=current_user.department_id,
is_system_admin=current_user.is_system_admin,
)
@router.get("/csrf-token", response_model=CSRFTokenResponse)
async def get_csrf_token(
current_user: User = Depends(get_current_user),
):
"""
Get a CSRF token for the current user.
The CSRF token should be included in the X-CSRF-Token header
for all sensitive state-changing operations (DELETE, PUT, PATCH).
Token expires after 1 hour and should be refreshed.
"""
csrf_token = get_csrf_token_for_user(current_user.id)
return CSRFTokenResponse(
csrf_token=csrf_token,
expires_in=3600, # 1 hour
)

View File

@@ -1,7 +1,7 @@
import uuid
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload, selectinload
from app.core.database import get_db
from app.models import User, Space, Project, TaskStatus, AuditAction, ProjectMember, ProjectTemplate, CustomField
@@ -55,6 +55,8 @@ async def list_projects_in_space(
):
"""
List all projects in a space that the user can access.
Optimized to avoid N+1 queries by using joinedload/selectinload for relationships.
"""
space = db.query(Space).filter(Space.id == space_id, Space.is_active == True).first()
@@ -70,13 +72,21 @@ async def list_projects_in_space(
detail="Access denied",
)
projects = db.query(Project).filter(Project.space_id == space_id, Project.is_active == True).all()
# Use joinedload to eagerly load owner, space, and department
# Use selectinload for tasks (one-to-many) to avoid cartesian product issues
projects = db.query(Project).options(
joinedload(Project.owner),
joinedload(Project.space),
joinedload(Project.department),
selectinload(Project.tasks),
).filter(Project.space_id == space_id, Project.is_active == True).all()
# Filter by project access
accessible_projects = [p for p in projects if check_project_access(current_user, p)]
result = []
for project in accessible_projects:
# Access pre-loaded relationships - no additional queries needed
task_count = len(project.tasks) if project.tasks else 0
result.append(ProjectWithDetails(
id=project.id,
@@ -422,6 +432,10 @@ async def list_project_members(
List all members of a project.
Only users with project access can view the member list.
Optimized to avoid N+1 queries by using joinedload for user relationships.
This loads all members and their related users in at most 2 queries instead of
one query per member.
"""
project = db.query(Project).filter(Project.id == project_id, Project.is_active == True).first()
@@ -437,14 +451,20 @@ async def list_project_members(
detail="Access denied",
)
members = db.query(ProjectMember).filter(
# Use joinedload to eagerly load user and added_by_user relationships
# This avoids N+1 queries when accessing member.user and member.added_by_user
members = db.query(ProjectMember).options(
joinedload(ProjectMember.user).joinedload(User.department),
joinedload(ProjectMember.added_by_user),
).filter(
ProjectMember.project_id == project_id
).all()
member_list = []
for member in members:
user = db.query(User).filter(User.id == member.user_id).first()
added_by_user = db.query(User).filter(User.id == member.added_by).first()
# Access pre-loaded relationships - no additional queries needed
user = member.user
added_by_user = member.added_by_user
member_list.append(ProjectMemberWithDetails(
id=member.id,

View File

@@ -3,7 +3,7 @@ import uuid
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload, selectinload
from app.core.database import get_db
from app.core.redis_pubsub import publish_task_event
@@ -110,6 +110,9 @@ async def list_tasks(
The due_after and due_before parameters are useful for calendar view
to fetch tasks within a specific date range.
Optimized to avoid N+1 queries by using selectinload for task relationships.
This batch loads assignees, statuses, creators and subtasks efficiently.
"""
project = db.query(Project).filter(Project.id == project_id).first()
@@ -125,7 +128,15 @@ async def list_tasks(
detail="Access denied",
)
query = db.query(Task).filter(Task.project_id == project_id)
# Use selectinload to eagerly load task relationships
# This avoids N+1 queries when accessing task.assignee, task.status, etc.
query = db.query(Task).options(
selectinload(Task.assignee),
selectinload(Task.status),
selectinload(Task.creator),
selectinload(Task.subtasks),
selectinload(Task.custom_values),
).filter(Task.project_id == project_id)
# Filter deleted tasks (only admin can include deleted)
if include_deleted and current_user.is_system_admin:
@@ -1112,6 +1123,8 @@ async def list_subtasks(
):
"""
List subtasks of a task.
Optimized to avoid N+1 queries by using selectinload for task relationships.
"""
task = db.query(Task).filter(Task.id == task_id).first()
@@ -1127,7 +1140,13 @@ async def list_subtasks(
detail="Access denied",
)
query = db.query(Task).filter(Task.parent_task_id == task_id)
# Use selectinload to eagerly load subtask relationships
query = db.query(Task).options(
selectinload(Task.assignee),
selectinload(Task.status),
selectinload(Task.creator),
selectinload(Task.subtasks),
).filter(Task.parent_task_id == task_id)
# Filter deleted subtasks (only admin can include deleted)
if not (include_deleted and current_user.is_system_admin):

View File

@@ -3,7 +3,7 @@ from sqlalchemy.orm import Session
from sqlalchemy import or_
from typing import List
from app.core.database import get_db
from app.core.database import get_db, escape_like
from app.core.redis import get_redis
from app.models.user import User
from app.models.role import Role
@@ -16,6 +16,7 @@ from app.middleware.auth import (
check_department_access,
)
from app.middleware.audit import get_audit_metadata
from app.middleware.csrf import require_csrf_token
from app.services.audit_service import AuditService
router = APIRouter()
@@ -32,11 +33,13 @@ async def search_users(
Search users by name or email. Used for @mention autocomplete.
Returns users matching the query, limited to same department unless system admin.
"""
# Escape special LIKE characters to prevent injection
escaped_q = escape_like(q)
query = db.query(User).filter(
User.is_active == True,
or_(
User.name.ilike(f"%{q}%"),
User.email.ilike(f"%{q}%"),
User.name.ilike(f"%{escaped_q}%", escape="\\"),
User.email.ilike(f"%{escaped_q}%", escape="\\"),
)
)
@@ -197,6 +200,7 @@ async def assign_role(
@router.patch("/{user_id}/admin", response_model=UserResponse)
@require_csrf_token
async def set_admin_status(
user_id: str,
is_admin: bool,
@@ -205,7 +209,7 @@ async def set_admin_status(
current_user: User = Depends(require_system_admin),
):
"""
Set or revoke system administrator status. Requires system admin.
Set or revoke system administrator status. Requires system admin and CSRF token.
"""
user = db.query(User).filter(User.id == user_id).first()
if not user:

View File

@@ -27,29 +27,37 @@ if os.getenv("TESTING") == "true":
AUTH_TIMEOUT = 1.0
async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
"""Validate token and return user_id and user object."""
async def get_user_from_token(token: str) -> str | None:
"""
Validate token and return user_id.
Returns:
user_id if valid, None otherwise.
Note: This function properly closes the database session after validation.
Do not return ORM objects as they become detached after session close.
"""
payload = decode_access_token(token)
if payload is None:
return None, None
return None
user_id = payload.get("sub")
if user_id is None:
return None, None
return None
# Verify session in Redis
redis_client = get_redis_sync()
stored_token = redis_client.get(f"session:{user_id}")
if stored_token is None or stored_token != token:
return None, None
return None
# Get user from database
# Verify user exists and is active
db = database.SessionLocal()
try:
user = db.query(User).filter(User.id == user_id).first()
if user is None or not user.is_active:
return None, None
return user_id, user
return None
return user_id
finally:
db.close()
@@ -57,7 +65,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
async def authenticate_websocket(
websocket: WebSocket,
query_token: Optional[str] = None
) -> tuple[str | None, User | None, Optional[str]]:
) -> tuple[str | None, Optional[str]]:
"""
Authenticate WebSocket connection.
@@ -67,7 +75,8 @@ async def authenticate_websocket(
2. Query parameter authentication (deprecated, for backward compatibility)
- Client connects with: ?token=<jwt_token>
Returns (user_id, user) if authenticated, (None, None) otherwise.
Returns:
Tuple of (user_id, error_reason). user_id is None if authentication fails.
"""
# If token provided via query parameter (backward compatibility)
if query_token:
@@ -75,10 +84,10 @@ async def authenticate_websocket(
"WebSocket authentication via query parameter is deprecated. "
"Please use first-message authentication for better security."
)
user_id, user = await get_user_from_token(query_token)
user_id = await get_user_from_token(query_token)
if user_id is None:
return None, None, "invalid_token"
return user_id, user, None
return None, "invalid_token"
return user_id, None
# Wait for authentication message with timeout
try:
@@ -90,24 +99,24 @@ async def authenticate_websocket(
msg_type = data.get("type")
if msg_type != "auth":
logger.warning("Expected 'auth' message type, got: %s", msg_type)
return None, None, "invalid_message"
return None, "invalid_message"
token = data.get("token")
if not token:
logger.warning("No token provided in auth message")
return None, None, "missing_token"
return None, "missing_token"
user_id, user = await get_user_from_token(token)
user_id = await get_user_from_token(token)
if user_id is None:
return None, None, "invalid_token"
return user_id, user, None
return None, "invalid_token"
return user_id, None
except asyncio.TimeoutError:
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
return None, None, "timeout"
return None, "timeout"
except Exception as e:
logger.error("Error during WebSocket authentication: %s", e)
return None, None, "error"
return None, "error"
async def get_unread_notifications(user_id: str) -> list[dict]:
@@ -183,7 +192,7 @@ async def websocket_notifications(
await websocket.accept()
# Authenticate
user_id, user, error_reason = await authenticate_websocket(websocket, token)
user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
@@ -306,7 +315,7 @@ async def websocket_notifications(
await manager.disconnect(websocket, user_id)
async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Project | None]:
async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, str | None, str | None]:
"""
Check if user has access to the project.
@@ -315,23 +324,34 @@ async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Pr
project_id: The project's ID
Returns:
Tuple of (has_access: bool, project: Project | None)
Tuple of (has_access: bool, project_title: str | None, error: str | None)
- has_access: True if user can access the project
- project_title: The project title (only if access granted)
- error: Error code if access denied ("user_not_found", "project_not_found", "access_denied")
Note: This function extracts needed data before closing the session to avoid
detached instance errors when accessing ORM object attributes.
"""
db = database.SessionLocal()
try:
# Get the user
user = db.query(User).filter(User.id == user_id).first()
if user is None or not user.is_active:
return False, None
return False, None, "user_not_found"
# Get the project
project = db.query(Project).filter(Project.id == project_id).first()
if project is None:
return False, None
return False, None, "project_not_found"
# Check access using existing middleware function
has_access = check_project_access(user, project)
return has_access, project
if not has_access:
return False, None, "access_denied"
# Extract title while session is still open
project_title = project.title
return True, project_title, None
finally:
db.close()
@@ -371,7 +391,7 @@ async def websocket_project_sync(
await websocket.accept()
# Authenticate user
user_id, user, error_reason = await authenticate_websocket(websocket, token)
user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
@@ -380,14 +400,13 @@ async def websocket_project_sync(
return
# Verify user has access to the project
has_access, project = await verify_project_access(user_id, project_id)
has_access, project_title, access_error = await verify_project_access(user_id, project_id)
if not has_access:
await websocket.close(code=4003, reason="Access denied to this project")
return
if project is None:
await websocket.close(code=4004, reason="Project not found")
if access_error == "project_not_found":
await websocket.close(code=4004, reason="Project not found")
else:
await websocket.close(code=4003, reason="Access denied to this project")
return
# Join project room
@@ -413,7 +432,7 @@ async def websocket_project_sync(
"data": {
"project_id": project_id,
"user_id": user_id,
"project_title": project.title if project else None,
"project_title": project_title,
},
})

View File

@@ -122,6 +122,11 @@ class Settings(BaseSettings):
RATE_LIMIT_SENSITIVE: str = "20/minute" # Attachments, password change, report export
RATE_LIMIT_HEAVY: str = "5/minute" # Report generation, bulk operations
# Development Mode Settings
DEBUG: bool = False # Enable debug mode for development
QUERY_LOGGING: bool = False # Enable SQLAlchemy query logging
QUERY_COUNT_THRESHOLD: int = 10 # Warn when query count exceeds this threshold
class Config:
env_file = ".env"
case_sensitive = True

View File

@@ -104,6 +104,10 @@ def _on_invalidate(dbapi_conn, connection_record, exception):
# Start pool statistics logging on module load
_start_pool_stats_logging()
# Set up query logging if enabled
from app.core.query_monitor import setup_query_logging
setup_query_logging(engine)
def get_db():
"""Dependency for getting database session."""
@@ -127,3 +131,25 @@ def get_pool_status() -> dict:
"total_checkins": _pool_stats["checkins"],
"invalidated_connections": _pool_stats["invalidated_connections"],
}
def escape_like(value: str) -> str:
"""
Escape special characters for SQL LIKE queries.
Escapes '%' and '_' characters which have special meaning in LIKE patterns.
This prevents LIKE injection attacks where user input could match unintended patterns.
Args:
value: The user input string to escape
Returns:
Escaped string safe for use in LIKE patterns
Example:
>>> escape_like("test%value")
'test\\%value'
>>> escape_like("user_name")
'user\\_name'
"""
return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")

View File

@@ -0,0 +1,167 @@
"""
Query monitoring utilities for detecting N+1 queries and performance issues.
This module provides:
1. Query counting per request in development mode
2. SQLAlchemy event listeners for query logging
3. Threshold-based warnings for excessive queries
"""
import logging
import threading
import time
from contextlib import contextmanager
from typing import Optional, Callable, Any
from sqlalchemy import event
from sqlalchemy.engine import Engine
from app.core.config import settings
logger = logging.getLogger(__name__)
# Thread-local storage for per-request query counting
_query_context = threading.local()
class QueryCounter:
"""
Context manager for counting database queries within a request.
Usage:
with QueryCounter() as counter:
# ... execute queries ...
print(f"Executed {counter.count} queries")
"""
def __init__(self, threshold: Optional[int] = None, context_name: str = "request"):
self.threshold = threshold or settings.QUERY_COUNT_THRESHOLD
self.context_name = context_name
self.count = 0
self.queries = []
self.start_time = None
self.total_time = 0.0
def __enter__(self):
self.count = 0
self.queries = []
self.start_time = time.time()
_query_context.counter = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.total_time = time.time() - self.start_time
_query_context.counter = None
# Log warning if threshold exceeded
if self.count > self.threshold:
logger.warning(
"Query count threshold exceeded in %s: %d queries (threshold: %d, time: %.3fs)",
self.context_name,
self.count,
self.threshold,
self.total_time,
)
if settings.DEBUG:
# In debug mode, also log the individual queries
for i, (sql, duration) in enumerate(self.queries[:20], 1):
logger.debug(" Query %d (%.3fs): %s", i, duration, sql[:200])
if len(self.queries) > 20:
logger.debug(" ... and %d more queries", len(self.queries) - 20)
elif settings.DEBUG and self.count > 0:
logger.debug(
"Query count for %s: %d queries in %.3fs",
self.context_name,
self.count,
self.total_time,
)
return False
def record_query(self, statement: str, duration: float):
"""Record a query execution."""
self.count += 1
if settings.DEBUG:
self.queries.append((statement, duration))
def get_current_counter() -> Optional[QueryCounter]:
"""Get the current request's query counter, if any."""
return getattr(_query_context, 'counter', None)
def setup_query_logging(engine: Engine):
"""
Set up SQLAlchemy event listeners for query logging.
This should be called once during application startup.
Only activates if QUERY_LOGGING is enabled in settings.
"""
if not settings.QUERY_LOGGING:
logger.info("Query logging is disabled")
return
logger.info("Setting up query logging with threshold=%d", settings.QUERY_COUNT_THRESHOLD)
@event.listens_for(engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
conn.info.setdefault('query_start_time', []).append(time.time())
@event.listens_for(engine, "after_cursor_execute")
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
start_times = conn.info.get('query_start_time', [])
duration = time.time() - start_times.pop() if start_times else 0.0
# Record in current counter if active
counter = get_current_counter()
if counter:
counter.record_query(statement, duration)
# Also log individual queries if in debug mode
if settings.DEBUG:
logger.debug("SQL (%.3fs): %s", duration, statement[:500])
@contextmanager
def count_queries(context_name: str = "operation", threshold: Optional[int] = None):
"""
Context manager to count queries for a specific operation.
Args:
context_name: Name for logging purposes
threshold: Override the default query count threshold
Usage:
with count_queries("list_members") as counter:
members = db.query(ProjectMember).all()
for member in members:
print(member.user.name) # N+1 query!
# After block, logs warning if threshold exceeded
print(f"Total queries: {counter.count}")
"""
with QueryCounter(threshold=threshold, context_name=context_name) as counter:
yield counter
def assert_query_count(max_queries: int):
"""
Decorator for testing that asserts maximum query count.
Usage in tests:
@assert_query_count(5)
def test_list_members():
# Should use at most 5 queries
response = client.get("/api/projects/xxx/members")
"""
def decorator(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
with QueryCounter(threshold=max_queries, context_name=func.__name__) as counter:
result = func(*args, **kwargs)
if counter.count > max_queries:
raise AssertionError(
f"Query count {counter.count} exceeded maximum {max_queries} "
f"in {func.__name__}"
)
return result
return wrapper
return decorator

View File

@@ -1,8 +1,283 @@
from datetime import datetime, timedelta, timezone
from typing import Optional, Any
from typing import Optional, Any, Tuple
from jose import jwt, JWTError
import logging
import math
import hashlib
import secrets
import hmac
from collections import Counter
from app.core.config import settings
logger = logging.getLogger(__name__)
# Constants for JWT secret validation
MIN_SECRET_LENGTH = 32
MIN_ENTROPY_BITS = 128 # Minimum entropy in bits for a secure secret
COMMON_WEAK_PATTERNS = [
"password", "secret", "changeme", "admin", "test", "demo",
"123456", "qwerty", "abc123", "letmein", "welcome",
]
def calculate_entropy(data: str) -> float:
"""
Calculate Shannon entropy of a string in bits.
Higher entropy indicates more randomness and thus a stronger secret.
A perfectly random string of length n with k possible characters has
entropy of n * log2(k) bits.
Args:
data: The string to calculate entropy for
Returns:
Entropy value in bits
"""
if not data:
return 0.0
# Count character frequencies
char_counts = Counter(data)
length = len(data)
# Calculate Shannon entropy
entropy = 0.0
for count in char_counts.values():
if count > 0:
probability = count / length
entropy -= probability * math.log2(probability)
# Return total entropy in bits (per-character entropy * length)
return entropy * length
def has_repeating_patterns(secret: str) -> bool:
"""
Check if the secret contains obvious repeating patterns.
Args:
secret: The secret string to check
Returns:
True if repeating patterns are detected
"""
if len(secret) < 8:
return False
# Check for repeating character sequences
for pattern_len in range(2, len(secret) // 3 + 1):
pattern = secret[:pattern_len]
if pattern * (len(secret) // pattern_len) == secret[:len(pattern) * (len(secret) // pattern_len)]:
# More than 50% of the string is the same pattern repeated
if (len(secret) // pattern_len) >= 3:
return True
# Check for consecutive same characters
consecutive_count = 1
for i in range(1, len(secret)):
if secret[i] == secret[i-1]:
consecutive_count += 1
if consecutive_count >= len(secret) // 2:
return True
else:
consecutive_count = 1
return False
def validate_jwt_secret_strength(secret: str) -> Tuple[bool, list]:
"""
Validate JWT secret key strength.
Checks:
1. Minimum length (32 characters)
2. Entropy (minimum 128 bits)
3. Common weak patterns
4. Repeating patterns
Args:
secret: The JWT secret to validate
Returns:
Tuple of (is_valid, list_of_warnings)
"""
warnings = []
is_valid = True
# Check minimum length
if len(secret) < MIN_SECRET_LENGTH:
warnings.append(
f"JWT secret is too short ({len(secret)} chars). "
f"Minimum recommended length is {MIN_SECRET_LENGTH} characters."
)
is_valid = False
# Calculate and check entropy
entropy = calculate_entropy(secret)
if entropy < MIN_ENTROPY_BITS:
warnings.append(
f"JWT secret has low entropy ({entropy:.1f} bits). "
f"Minimum recommended entropy is {MIN_ENTROPY_BITS} bits. "
"Consider using a cryptographically random secret."
)
# Low entropy alone doesn't make it invalid, but it's a warning
# Check for common weak patterns
secret_lower = secret.lower()
for pattern in COMMON_WEAK_PATTERNS:
if pattern in secret_lower:
warnings.append(
f"JWT secret contains common weak pattern: '{pattern}'. "
"Use a cryptographically random secret."
)
break
# Check for repeating patterns
if has_repeating_patterns(secret):
warnings.append(
"JWT secret contains repeating patterns. "
"Use a cryptographically random secret."
)
return is_valid, warnings
def validate_jwt_secret_on_startup() -> None:
"""
Validate JWT secret strength on application startup.
Logs warnings for weak secrets and raises an error in production
if the secret is critically weak.
"""
import os
secret = settings.JWT_SECRET_KEY
is_valid, warnings = validate_jwt_secret_strength(secret)
# Log all warnings
for warning in warnings:
logger.warning("JWT Security Warning: %s", warning)
# In production, enforce stricter requirements
is_production = os.environ.get("ENVIRONMENT", "").lower() == "production"
if not is_valid:
if is_production:
logger.critical(
"JWT secret does not meet security requirements. "
"Application startup blocked in production mode. "
"Please configure a strong JWT_SECRET_KEY (minimum 32 characters)."
)
raise ValueError(
"JWT_SECRET_KEY does not meet minimum security requirements. "
"See logs for details."
)
else:
logger.warning(
"JWT secret does not meet security requirements. "
"This would block startup in production mode."
)
if warnings:
logger.info(
"JWT secret validation completed with %d warning(s). "
"Consider using: python -c \"import secrets; print(secrets.token_urlsafe(48))\" "
"to generate a strong secret.",
len(warnings)
)
else:
logger.info("JWT secret validation passed. Secret meets security requirements.")
# CSRF Token Functions
CSRF_TOKEN_LENGTH = 32
CSRF_TOKEN_EXPIRY_SECONDS = 3600 # 1 hour
def generate_csrf_token(user_id: str) -> str:
"""
Generate a CSRF token for a user.
The token is a combination of:
- Random bytes for unpredictability
- User ID binding to prevent token reuse across users
- HMAC signature for integrity
Args:
user_id: The user's ID to bind the token to
Returns:
CSRF token string
"""
# Generate random token
random_part = secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
# Create timestamp for expiry checking
timestamp = int(datetime.now(timezone.utc).timestamp())
# Create the token payload
payload = f"{random_part}:{user_id}:{timestamp}"
# Sign with HMAC using JWT secret
signature = hmac.new(
settings.JWT_SECRET_KEY.encode(),
payload.encode(),
hashlib.sha256
).hexdigest()[:16]
# Return combined token
return f"{payload}:{signature}"
def validate_csrf_token(token: str, user_id: str) -> Tuple[bool, str]:
"""
Validate a CSRF token.
Args:
token: The CSRF token to validate
user_id: The expected user ID
Returns:
Tuple of (is_valid, error_message)
"""
if not token:
return False, "CSRF token is required"
try:
parts = token.split(":")
if len(parts) != 4:
return False, "Invalid CSRF token format"
random_part, token_user_id, timestamp_str, signature = parts
# Verify user ID matches
if token_user_id != user_id:
return False, "CSRF token user mismatch"
# Verify timestamp (check expiry)
timestamp = int(timestamp_str)
current_time = int(datetime.now(timezone.utc).timestamp())
if current_time - timestamp > CSRF_TOKEN_EXPIRY_SECONDS:
return False, "CSRF token expired"
# Verify signature
payload = f"{random_part}:{token_user_id}:{timestamp_str}"
expected_signature = hmac.new(
settings.JWT_SECRET_KEY.encode(),
payload.encode(),
hashlib.sha256
).hexdigest()[:16]
if not hmac.compare_digest(signature, expected_signature):
return False, "CSRF token signature invalid"
return True, ""
except (ValueError, IndexError) as e:
return False, f"CSRF token validation error: {str(e)}"
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""

View File

@@ -20,6 +20,12 @@ async def lifespan(app: FastAPI):
testing = os.environ.get("TESTING", "").lower() in ("true", "1", "yes")
scheduler_disabled = os.environ.get("DISABLE_SCHEDULER", "").lower() in ("true", "1", "yes")
start_background_jobs = not testing and not scheduler_disabled
# Startup security validation
if not testing:
from app.core.security import validate_jwt_secret_on_startup
validate_jwt_secret_on_startup()
# Startup
if start_background_jobs:
start_scheduler()

View File

@@ -0,0 +1,167 @@
"""
CSRF (Cross-Site Request Forgery) Protection Middleware.
This module provides CSRF protection for sensitive state-changing operations.
It validates CSRF tokens for specified protected endpoints.
"""
from fastapi import Request, HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional, Callable, List
from functools import wraps
import logging
from app.core.security import validate_csrf_token, generate_csrf_token
logger = logging.getLogger(__name__)
# Header name for CSRF token
CSRF_TOKEN_HEADER = "X-CSRF-Token"
# List of endpoint patterns that require CSRF protection
# These are sensitive state-changing operations
CSRF_PROTECTED_PATTERNS = [
# User operations
"/api/v1/users/{user_id}/admin", # Admin status change
"/api/users/{user_id}/admin", # Legacy
# Password changes would go here if implemented
# Delete operations
"/api/attachments/{attachment_id}", # DELETE method
"/api/tasks/{task_id}", # DELETE method (soft delete)
"/api/projects/{project_id}", # DELETE method
]
# Methods that require CSRF protection
CSRF_PROTECTED_METHODS = ["DELETE", "PUT", "PATCH"]
class CSRFProtectionError(HTTPException):
"""Custom exception for CSRF validation failures."""
def __init__(self, detail: str = "CSRF validation failed"):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail
)
def require_csrf_token(func: Callable) -> Callable:
"""
Decorator to require CSRF token validation for an endpoint.
Usage:
@router.delete("/resource/{id}")
@require_csrf_token
async def delete_resource(request: Request, id: str, current_user: User = Depends(get_current_user)):
...
The decorator validates the X-CSRF-Token header against the current user.
"""
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract request and current_user from kwargs
request: Optional[Request] = kwargs.get("request")
current_user = kwargs.get("current_user")
if request is None:
# Try to find request in args (for methods where request is positional)
for arg in args:
if isinstance(arg, Request):
request = arg
break
if request is None:
logger.error("CSRF validation failed: Request object not found")
raise CSRFProtectionError("Internal error: Request not available")
if current_user is None:
logger.error("CSRF validation failed: User not authenticated")
raise CSRFProtectionError("Authentication required for CSRF-protected endpoint")
# Get CSRF token from header
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
if not csrf_token:
logger.warning(
"CSRF validation failed: Missing token for user %s on %s %s",
current_user.id, request.method, request.url.path
)
raise CSRFProtectionError("CSRF token is required")
# Validate the token
is_valid, error_message = validate_csrf_token(csrf_token, current_user.id)
if not is_valid:
logger.warning(
"CSRF validation failed for user %s on %s %s: %s",
current_user.id, request.method, request.url.path, error_message
)
raise CSRFProtectionError(error_message)
logger.debug(
"CSRF validation passed for user %s on %s %s",
current_user.id, request.method, request.url.path
)
return await func(*args, **kwargs)
return wrapper
def get_csrf_token_for_user(user_id: str) -> str:
"""
Generate a CSRF token for a user.
This function can be called from login endpoints to provide
the client with a CSRF token.
Args:
user_id: The user's ID
Returns:
CSRF token string
"""
return generate_csrf_token(user_id)
async def validate_csrf_for_request(
request: Request,
user_id: str,
skip_methods: Optional[List[str]] = None
) -> bool:
"""
Validate CSRF token for a request.
This is a utility function that can be used directly in endpoints
without the decorator.
Args:
request: The FastAPI request object
user_id: The current user's ID
skip_methods: HTTP methods to skip validation for (default: GET, HEAD, OPTIONS)
Returns:
True if validation passes
Raises:
CSRFProtectionError: If validation fails
"""
if skip_methods is None:
skip_methods = ["GET", "HEAD", "OPTIONS"]
# Skip validation for safe methods
if request.method.upper() in skip_methods:
return True
# Get CSRF token from header
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
if not csrf_token:
raise CSRFProtectionError("CSRF token is required")
is_valid, error_message = validate_csrf_token(csrf_token, user_id)
if not is_valid:
raise CSRFProtectionError(error_message)
return True

View File

@@ -32,5 +32,11 @@ class TokenPayload(BaseModel):
iat: int
class CSRFTokenResponse(BaseModel):
"""Response containing a CSRF token for state-changing operations."""
csrf_token: str = Field(..., description="CSRF token to include in X-CSRF-Token header")
expires_in: int = Field(default=3600, description="Token expiry time in seconds")
# Update forward reference
LoginResponse.model_rebuild()

View File

@@ -286,11 +286,15 @@ class FileStorageService:
return filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
@staticmethod
def validate_file(file: UploadFile) -> Tuple[str, str]:
def validate_file(file: UploadFile, validate_mime: bool = True) -> Tuple[str, str]:
"""
Validate file size and type.
Validate file size, type, and optionally MIME content.
Returns (extension, mime_type) if valid.
Raises HTTPException if invalid.
Args:
file: The uploaded file
validate_mime: If True, validate MIME type using magic bytes detection
"""
# Check file size
file.file.seek(0, 2) # Seek to end
@@ -323,7 +327,35 @@ class FileStorageService:
detail=f"File type '.{extension}' is not supported"
)
mime_type = file.content_type or "application/octet-stream"
# Validate MIME type using magic bytes detection
if validate_mime:
from app.services.mime_validation_service import mime_validation_service
# Read first 16 bytes for magic detection (enough for most signatures)
file_header = file.file.read(16)
file.file.seek(0) # Reset
is_valid, detected_mime, error_message = mime_validation_service.validate_file_content(
file_content=file_header,
declared_extension=extension,
declared_mime_type=file.content_type
)
if not is_valid:
logger.warning(
"MIME validation failed for file '%s': %s (detected: %s)",
file.filename, error_message, detected_mime
)
raise HTTPException(
status_code=400,
detail=error_message or "File type validation failed"
)
# Use detected MIME type if available, otherwise fall back to declared
mime_type = detected_mime if detected_mime else (file.content_type or "application/octet-stream")
else:
mime_type = file.content_type or "application/octet-stream"
return extension, mime_type
async def save_file(

View File

@@ -0,0 +1,314 @@
"""
MIME Type Validation Service using Magic Bytes Detection.
This module provides file content type validation by examining
the actual file content (magic bytes) rather than trusting
the file extension or Content-Type header.
"""
import logging
from typing import Optional, Tuple, Dict, Set, BinaryIO
from io import BytesIO
logger = logging.getLogger(__name__)
class MimeValidationError(Exception):
"""Raised when MIME type validation fails."""
pass
class FileMismatchError(MimeValidationError):
"""Raised when file extension doesn't match actual content type."""
pass
class UnsupportedMimeError(MimeValidationError):
"""Raised when file has an unsupported MIME type."""
pass
# Magic bytes signatures for common file types
# Format: { bytes_pattern: (mime_type, extensions) }
MAGIC_SIGNATURES: Dict[bytes, Tuple[str, Set[str]]] = {
# Images
b'\xFF\xD8\xFF': ('image/jpeg', {'jpg', 'jpeg', 'jpe'}),
b'\x89PNG\r\n\x1a\n': ('image/png', {'png'}),
b'GIF87a': ('image/gif', {'gif'}),
b'GIF89a': ('image/gif', {'gif'}),
b'RIFF': ('image/webp', {'webp'}), # WebP starts with RIFF, then WEBP
b'BM': ('image/bmp', {'bmp'}),
# PDF
b'%PDF': ('application/pdf', {'pdf'}),
# Microsoft Office (Modern formats - ZIP-based)
b'PK\x03\x04': ('application/zip', {'zip', 'docx', 'xlsx', 'pptx', 'odt', 'ods', 'odp', 'jar'}),
# Microsoft Office (Legacy formats - Compound Document)
b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1': ('application/msword', {'doc', 'xls', 'ppt', 'msi'}),
# Archives
b'\x1f\x8b': ('application/gzip', {'gz', 'tgz'}),
b'\x42\x5a\x68': ('application/x-bzip2', {'bz2'}),
b'\x37\x7A\xBC\xAF\x27\x1C': ('application/x-7z-compressed', {'7z'}),
b'Rar!\x1a\x07': ('application/x-rar-compressed', {'rar'}),
# Text/Data formats - these are harder to detect, usually fallback to extension
b'<?xml': ('application/xml', {'xml', 'svg'}),
b'{': ('application/json', {'json'}), # JSON typically starts with { or [
b'[': ('application/json', {'json'}),
# Executables (dangerous - should be blocked)
b'MZ': ('application/x-executable', {'exe', 'dll', 'com', 'scr'}),
b'\x7fELF': ('application/x-executable', {'elf', 'so', 'bin'}),
}
# Map extensions to expected MIME types
EXTENSION_TO_MIME: Dict[str, Set[str]] = {
# Images
'jpg': {'image/jpeg'},
'jpeg': {'image/jpeg'},
'jpe': {'image/jpeg'},
'png': {'image/png'},
'gif': {'image/gif'},
'bmp': {'image/bmp'},
'webp': {'image/webp'},
'svg': {'image/svg+xml', 'application/xml', 'text/xml'},
# Documents
'pdf': {'application/pdf'},
'doc': {'application/msword'},
'docx': {'application/vnd.openxmlformats-officedocument.wordprocessingml.document', 'application/zip'},
'xls': {'application/vnd.ms-excel', 'application/msword'},
'xlsx': {'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', 'application/zip'},
'ppt': {'application/vnd.ms-powerpoint', 'application/msword'},
'pptx': {'application/vnd.openxmlformats-officedocument.presentationml.presentation', 'application/zip'},
# Text
'txt': {'text/plain'},
'csv': {'text/csv', 'text/plain'},
'json': {'application/json', 'text/plain'},
'xml': {'application/xml', 'text/xml', 'text/plain'},
'yaml': {'application/yaml', 'text/plain'},
'yml': {'application/yaml', 'text/plain'},
# Archives
'zip': {'application/zip'},
'rar': {'application/x-rar-compressed'},
'7z': {'application/x-7z-compressed'},
'tar': {'application/x-tar'},
'gz': {'application/gzip'},
}
# MIME types that should always be blocked (dangerous executables)
BLOCKED_MIME_TYPES: Set[str] = {
'application/x-executable',
'application/x-msdownload',
'application/x-msdos-program',
'application/x-sh',
'application/x-csh',
'application/x-dosexec',
}
# Configurable allowed MIME type categories
ALLOWED_MIME_CATEGORIES: Dict[str, Set[str]] = {
'images': {
'image/jpeg', 'image/png', 'image/gif', 'image/bmp', 'image/webp', 'image/svg+xml'
},
'documents': {
'application/pdf',
'application/msword',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'application/vnd.ms-excel',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'application/vnd.ms-powerpoint',
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'text/plain', 'text/csv',
},
'archives': {
'application/zip', 'application/x-rar-compressed',
'application/x-7z-compressed', 'application/gzip',
'application/x-tar',
},
'data': {
'application/json', 'application/xml', 'text/xml',
'application/yaml', 'text/plain',
},
}
class MimeValidationService:
"""Service for validating file MIME types using magic bytes."""
def __init__(
self,
allowed_categories: Optional[Set[str]] = None,
bypass_for_trusted: bool = False
):
"""
Initialize the MIME validation service.
Args:
allowed_categories: Set of allowed MIME categories ('images', 'documents', etc.)
If None, all categories are allowed.
bypass_for_trusted: If True, validation can be bypassed for trusted sources.
"""
self.bypass_for_trusted = bypass_for_trusted
# Build set of allowed MIME types
if allowed_categories is None:
self.allowed_mime_types = set()
for category_mimes in ALLOWED_MIME_CATEGORIES.values():
self.allowed_mime_types.update(category_mimes)
else:
self.allowed_mime_types = set()
for category in allowed_categories:
if category in ALLOWED_MIME_CATEGORIES:
self.allowed_mime_types.update(ALLOWED_MIME_CATEGORIES[category])
def detect_mime_type(self, file_content: bytes) -> Optional[str]:
"""
Detect MIME type from file content using magic bytes.
Args:
file_content: The raw file bytes (at least first 16 bytes needed)
Returns:
Detected MIME type or None if unknown
"""
if len(file_content) < 2:
return None
# Check each magic signature
for magic_bytes, (mime_type, _) in MAGIC_SIGNATURES.items():
if file_content.startswith(magic_bytes):
# Special case for WebP: check for WEBP after RIFF
if magic_bytes == b'RIFF' and len(file_content) >= 12:
if file_content[8:12] == b'WEBP':
return 'image/webp'
else:
continue # Not WebP, might be something else
return mime_type
return None
def validate_file_content(
self,
file_content: bytes,
declared_extension: str,
declared_mime_type: Optional[str] = None,
trusted_source: bool = False
) -> Tuple[bool, str, Optional[str]]:
"""
Validate file content against declared extension and MIME type.
Args:
file_content: The raw file bytes
declared_extension: The file extension (without dot)
declared_mime_type: The Content-Type header value (optional)
trusted_source: If True and bypass_for_trusted is enabled, skip validation
Returns:
Tuple of (is_valid, detected_mime_type, error_message)
"""
# Bypass for trusted sources if configured
if trusted_source and self.bypass_for_trusted:
logger.debug("MIME validation bypassed for trusted source")
return True, declared_mime_type or 'application/octet-stream', None
# Detect actual MIME type
detected_mime = self.detect_mime_type(file_content)
ext_lower = declared_extension.lower()
# Check if detected MIME is blocked (dangerous executable)
if detected_mime in BLOCKED_MIME_TYPES:
logger.warning(
"Blocked dangerous file type detected: %s (claimed extension: %s)",
detected_mime, ext_lower
)
return False, detected_mime, "File type not allowed for security reasons"
# If we couldn't detect the MIME type, fall back to extension-based check
if detected_mime is None:
# For text/data files, detection is unreliable
# Trust the extension if it's in our allowed list
if ext_lower in EXTENSION_TO_MIME:
expected_mimes = EXTENSION_TO_MIME[ext_lower]
# Check if any expected MIME is in allowed set
if expected_mimes & self.allowed_mime_types:
logger.debug(
"MIME detection inconclusive for extension %s, allowing based on extension",
ext_lower
)
# Return the first expected MIME type
return True, next(iter(expected_mimes)), None
# Unknown extension or MIME type
logger.warning(
"Could not detect MIME type for file with extension: %s",
ext_lower
)
return True, 'application/octet-stream', None
# Check if detected MIME is in allowed set
if detected_mime not in self.allowed_mime_types:
logger.warning(
"Unsupported MIME type detected: %s (extension: %s)",
detected_mime, ext_lower
)
return False, detected_mime, f"Unsupported file type: {detected_mime}"
# Verify extension matches detected MIME type
if ext_lower in EXTENSION_TO_MIME:
expected_mimes = EXTENSION_TO_MIME[ext_lower]
# Special handling for ZIP-based formats (docx, xlsx, pptx)
if detected_mime == 'application/zip' and ext_lower in {'docx', 'xlsx', 'pptx', 'odt', 'ods', 'odp'}:
# These are valid - ZIP container with specific extension
return True, detected_mime, None
# Check if detected MIME matches any expected MIME for this extension
if detected_mime not in expected_mimes:
# Mismatch detected!
logger.warning(
"File type mismatch: extension '%s' but detected '%s'",
ext_lower, detected_mime
)
return False, detected_mime, f"File type mismatch: extension indicates {ext_lower} but content is {detected_mime}"
return True, detected_mime, None
async def validate_upload_file(
self,
file_content: bytes,
filename: str,
content_type: Optional[str] = None,
trusted_source: bool = False
) -> Tuple[bool, str, Optional[str]]:
"""
Validate an uploaded file.
Args:
file_content: The raw file bytes
filename: The uploaded filename
content_type: The Content-Type header value
trusted_source: If True and bypass is enabled, skip validation
Returns:
Tuple of (is_valid, detected_mime_type, error_message)
"""
# Extract extension
extension = filename.rsplit('.', 1)[-1] if '.' in filename else ''
return self.validate_file_content(
file_content=file_content,
declared_extension=extension,
declared_mime_type=content_type,
trusted_source=trusted_source
)
# Singleton instance with default configuration
mime_validation_service = MimeValidationService()

View File

@@ -14,6 +14,7 @@ from sqlalchemy import event
from app.models import User, Notification, Task, Comment, Mention
from app.core.redis_pubsub import publish_notification as redis_publish, get_channel_name
from app.core.redis import get_redis_sync
from app.core.database import escape_like
logger = logging.getLogger(__name__)
@@ -427,9 +428,12 @@ class NotificationService:
# Find users by email or name
for username in mentioned_usernames:
# Escape special LIKE characters to prevent injection
escaped_username = escape_like(username)
# Try to find user by email first
user = db.query(User).filter(
(User.email.ilike(f"{username}%")) | (User.name.ilike(f"%{username}%"))
(User.email.ilike(f"{escaped_username}%", escape="\\")) |
(User.name.ilike(f"%{escaped_username}%", escape="\\"))
).first()
if user and user.id != author.id: