feat: implement security, error resilience, and query optimization proposals
Security Validation (enhance-security-validation): - JWT secret validation with entropy checking and pattern detection - CSRF protection middleware with token generation/validation - Frontend CSRF token auto-injection for DELETE/PUT/PATCH requests - MIME type validation with magic bytes detection for file uploads Error Resilience (add-error-resilience): - React ErrorBoundary component with fallback UI and retry functionality - ErrorBoundaryWithI18n wrapper for internationalization support - Page-level and section-level error boundaries in App.tsx Query Performance (optimize-query-performance): - Query monitoring utility with threshold warnings - N+1 query fixes using joinedload/selectinload - Optimized project members, tasks, and subtasks endpoints Bug Fixes: - WebSocket session management (P0): Return primitives instead of ORM objects - LIKE query injection (P1): Escape special characters in search queries Tests: 543 backend tests, 56 frontend tests passing Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from app.services.encryption_service import (
|
||||
MasterKeyNotConfiguredError,
|
||||
DecryptionError,
|
||||
)
|
||||
from app.middleware.csrf import require_csrf_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -610,13 +611,14 @@ async def download_attachment(
|
||||
|
||||
|
||||
@router.delete("/attachments/{attachment_id}")
|
||||
@require_csrf_token
|
||||
async def delete_attachment(
|
||||
attachment_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Soft delete an attachment."""
|
||||
"""Soft delete an attachment. Requires CSRF token in X-CSRF-Token header."""
|
||||
attachment = get_attachment_with_access_check(db, attachment_id, current_user, require_edit=True)
|
||||
|
||||
# Soft delete
|
||||
|
||||
@@ -8,7 +8,7 @@ from app.core.redis import get_redis
|
||||
from app.core.rate_limiter import limiter
|
||||
from app.models.user import User
|
||||
from app.models.audit_log import AuditAction
|
||||
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo
|
||||
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo, CSRFTokenResponse
|
||||
from app.services.auth_client import (
|
||||
verify_credentials,
|
||||
AuthAPIError,
|
||||
@@ -16,6 +16,7 @@ from app.services.auth_client import (
|
||||
)
|
||||
from app.services.audit_service import AuditService
|
||||
from app.middleware.auth import get_current_user
|
||||
from app.middleware.csrf import get_csrf_token_for_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -182,3 +183,23 @@ async def get_current_user_info(
|
||||
department_id=current_user.department_id,
|
||||
is_system_admin=current_user.is_system_admin,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/csrf-token", response_model=CSRFTokenResponse)
|
||||
async def get_csrf_token(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Get a CSRF token for the current user.
|
||||
|
||||
The CSRF token should be included in the X-CSRF-Token header
|
||||
for all sensitive state-changing operations (DELETE, PUT, PATCH).
|
||||
|
||||
Token expires after 1 hour and should be refreshed.
|
||||
"""
|
||||
csrf_token = get_csrf_token_for_user(current_user.id)
|
||||
|
||||
return CSRFTokenResponse(
|
||||
csrf_token=csrf_token,
|
||||
expires_in=3600, # 1 hour
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models import User, Space, Project, TaskStatus, AuditAction, ProjectMember, ProjectTemplate, CustomField
|
||||
@@ -55,6 +55,8 @@ async def list_projects_in_space(
|
||||
):
|
||||
"""
|
||||
List all projects in a space that the user can access.
|
||||
|
||||
Optimized to avoid N+1 queries by using joinedload/selectinload for relationships.
|
||||
"""
|
||||
space = db.query(Space).filter(Space.id == space_id, Space.is_active == True).first()
|
||||
|
||||
@@ -70,13 +72,21 @@ async def list_projects_in_space(
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
projects = db.query(Project).filter(Project.space_id == space_id, Project.is_active == True).all()
|
||||
# Use joinedload to eagerly load owner, space, and department
|
||||
# Use selectinload for tasks (one-to-many) to avoid cartesian product issues
|
||||
projects = db.query(Project).options(
|
||||
joinedload(Project.owner),
|
||||
joinedload(Project.space),
|
||||
joinedload(Project.department),
|
||||
selectinload(Project.tasks),
|
||||
).filter(Project.space_id == space_id, Project.is_active == True).all()
|
||||
|
||||
# Filter by project access
|
||||
accessible_projects = [p for p in projects if check_project_access(current_user, p)]
|
||||
|
||||
result = []
|
||||
for project in accessible_projects:
|
||||
# Access pre-loaded relationships - no additional queries needed
|
||||
task_count = len(project.tasks) if project.tasks else 0
|
||||
result.append(ProjectWithDetails(
|
||||
id=project.id,
|
||||
@@ -422,6 +432,10 @@ async def list_project_members(
|
||||
List all members of a project.
|
||||
|
||||
Only users with project access can view the member list.
|
||||
|
||||
Optimized to avoid N+1 queries by using joinedload for user relationships.
|
||||
This loads all members and their related users in at most 2 queries instead of
|
||||
one query per member.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id, Project.is_active == True).first()
|
||||
|
||||
@@ -437,14 +451,20 @@ async def list_project_members(
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
members = db.query(ProjectMember).filter(
|
||||
# Use joinedload to eagerly load user and added_by_user relationships
|
||||
# This avoids N+1 queries when accessing member.user and member.added_by_user
|
||||
members = db.query(ProjectMember).options(
|
||||
joinedload(ProjectMember.user).joinedload(User.department),
|
||||
joinedload(ProjectMember.added_by_user),
|
||||
).filter(
|
||||
ProjectMember.project_id == project_id
|
||||
).all()
|
||||
|
||||
member_list = []
|
||||
for member in members:
|
||||
user = db.query(User).filter(User.id == member.user_id).first()
|
||||
added_by_user = db.query(User).filter(User.id == member.added_by).first()
|
||||
# Access pre-loaded relationships - no additional queries needed
|
||||
user = member.user
|
||||
added_by_user = member.added_by_user
|
||||
|
||||
member_list.append(ProjectMemberWithDetails(
|
||||
id=member.id,
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.redis_pubsub import publish_task_event
|
||||
@@ -110,6 +110,9 @@ async def list_tasks(
|
||||
|
||||
The due_after and due_before parameters are useful for calendar view
|
||||
to fetch tasks within a specific date range.
|
||||
|
||||
Optimized to avoid N+1 queries by using selectinload for task relationships.
|
||||
This batch loads assignees, statuses, creators and subtasks efficiently.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
|
||||
@@ -125,7 +128,15 @@ async def list_tasks(
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
query = db.query(Task).filter(Task.project_id == project_id)
|
||||
# Use selectinload to eagerly load task relationships
|
||||
# This avoids N+1 queries when accessing task.assignee, task.status, etc.
|
||||
query = db.query(Task).options(
|
||||
selectinload(Task.assignee),
|
||||
selectinload(Task.status),
|
||||
selectinload(Task.creator),
|
||||
selectinload(Task.subtasks),
|
||||
selectinload(Task.custom_values),
|
||||
).filter(Task.project_id == project_id)
|
||||
|
||||
# Filter deleted tasks (only admin can include deleted)
|
||||
if include_deleted and current_user.is_system_admin:
|
||||
@@ -1112,6 +1123,8 @@ async def list_subtasks(
|
||||
):
|
||||
"""
|
||||
List subtasks of a task.
|
||||
|
||||
Optimized to avoid N+1 queries by using selectinload for task relationships.
|
||||
"""
|
||||
task = db.query(Task).filter(Task.id == task_id).first()
|
||||
|
||||
@@ -1127,7 +1140,13 @@ async def list_subtasks(
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
query = db.query(Task).filter(Task.parent_task_id == task_id)
|
||||
# Use selectinload to eagerly load subtask relationships
|
||||
query = db.query(Task).options(
|
||||
selectinload(Task.assignee),
|
||||
selectinload(Task.status),
|
||||
selectinload(Task.creator),
|
||||
selectinload(Task.subtasks),
|
||||
).filter(Task.parent_task_id == task_id)
|
||||
|
||||
# Filter deleted subtasks (only admin can include deleted)
|
||||
if not (include_deleted and current_user.is_system_admin):
|
||||
|
||||
@@ -3,7 +3,7 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_
|
||||
from typing import List
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.database import get_db, escape_like
|
||||
from app.core.redis import get_redis
|
||||
from app.models.user import User
|
||||
from app.models.role import Role
|
||||
@@ -16,6 +16,7 @@ from app.middleware.auth import (
|
||||
check_department_access,
|
||||
)
|
||||
from app.middleware.audit import get_audit_metadata
|
||||
from app.middleware.csrf import require_csrf_token
|
||||
from app.services.audit_service import AuditService
|
||||
|
||||
router = APIRouter()
|
||||
@@ -32,11 +33,13 @@ async def search_users(
|
||||
Search users by name or email. Used for @mention autocomplete.
|
||||
Returns users matching the query, limited to same department unless system admin.
|
||||
"""
|
||||
# Escape special LIKE characters to prevent injection
|
||||
escaped_q = escape_like(q)
|
||||
query = db.query(User).filter(
|
||||
User.is_active == True,
|
||||
or_(
|
||||
User.name.ilike(f"%{q}%"),
|
||||
User.email.ilike(f"%{q}%"),
|
||||
User.name.ilike(f"%{escaped_q}%", escape="\\"),
|
||||
User.email.ilike(f"%{escaped_q}%", escape="\\"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -197,6 +200,7 @@ async def assign_role(
|
||||
|
||||
|
||||
@router.patch("/{user_id}/admin", response_model=UserResponse)
|
||||
@require_csrf_token
|
||||
async def set_admin_status(
|
||||
user_id: str,
|
||||
is_admin: bool,
|
||||
@@ -205,7 +209,7 @@ async def set_admin_status(
|
||||
current_user: User = Depends(require_system_admin),
|
||||
):
|
||||
"""
|
||||
Set or revoke system administrator status. Requires system admin.
|
||||
Set or revoke system administrator status. Requires system admin and CSRF token.
|
||||
"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
|
||||
@@ -27,29 +27,37 @@ if os.getenv("TESTING") == "true":
|
||||
AUTH_TIMEOUT = 1.0
|
||||
|
||||
|
||||
async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
"""Validate token and return user_id and user object."""
|
||||
async def get_user_from_token(token: str) -> str | None:
|
||||
"""
|
||||
Validate token and return user_id.
|
||||
|
||||
Returns:
|
||||
user_id if valid, None otherwise.
|
||||
|
||||
Note: This function properly closes the database session after validation.
|
||||
Do not return ORM objects as they become detached after session close.
|
||||
"""
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
return None, None
|
||||
return None
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
return None, None
|
||||
return None
|
||||
|
||||
# Verify session in Redis
|
||||
redis_client = get_redis_sync()
|
||||
stored_token = redis_client.get(f"session:{user_id}")
|
||||
if stored_token is None or stored_token != token:
|
||||
return None, None
|
||||
return None
|
||||
|
||||
# Get user from database
|
||||
# Verify user exists and is active
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None or not user.is_active:
|
||||
return None, None
|
||||
return user_id, user
|
||||
return None
|
||||
return user_id
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -57,7 +65,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
async def authenticate_websocket(
|
||||
websocket: WebSocket,
|
||||
query_token: Optional[str] = None
|
||||
) -> tuple[str | None, User | None, Optional[str]]:
|
||||
) -> tuple[str | None, Optional[str]]:
|
||||
"""
|
||||
Authenticate WebSocket connection.
|
||||
|
||||
@@ -67,7 +75,8 @@ async def authenticate_websocket(
|
||||
2. Query parameter authentication (deprecated, for backward compatibility)
|
||||
- Client connects with: ?token=<jwt_token>
|
||||
|
||||
Returns (user_id, user) if authenticated, (None, None) otherwise.
|
||||
Returns:
|
||||
Tuple of (user_id, error_reason). user_id is None if authentication fails.
|
||||
"""
|
||||
# If token provided via query parameter (backward compatibility)
|
||||
if query_token:
|
||||
@@ -75,10 +84,10 @@ async def authenticate_websocket(
|
||||
"WebSocket authentication via query parameter is deprecated. "
|
||||
"Please use first-message authentication for better security."
|
||||
)
|
||||
user_id, user = await get_user_from_token(query_token)
|
||||
user_id = await get_user_from_token(query_token)
|
||||
if user_id is None:
|
||||
return None, None, "invalid_token"
|
||||
return user_id, user, None
|
||||
return None, "invalid_token"
|
||||
return user_id, None
|
||||
|
||||
# Wait for authentication message with timeout
|
||||
try:
|
||||
@@ -90,24 +99,24 @@ async def authenticate_websocket(
|
||||
msg_type = data.get("type")
|
||||
if msg_type != "auth":
|
||||
logger.warning("Expected 'auth' message type, got: %s", msg_type)
|
||||
return None, None, "invalid_message"
|
||||
return None, "invalid_message"
|
||||
|
||||
token = data.get("token")
|
||||
if not token:
|
||||
logger.warning("No token provided in auth message")
|
||||
return None, None, "missing_token"
|
||||
return None, "missing_token"
|
||||
|
||||
user_id, user = await get_user_from_token(token)
|
||||
user_id = await get_user_from_token(token)
|
||||
if user_id is None:
|
||||
return None, None, "invalid_token"
|
||||
return user_id, user, None
|
||||
return None, "invalid_token"
|
||||
return user_id, None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
|
||||
return None, None, "timeout"
|
||||
return None, "timeout"
|
||||
except Exception as e:
|
||||
logger.error("Error during WebSocket authentication: %s", e)
|
||||
return None, None, "error"
|
||||
return None, "error"
|
||||
|
||||
|
||||
async def get_unread_notifications(user_id: str) -> list[dict]:
|
||||
@@ -183,7 +192,7 @@ async def websocket_notifications(
|
||||
await websocket.accept()
|
||||
|
||||
# Authenticate
|
||||
user_id, user, error_reason = await authenticate_websocket(websocket, token)
|
||||
user_id, error_reason = await authenticate_websocket(websocket, token)
|
||||
|
||||
if user_id is None:
|
||||
if error_reason == "invalid_token":
|
||||
@@ -306,7 +315,7 @@ async def websocket_notifications(
|
||||
await manager.disconnect(websocket, user_id)
|
||||
|
||||
|
||||
async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Project | None]:
|
||||
async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, str | None, str | None]:
|
||||
"""
|
||||
Check if user has access to the project.
|
||||
|
||||
@@ -315,23 +324,34 @@ async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Pr
|
||||
project_id: The project's ID
|
||||
|
||||
Returns:
|
||||
Tuple of (has_access: bool, project: Project | None)
|
||||
Tuple of (has_access: bool, project_title: str | None, error: str | None)
|
||||
- has_access: True if user can access the project
|
||||
- project_title: The project title (only if access granted)
|
||||
- error: Error code if access denied ("user_not_found", "project_not_found", "access_denied")
|
||||
|
||||
Note: This function extracts needed data before closing the session to avoid
|
||||
detached instance errors when accessing ORM object attributes.
|
||||
"""
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
# Get the user
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None or not user.is_active:
|
||||
return False, None
|
||||
return False, None, "user_not_found"
|
||||
|
||||
# Get the project
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if project is None:
|
||||
return False, None
|
||||
return False, None, "project_not_found"
|
||||
|
||||
# Check access using existing middleware function
|
||||
has_access = check_project_access(user, project)
|
||||
return has_access, project
|
||||
if not has_access:
|
||||
return False, None, "access_denied"
|
||||
|
||||
# Extract title while session is still open
|
||||
project_title = project.title
|
||||
return True, project_title, None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -371,7 +391,7 @@ async def websocket_project_sync(
|
||||
await websocket.accept()
|
||||
|
||||
# Authenticate user
|
||||
user_id, user, error_reason = await authenticate_websocket(websocket, token)
|
||||
user_id, error_reason = await authenticate_websocket(websocket, token)
|
||||
|
||||
if user_id is None:
|
||||
if error_reason == "invalid_token":
|
||||
@@ -380,14 +400,13 @@ async def websocket_project_sync(
|
||||
return
|
||||
|
||||
# Verify user has access to the project
|
||||
has_access, project = await verify_project_access(user_id, project_id)
|
||||
has_access, project_title, access_error = await verify_project_access(user_id, project_id)
|
||||
|
||||
if not has_access:
|
||||
await websocket.close(code=4003, reason="Access denied to this project")
|
||||
return
|
||||
|
||||
if project is None:
|
||||
await websocket.close(code=4004, reason="Project not found")
|
||||
if access_error == "project_not_found":
|
||||
await websocket.close(code=4004, reason="Project not found")
|
||||
else:
|
||||
await websocket.close(code=4003, reason="Access denied to this project")
|
||||
return
|
||||
|
||||
# Join project room
|
||||
@@ -413,7 +432,7 @@ async def websocket_project_sync(
|
||||
"data": {
|
||||
"project_id": project_id,
|
||||
"user_id": user_id,
|
||||
"project_title": project.title if project else None,
|
||||
"project_title": project_title,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@@ -122,6 +122,11 @@ class Settings(BaseSettings):
|
||||
RATE_LIMIT_SENSITIVE: str = "20/minute" # Attachments, password change, report export
|
||||
RATE_LIMIT_HEAVY: str = "5/minute" # Report generation, bulk operations
|
||||
|
||||
# Development Mode Settings
|
||||
DEBUG: bool = False # Enable debug mode for development
|
||||
QUERY_LOGGING: bool = False # Enable SQLAlchemy query logging
|
||||
QUERY_COUNT_THRESHOLD: int = 10 # Warn when query count exceeds this threshold
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
@@ -104,6 +104,10 @@ def _on_invalidate(dbapi_conn, connection_record, exception):
|
||||
# Start pool statistics logging on module load
|
||||
_start_pool_stats_logging()
|
||||
|
||||
# Set up query logging if enabled
|
||||
from app.core.query_monitor import setup_query_logging
|
||||
setup_query_logging(engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Dependency for getting database session."""
|
||||
@@ -127,3 +131,25 @@ def get_pool_status() -> dict:
|
||||
"total_checkins": _pool_stats["checkins"],
|
||||
"invalidated_connections": _pool_stats["invalidated_connections"],
|
||||
}
|
||||
|
||||
|
||||
def escape_like(value: str) -> str:
|
||||
"""
|
||||
Escape special characters for SQL LIKE queries.
|
||||
|
||||
Escapes '%' and '_' characters which have special meaning in LIKE patterns.
|
||||
This prevents LIKE injection attacks where user input could match unintended patterns.
|
||||
|
||||
Args:
|
||||
value: The user input string to escape
|
||||
|
||||
Returns:
|
||||
Escaped string safe for use in LIKE patterns
|
||||
|
||||
Example:
|
||||
>>> escape_like("test%value")
|
||||
'test\\%value'
|
||||
>>> escape_like("user_name")
|
||||
'user\\_name'
|
||||
"""
|
||||
return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
167
backend/app/core/query_monitor.py
Normal file
167
backend/app/core/query_monitor.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Query monitoring utilities for detecting N+1 queries and performance issues.
|
||||
|
||||
This module provides:
|
||||
1. Query counting per request in development mode
|
||||
2. SQLAlchemy event listeners for query logging
|
||||
3. Threshold-based warnings for excessive queries
|
||||
"""
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Callable, Any
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thread-local storage for per-request query counting
|
||||
_query_context = threading.local()
|
||||
|
||||
|
||||
class QueryCounter:
|
||||
"""
|
||||
Context manager for counting database queries within a request.
|
||||
|
||||
Usage:
|
||||
with QueryCounter() as counter:
|
||||
# ... execute queries ...
|
||||
print(f"Executed {counter.count} queries")
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: Optional[int] = None, context_name: str = "request"):
|
||||
self.threshold = threshold or settings.QUERY_COUNT_THRESHOLD
|
||||
self.context_name = context_name
|
||||
self.count = 0
|
||||
self.queries = []
|
||||
self.start_time = None
|
||||
self.total_time = 0.0
|
||||
|
||||
def __enter__(self):
|
||||
self.count = 0
|
||||
self.queries = []
|
||||
self.start_time = time.time()
|
||||
_query_context.counter = self
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.total_time = time.time() - self.start_time
|
||||
_query_context.counter = None
|
||||
|
||||
# Log warning if threshold exceeded
|
||||
if self.count > self.threshold:
|
||||
logger.warning(
|
||||
"Query count threshold exceeded in %s: %d queries (threshold: %d, time: %.3fs)",
|
||||
self.context_name,
|
||||
self.count,
|
||||
self.threshold,
|
||||
self.total_time,
|
||||
)
|
||||
if settings.DEBUG:
|
||||
# In debug mode, also log the individual queries
|
||||
for i, (sql, duration) in enumerate(self.queries[:20], 1):
|
||||
logger.debug(" Query %d (%.3fs): %s", i, duration, sql[:200])
|
||||
if len(self.queries) > 20:
|
||||
logger.debug(" ... and %d more queries", len(self.queries) - 20)
|
||||
elif settings.DEBUG and self.count > 0:
|
||||
logger.debug(
|
||||
"Query count for %s: %d queries in %.3fs",
|
||||
self.context_name,
|
||||
self.count,
|
||||
self.total_time,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
def record_query(self, statement: str, duration: float):
|
||||
"""Record a query execution."""
|
||||
self.count += 1
|
||||
if settings.DEBUG:
|
||||
self.queries.append((statement, duration))
|
||||
|
||||
|
||||
def get_current_counter() -> Optional[QueryCounter]:
|
||||
"""Get the current request's query counter, if any."""
|
||||
return getattr(_query_context, 'counter', None)
|
||||
|
||||
|
||||
def setup_query_logging(engine: Engine):
|
||||
"""
|
||||
Set up SQLAlchemy event listeners for query logging.
|
||||
|
||||
This should be called once during application startup.
|
||||
Only activates if QUERY_LOGGING is enabled in settings.
|
||||
"""
|
||||
if not settings.QUERY_LOGGING:
|
||||
logger.info("Query logging is disabled")
|
||||
return
|
||||
|
||||
logger.info("Setting up query logging with threshold=%d", settings.QUERY_COUNT_THRESHOLD)
|
||||
|
||||
@event.listens_for(engine, "before_cursor_execute")
|
||||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
conn.info.setdefault('query_start_time', []).append(time.time())
|
||||
|
||||
@event.listens_for(engine, "after_cursor_execute")
|
||||
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
start_times = conn.info.get('query_start_time', [])
|
||||
duration = time.time() - start_times.pop() if start_times else 0.0
|
||||
|
||||
# Record in current counter if active
|
||||
counter = get_current_counter()
|
||||
if counter:
|
||||
counter.record_query(statement, duration)
|
||||
|
||||
# Also log individual queries if in debug mode
|
||||
if settings.DEBUG:
|
||||
logger.debug("SQL (%.3fs): %s", duration, statement[:500])
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_queries(context_name: str = "operation", threshold: Optional[int] = None):
|
||||
"""
|
||||
Context manager to count queries for a specific operation.
|
||||
|
||||
Args:
|
||||
context_name: Name for logging purposes
|
||||
threshold: Override the default query count threshold
|
||||
|
||||
Usage:
|
||||
with count_queries("list_members") as counter:
|
||||
members = db.query(ProjectMember).all()
|
||||
for member in members:
|
||||
print(member.user.name) # N+1 query!
|
||||
|
||||
# After block, logs warning if threshold exceeded
|
||||
print(f"Total queries: {counter.count}")
|
||||
"""
|
||||
with QueryCounter(threshold=threshold, context_name=context_name) as counter:
|
||||
yield counter
|
||||
|
||||
|
||||
def assert_query_count(max_queries: int):
|
||||
"""
|
||||
Decorator for testing that asserts maximum query count.
|
||||
|
||||
Usage in tests:
|
||||
@assert_query_count(5)
|
||||
def test_list_members():
|
||||
# Should use at most 5 queries
|
||||
response = client.get("/api/projects/xxx/members")
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
def wrapper(*args, **kwargs):
|
||||
with QueryCounter(threshold=max_queries, context_name=func.__name__) as counter:
|
||||
result = func(*args, **kwargs)
|
||||
if counter.count > max_queries:
|
||||
raise AssertionError(
|
||||
f"Query count {counter.count} exceeded maximum {max_queries} "
|
||||
f"in {func.__name__}"
|
||||
)
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
@@ -1,8 +1,283 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, Tuple
|
||||
from jose import jwt, JWTError
|
||||
import logging
|
||||
import math
|
||||
import hashlib
|
||||
import secrets
|
||||
import hmac
|
||||
from collections import Counter
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants for JWT secret validation
|
||||
MIN_SECRET_LENGTH = 32
|
||||
MIN_ENTROPY_BITS = 128 # Minimum entropy in bits for a secure secret
|
||||
COMMON_WEAK_PATTERNS = [
|
||||
"password", "secret", "changeme", "admin", "test", "demo",
|
||||
"123456", "qwerty", "abc123", "letmein", "welcome",
|
||||
]
|
||||
|
||||
|
||||
def calculate_entropy(data: str) -> float:
|
||||
"""
|
||||
Calculate Shannon entropy of a string in bits.
|
||||
|
||||
Higher entropy indicates more randomness and thus a stronger secret.
|
||||
A perfectly random string of length n with k possible characters has
|
||||
entropy of n * log2(k) bits.
|
||||
|
||||
Args:
|
||||
data: The string to calculate entropy for
|
||||
|
||||
Returns:
|
||||
Entropy value in bits
|
||||
"""
|
||||
if not data:
|
||||
return 0.0
|
||||
|
||||
# Count character frequencies
|
||||
char_counts = Counter(data)
|
||||
length = len(data)
|
||||
|
||||
# Calculate Shannon entropy
|
||||
entropy = 0.0
|
||||
for count in char_counts.values():
|
||||
if count > 0:
|
||||
probability = count / length
|
||||
entropy -= probability * math.log2(probability)
|
||||
|
||||
# Return total entropy in bits (per-character entropy * length)
|
||||
return entropy * length
|
||||
|
||||
|
||||
def has_repeating_patterns(secret: str) -> bool:
|
||||
"""
|
||||
Check if the secret contains obvious repeating patterns.
|
||||
|
||||
Args:
|
||||
secret: The secret string to check
|
||||
|
||||
Returns:
|
||||
True if repeating patterns are detected
|
||||
"""
|
||||
if len(secret) < 8:
|
||||
return False
|
||||
|
||||
# Check for repeating character sequences
|
||||
for pattern_len in range(2, len(secret) // 3 + 1):
|
||||
pattern = secret[:pattern_len]
|
||||
if pattern * (len(secret) // pattern_len) == secret[:len(pattern) * (len(secret) // pattern_len)]:
|
||||
# More than 50% of the string is the same pattern repeated
|
||||
if (len(secret) // pattern_len) >= 3:
|
||||
return True
|
||||
|
||||
# Check for consecutive same characters
|
||||
consecutive_count = 1
|
||||
for i in range(1, len(secret)):
|
||||
if secret[i] == secret[i-1]:
|
||||
consecutive_count += 1
|
||||
if consecutive_count >= len(secret) // 2:
|
||||
return True
|
||||
else:
|
||||
consecutive_count = 1
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def validate_jwt_secret_strength(secret: str) -> Tuple[bool, list]:
|
||||
"""
|
||||
Validate JWT secret key strength.
|
||||
|
||||
Checks:
|
||||
1. Minimum length (32 characters)
|
||||
2. Entropy (minimum 128 bits)
|
||||
3. Common weak patterns
|
||||
4. Repeating patterns
|
||||
|
||||
Args:
|
||||
secret: The JWT secret to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, list_of_warnings)
|
||||
"""
|
||||
warnings = []
|
||||
is_valid = True
|
||||
|
||||
# Check minimum length
|
||||
if len(secret) < MIN_SECRET_LENGTH:
|
||||
warnings.append(
|
||||
f"JWT secret is too short ({len(secret)} chars). "
|
||||
f"Minimum recommended length is {MIN_SECRET_LENGTH} characters."
|
||||
)
|
||||
is_valid = False
|
||||
|
||||
# Calculate and check entropy
|
||||
entropy = calculate_entropy(secret)
|
||||
if entropy < MIN_ENTROPY_BITS:
|
||||
warnings.append(
|
||||
f"JWT secret has low entropy ({entropy:.1f} bits). "
|
||||
f"Minimum recommended entropy is {MIN_ENTROPY_BITS} bits. "
|
||||
"Consider using a cryptographically random secret."
|
||||
)
|
||||
# Low entropy alone doesn't make it invalid, but it's a warning
|
||||
|
||||
# Check for common weak patterns
|
||||
secret_lower = secret.lower()
|
||||
for pattern in COMMON_WEAK_PATTERNS:
|
||||
if pattern in secret_lower:
|
||||
warnings.append(
|
||||
f"JWT secret contains common weak pattern: '{pattern}'. "
|
||||
"Use a cryptographically random secret."
|
||||
)
|
||||
break
|
||||
|
||||
# Check for repeating patterns
|
||||
if has_repeating_patterns(secret):
|
||||
warnings.append(
|
||||
"JWT secret contains repeating patterns. "
|
||||
"Use a cryptographically random secret."
|
||||
)
|
||||
|
||||
return is_valid, warnings
|
||||
|
||||
|
||||
def validate_jwt_secret_on_startup() -> None:
|
||||
"""
|
||||
Validate JWT secret strength on application startup.
|
||||
|
||||
Logs warnings for weak secrets and raises an error in production
|
||||
if the secret is critically weak.
|
||||
"""
|
||||
import os
|
||||
|
||||
secret = settings.JWT_SECRET_KEY
|
||||
is_valid, warnings = validate_jwt_secret_strength(secret)
|
||||
|
||||
# Log all warnings
|
||||
for warning in warnings:
|
||||
logger.warning("JWT Security Warning: %s", warning)
|
||||
|
||||
# In production, enforce stricter requirements
|
||||
is_production = os.environ.get("ENVIRONMENT", "").lower() == "production"
|
||||
|
||||
if not is_valid:
|
||||
if is_production:
|
||||
logger.critical(
|
||||
"JWT secret does not meet security requirements. "
|
||||
"Application startup blocked in production mode. "
|
||||
"Please configure a strong JWT_SECRET_KEY (minimum 32 characters)."
|
||||
)
|
||||
raise ValueError(
|
||||
"JWT_SECRET_KEY does not meet minimum security requirements. "
|
||||
"See logs for details."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"JWT secret does not meet security requirements. "
|
||||
"This would block startup in production mode."
|
||||
)
|
||||
|
||||
if warnings:
|
||||
logger.info(
|
||||
"JWT secret validation completed with %d warning(s). "
|
||||
"Consider using: python -c \"import secrets; print(secrets.token_urlsafe(48))\" "
|
||||
"to generate a strong secret.",
|
||||
len(warnings)
|
||||
)
|
||||
else:
|
||||
logger.info("JWT secret validation passed. Secret meets security requirements.")
|
||||
|
||||
|
||||
# CSRF Token Functions
|
||||
CSRF_TOKEN_LENGTH = 32
|
||||
CSRF_TOKEN_EXPIRY_SECONDS = 3600 # 1 hour
|
||||
|
||||
|
||||
def generate_csrf_token(user_id: str) -> str:
|
||||
"""
|
||||
Generate a CSRF token for a user.
|
||||
|
||||
The token is a combination of:
|
||||
- Random bytes for unpredictability
|
||||
- User ID binding to prevent token reuse across users
|
||||
- HMAC signature for integrity
|
||||
|
||||
Args:
|
||||
user_id: The user's ID to bind the token to
|
||||
|
||||
Returns:
|
||||
CSRF token string
|
||||
"""
|
||||
# Generate random token
|
||||
random_part = secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||
|
||||
# Create timestamp for expiry checking
|
||||
timestamp = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Create the token payload
|
||||
payload = f"{random_part}:{user_id}:{timestamp}"
|
||||
|
||||
# Sign with HMAC using JWT secret
|
||||
signature = hmac.new(
|
||||
settings.JWT_SECRET_KEY.encode(),
|
||||
payload.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()[:16]
|
||||
|
||||
# Return combined token
|
||||
return f"{payload}:{signature}"
|
||||
|
||||
|
||||
def validate_csrf_token(token: str, user_id: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate a CSRF token.
|
||||
|
||||
Args:
|
||||
token: The CSRF token to validate
|
||||
user_id: The expected user ID
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
if not token:
|
||||
return False, "CSRF token is required"
|
||||
|
||||
try:
|
||||
parts = token.split(":")
|
||||
if len(parts) != 4:
|
||||
return False, "Invalid CSRF token format"
|
||||
|
||||
random_part, token_user_id, timestamp_str, signature = parts
|
||||
|
||||
# Verify user ID matches
|
||||
if token_user_id != user_id:
|
||||
return False, "CSRF token user mismatch"
|
||||
|
||||
# Verify timestamp (check expiry)
|
||||
timestamp = int(timestamp_str)
|
||||
current_time = int(datetime.now(timezone.utc).timestamp())
|
||||
if current_time - timestamp > CSRF_TOKEN_EXPIRY_SECONDS:
|
||||
return False, "CSRF token expired"
|
||||
|
||||
# Verify signature
|
||||
payload = f"{random_part}:{token_user_id}:{timestamp_str}"
|
||||
expected_signature = hmac.new(
|
||||
settings.JWT_SECRET_KEY.encode(),
|
||||
payload.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()[:16]
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
return False, "CSRF token signature invalid"
|
||||
|
||||
return True, ""
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
return False, f"CSRF token validation error: {str(e)}"
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""
|
||||
|
||||
@@ -20,6 +20,12 @@ async def lifespan(app: FastAPI):
|
||||
testing = os.environ.get("TESTING", "").lower() in ("true", "1", "yes")
|
||||
scheduler_disabled = os.environ.get("DISABLE_SCHEDULER", "").lower() in ("true", "1", "yes")
|
||||
start_background_jobs = not testing and not scheduler_disabled
|
||||
|
||||
# Startup security validation
|
||||
if not testing:
|
||||
from app.core.security import validate_jwt_secret_on_startup
|
||||
validate_jwt_secret_on_startup()
|
||||
|
||||
# Startup
|
||||
if start_background_jobs:
|
||||
start_scheduler()
|
||||
|
||||
167
backend/app/middleware/csrf.py
Normal file
167
backend/app/middleware/csrf.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
CSRF (Cross-Site Request Forgery) Protection Middleware.
|
||||
|
||||
This module provides CSRF protection for sensitive state-changing operations.
|
||||
It validates CSRF tokens for specified protected endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import Request, HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from typing import Optional, Callable, List
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
from app.core.security import validate_csrf_token, generate_csrf_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Header name for CSRF token
|
||||
CSRF_TOKEN_HEADER = "X-CSRF-Token"
|
||||
|
||||
# List of endpoint patterns that require CSRF protection
|
||||
# These are sensitive state-changing operations
|
||||
CSRF_PROTECTED_PATTERNS = [
|
||||
# User operations
|
||||
"/api/v1/users/{user_id}/admin", # Admin status change
|
||||
"/api/users/{user_id}/admin", # Legacy
|
||||
# Password changes would go here if implemented
|
||||
# Delete operations
|
||||
"/api/attachments/{attachment_id}", # DELETE method
|
||||
"/api/tasks/{task_id}", # DELETE method (soft delete)
|
||||
"/api/projects/{project_id}", # DELETE method
|
||||
]
|
||||
|
||||
# Methods that require CSRF protection
|
||||
CSRF_PROTECTED_METHODS = ["DELETE", "PUT", "PATCH"]
|
||||
|
||||
|
||||
class CSRFProtectionError(HTTPException):
|
||||
"""Custom exception for CSRF validation failures."""
|
||||
|
||||
def __init__(self, detail: str = "CSRF validation failed"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=detail
|
||||
)
|
||||
|
||||
|
||||
def require_csrf_token(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to require CSRF token validation for an endpoint.
|
||||
|
||||
Usage:
|
||||
@router.delete("/resource/{id}")
|
||||
@require_csrf_token
|
||||
async def delete_resource(request: Request, id: str, current_user: User = Depends(get_current_user)):
|
||||
...
|
||||
|
||||
The decorator validates the X-CSRF-Token header against the current user.
|
||||
"""
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract request and current_user from kwargs
|
||||
request: Optional[Request] = kwargs.get("request")
|
||||
current_user = kwargs.get("current_user")
|
||||
|
||||
if request is None:
|
||||
# Try to find request in args (for methods where request is positional)
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if request is None:
|
||||
logger.error("CSRF validation failed: Request object not found")
|
||||
raise CSRFProtectionError("Internal error: Request not available")
|
||||
|
||||
if current_user is None:
|
||||
logger.error("CSRF validation failed: User not authenticated")
|
||||
raise CSRFProtectionError("Authentication required for CSRF-protected endpoint")
|
||||
|
||||
# Get CSRF token from header
|
||||
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
|
||||
|
||||
if not csrf_token:
|
||||
logger.warning(
|
||||
"CSRF validation failed: Missing token for user %s on %s %s",
|
||||
current_user.id, request.method, request.url.path
|
||||
)
|
||||
raise CSRFProtectionError("CSRF token is required")
|
||||
|
||||
# Validate the token
|
||||
is_valid, error_message = validate_csrf_token(csrf_token, current_user.id)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
"CSRF validation failed for user %s on %s %s: %s",
|
||||
current_user.id, request.method, request.url.path, error_message
|
||||
)
|
||||
raise CSRFProtectionError(error_message)
|
||||
|
||||
logger.debug(
|
||||
"CSRF validation passed for user %s on %s %s",
|
||||
current_user.id, request.method, request.url.path
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_csrf_token_for_user(user_id: str) -> str:
|
||||
"""
|
||||
Generate a CSRF token for a user.
|
||||
|
||||
This function can be called from login endpoints to provide
|
||||
the client with a CSRF token.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
CSRF token string
|
||||
"""
|
||||
return generate_csrf_token(user_id)
|
||||
|
||||
|
||||
async def validate_csrf_for_request(
|
||||
request: Request,
|
||||
user_id: str,
|
||||
skip_methods: Optional[List[str]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Validate CSRF token for a request.
|
||||
|
||||
This is a utility function that can be used directly in endpoints
|
||||
without the decorator.
|
||||
|
||||
Args:
|
||||
request: The FastAPI request object
|
||||
user_id: The current user's ID
|
||||
skip_methods: HTTP methods to skip validation for (default: GET, HEAD, OPTIONS)
|
||||
|
||||
Returns:
|
||||
True if validation passes
|
||||
|
||||
Raises:
|
||||
CSRFProtectionError: If validation fails
|
||||
"""
|
||||
if skip_methods is None:
|
||||
skip_methods = ["GET", "HEAD", "OPTIONS"]
|
||||
|
||||
# Skip validation for safe methods
|
||||
if request.method.upper() in skip_methods:
|
||||
return True
|
||||
|
||||
# Get CSRF token from header
|
||||
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
|
||||
|
||||
if not csrf_token:
|
||||
raise CSRFProtectionError("CSRF token is required")
|
||||
|
||||
is_valid, error_message = validate_csrf_token(csrf_token, user_id)
|
||||
|
||||
if not is_valid:
|
||||
raise CSRFProtectionError(error_message)
|
||||
|
||||
return True
|
||||
@@ -32,5 +32,11 @@ class TokenPayload(BaseModel):
|
||||
iat: int
|
||||
|
||||
|
||||
class CSRFTokenResponse(BaseModel):
|
||||
"""Response containing a CSRF token for state-changing operations."""
|
||||
csrf_token: str = Field(..., description="CSRF token to include in X-CSRF-Token header")
|
||||
expires_in: int = Field(default=3600, description="Token expiry time in seconds")
|
||||
|
||||
|
||||
# Update forward reference
|
||||
LoginResponse.model_rebuild()
|
||||
|
||||
@@ -286,11 +286,15 @@ class FileStorageService:
|
||||
return filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
||||
|
||||
@staticmethod
|
||||
def validate_file(file: UploadFile) -> Tuple[str, str]:
|
||||
def validate_file(file: UploadFile, validate_mime: bool = True) -> Tuple[str, str]:
|
||||
"""
|
||||
Validate file size and type.
|
||||
Validate file size, type, and optionally MIME content.
|
||||
Returns (extension, mime_type) if valid.
|
||||
Raises HTTPException if invalid.
|
||||
|
||||
Args:
|
||||
file: The uploaded file
|
||||
validate_mime: If True, validate MIME type using magic bytes detection
|
||||
"""
|
||||
# Check file size
|
||||
file.file.seek(0, 2) # Seek to end
|
||||
@@ -323,7 +327,35 @@ class FileStorageService:
|
||||
detail=f"File type '.{extension}' is not supported"
|
||||
)
|
||||
|
||||
mime_type = file.content_type or "application/octet-stream"
|
||||
# Validate MIME type using magic bytes detection
|
||||
if validate_mime:
|
||||
from app.services.mime_validation_service import mime_validation_service
|
||||
|
||||
# Read first 16 bytes for magic detection (enough for most signatures)
|
||||
file_header = file.file.read(16)
|
||||
file.file.seek(0) # Reset
|
||||
|
||||
is_valid, detected_mime, error_message = mime_validation_service.validate_file_content(
|
||||
file_content=file_header,
|
||||
declared_extension=extension,
|
||||
declared_mime_type=file.content_type
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
"MIME validation failed for file '%s': %s (detected: %s)",
|
||||
file.filename, error_message, detected_mime
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=error_message or "File type validation failed"
|
||||
)
|
||||
|
||||
# Use detected MIME type if available, otherwise fall back to declared
|
||||
mime_type = detected_mime if detected_mime else (file.content_type or "application/octet-stream")
|
||||
else:
|
||||
mime_type = file.content_type or "application/octet-stream"
|
||||
|
||||
return extension, mime_type
|
||||
|
||||
async def save_file(
|
||||
|
||||
314
backend/app/services/mime_validation_service.py
Normal file
314
backend/app/services/mime_validation_service.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
MIME Type Validation Service using Magic Bytes Detection.
|
||||
|
||||
This module provides file content type validation by examining
|
||||
the actual file content (magic bytes) rather than trusting
|
||||
the file extension or Content-Type header.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, Dict, Set, BinaryIO
|
||||
from io import BytesIO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MimeValidationError(Exception):
|
||||
"""Raised when MIME type validation fails."""
|
||||
pass
|
||||
|
||||
|
||||
class FileMismatchError(MimeValidationError):
|
||||
"""Raised when file extension doesn't match actual content type."""
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedMimeError(MimeValidationError):
|
||||
"""Raised when file has an unsupported MIME type."""
|
||||
pass
|
||||
|
||||
|
||||
# Magic bytes signatures for common file types
|
||||
# Format: { bytes_pattern: (mime_type, extensions) }
|
||||
MAGIC_SIGNATURES: Dict[bytes, Tuple[str, Set[str]]] = {
|
||||
# Images
|
||||
b'\xFF\xD8\xFF': ('image/jpeg', {'jpg', 'jpeg', 'jpe'}),
|
||||
b'\x89PNG\r\n\x1a\n': ('image/png', {'png'}),
|
||||
b'GIF87a': ('image/gif', {'gif'}),
|
||||
b'GIF89a': ('image/gif', {'gif'}),
|
||||
b'RIFF': ('image/webp', {'webp'}), # WebP starts with RIFF, then WEBP
|
||||
b'BM': ('image/bmp', {'bmp'}),
|
||||
|
||||
# PDF
|
||||
b'%PDF': ('application/pdf', {'pdf'}),
|
||||
|
||||
# Microsoft Office (Modern formats - ZIP-based)
|
||||
b'PK\x03\x04': ('application/zip', {'zip', 'docx', 'xlsx', 'pptx', 'odt', 'ods', 'odp', 'jar'}),
|
||||
|
||||
# Microsoft Office (Legacy formats - Compound Document)
|
||||
b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1': ('application/msword', {'doc', 'xls', 'ppt', 'msi'}),
|
||||
|
||||
# Archives
|
||||
b'\x1f\x8b': ('application/gzip', {'gz', 'tgz'}),
|
||||
b'\x42\x5a\x68': ('application/x-bzip2', {'bz2'}),
|
||||
b'\x37\x7A\xBC\xAF\x27\x1C': ('application/x-7z-compressed', {'7z'}),
|
||||
b'Rar!\x1a\x07': ('application/x-rar-compressed', {'rar'}),
|
||||
|
||||
# Text/Data formats - these are harder to detect, usually fallback to extension
|
||||
b'<?xml': ('application/xml', {'xml', 'svg'}),
|
||||
b'{': ('application/json', {'json'}), # JSON typically starts with { or [
|
||||
b'[': ('application/json', {'json'}),
|
||||
|
||||
# Executables (dangerous - should be blocked)
|
||||
b'MZ': ('application/x-executable', {'exe', 'dll', 'com', 'scr'}),
|
||||
b'\x7fELF': ('application/x-executable', {'elf', 'so', 'bin'}),
|
||||
}
|
||||
|
||||
# Map extensions to expected MIME types
|
||||
EXTENSION_TO_MIME: Dict[str, Set[str]] = {
|
||||
# Images
|
||||
'jpg': {'image/jpeg'},
|
||||
'jpeg': {'image/jpeg'},
|
||||
'jpe': {'image/jpeg'},
|
||||
'png': {'image/png'},
|
||||
'gif': {'image/gif'},
|
||||
'bmp': {'image/bmp'},
|
||||
'webp': {'image/webp'},
|
||||
'svg': {'image/svg+xml', 'application/xml', 'text/xml'},
|
||||
|
||||
# Documents
|
||||
'pdf': {'application/pdf'},
|
||||
'doc': {'application/msword'},
|
||||
'docx': {'application/vnd.openxmlformats-officedocument.wordprocessingml.document', 'application/zip'},
|
||||
'xls': {'application/vnd.ms-excel', 'application/msword'},
|
||||
'xlsx': {'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', 'application/zip'},
|
||||
'ppt': {'application/vnd.ms-powerpoint', 'application/msword'},
|
||||
'pptx': {'application/vnd.openxmlformats-officedocument.presentationml.presentation', 'application/zip'},
|
||||
|
||||
# Text
|
||||
'txt': {'text/plain'},
|
||||
'csv': {'text/csv', 'text/plain'},
|
||||
'json': {'application/json', 'text/plain'},
|
||||
'xml': {'application/xml', 'text/xml', 'text/plain'},
|
||||
'yaml': {'application/yaml', 'text/plain'},
|
||||
'yml': {'application/yaml', 'text/plain'},
|
||||
|
||||
# Archives
|
||||
'zip': {'application/zip'},
|
||||
'rar': {'application/x-rar-compressed'},
|
||||
'7z': {'application/x-7z-compressed'},
|
||||
'tar': {'application/x-tar'},
|
||||
'gz': {'application/gzip'},
|
||||
}
|
||||
|
||||
# MIME types that should always be blocked (dangerous executables)
|
||||
BLOCKED_MIME_TYPES: Set[str] = {
|
||||
'application/x-executable',
|
||||
'application/x-msdownload',
|
||||
'application/x-msdos-program',
|
||||
'application/x-sh',
|
||||
'application/x-csh',
|
||||
'application/x-dosexec',
|
||||
}
|
||||
|
||||
# Configurable allowed MIME type categories
|
||||
ALLOWED_MIME_CATEGORIES: Dict[str, Set[str]] = {
|
||||
'images': {
|
||||
'image/jpeg', 'image/png', 'image/gif', 'image/bmp', 'image/webp', 'image/svg+xml'
|
||||
},
|
||||
'documents': {
|
||||
'application/pdf',
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.ms-excel',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'text/plain', 'text/csv',
|
||||
},
|
||||
'archives': {
|
||||
'application/zip', 'application/x-rar-compressed',
|
||||
'application/x-7z-compressed', 'application/gzip',
|
||||
'application/x-tar',
|
||||
},
|
||||
'data': {
|
||||
'application/json', 'application/xml', 'text/xml',
|
||||
'application/yaml', 'text/plain',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MimeValidationService:
|
||||
"""Service for validating file MIME types using magic bytes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allowed_categories: Optional[Set[str]] = None,
|
||||
bypass_for_trusted: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize the MIME validation service.
|
||||
|
||||
Args:
|
||||
allowed_categories: Set of allowed MIME categories ('images', 'documents', etc.)
|
||||
If None, all categories are allowed.
|
||||
bypass_for_trusted: If True, validation can be bypassed for trusted sources.
|
||||
"""
|
||||
self.bypass_for_trusted = bypass_for_trusted
|
||||
|
||||
# Build set of allowed MIME types
|
||||
if allowed_categories is None:
|
||||
self.allowed_mime_types = set()
|
||||
for category_mimes in ALLOWED_MIME_CATEGORIES.values():
|
||||
self.allowed_mime_types.update(category_mimes)
|
||||
else:
|
||||
self.allowed_mime_types = set()
|
||||
for category in allowed_categories:
|
||||
if category in ALLOWED_MIME_CATEGORIES:
|
||||
self.allowed_mime_types.update(ALLOWED_MIME_CATEGORIES[category])
|
||||
|
||||
def detect_mime_type(self, file_content: bytes) -> Optional[str]:
|
||||
"""
|
||||
Detect MIME type from file content using magic bytes.
|
||||
|
||||
Args:
|
||||
file_content: The raw file bytes (at least first 16 bytes needed)
|
||||
|
||||
Returns:
|
||||
Detected MIME type or None if unknown
|
||||
"""
|
||||
if len(file_content) < 2:
|
||||
return None
|
||||
|
||||
# Check each magic signature
|
||||
for magic_bytes, (mime_type, _) in MAGIC_SIGNATURES.items():
|
||||
if file_content.startswith(magic_bytes):
|
||||
# Special case for WebP: check for WEBP after RIFF
|
||||
if magic_bytes == b'RIFF' and len(file_content) >= 12:
|
||||
if file_content[8:12] == b'WEBP':
|
||||
return 'image/webp'
|
||||
else:
|
||||
continue # Not WebP, might be something else
|
||||
|
||||
return mime_type
|
||||
|
||||
return None
|
||||
|
||||
def validate_file_content(
|
||||
self,
|
||||
file_content: bytes,
|
||||
declared_extension: str,
|
||||
declared_mime_type: Optional[str] = None,
|
||||
trusted_source: bool = False
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
Validate file content against declared extension and MIME type.
|
||||
|
||||
Args:
|
||||
file_content: The raw file bytes
|
||||
declared_extension: The file extension (without dot)
|
||||
declared_mime_type: The Content-Type header value (optional)
|
||||
trusted_source: If True and bypass_for_trusted is enabled, skip validation
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, detected_mime_type, error_message)
|
||||
"""
|
||||
# Bypass for trusted sources if configured
|
||||
if trusted_source and self.bypass_for_trusted:
|
||||
logger.debug("MIME validation bypassed for trusted source")
|
||||
return True, declared_mime_type or 'application/octet-stream', None
|
||||
|
||||
# Detect actual MIME type
|
||||
detected_mime = self.detect_mime_type(file_content)
|
||||
ext_lower = declared_extension.lower()
|
||||
|
||||
# Check if detected MIME is blocked (dangerous executable)
|
||||
if detected_mime in BLOCKED_MIME_TYPES:
|
||||
logger.warning(
|
||||
"Blocked dangerous file type detected: %s (claimed extension: %s)",
|
||||
detected_mime, ext_lower
|
||||
)
|
||||
return False, detected_mime, "File type not allowed for security reasons"
|
||||
|
||||
# If we couldn't detect the MIME type, fall back to extension-based check
|
||||
if detected_mime is None:
|
||||
# For text/data files, detection is unreliable
|
||||
# Trust the extension if it's in our allowed list
|
||||
if ext_lower in EXTENSION_TO_MIME:
|
||||
expected_mimes = EXTENSION_TO_MIME[ext_lower]
|
||||
# Check if any expected MIME is in allowed set
|
||||
if expected_mimes & self.allowed_mime_types:
|
||||
logger.debug(
|
||||
"MIME detection inconclusive for extension %s, allowing based on extension",
|
||||
ext_lower
|
||||
)
|
||||
# Return the first expected MIME type
|
||||
return True, next(iter(expected_mimes)), None
|
||||
|
||||
# Unknown extension or MIME type
|
||||
logger.warning(
|
||||
"Could not detect MIME type for file with extension: %s",
|
||||
ext_lower
|
||||
)
|
||||
return True, 'application/octet-stream', None
|
||||
|
||||
# Check if detected MIME is in allowed set
|
||||
if detected_mime not in self.allowed_mime_types:
|
||||
logger.warning(
|
||||
"Unsupported MIME type detected: %s (extension: %s)",
|
||||
detected_mime, ext_lower
|
||||
)
|
||||
return False, detected_mime, f"Unsupported file type: {detected_mime}"
|
||||
|
||||
# Verify extension matches detected MIME type
|
||||
if ext_lower in EXTENSION_TO_MIME:
|
||||
expected_mimes = EXTENSION_TO_MIME[ext_lower]
|
||||
|
||||
# Special handling for ZIP-based formats (docx, xlsx, pptx)
|
||||
if detected_mime == 'application/zip' and ext_lower in {'docx', 'xlsx', 'pptx', 'odt', 'ods', 'odp'}:
|
||||
# These are valid - ZIP container with specific extension
|
||||
return True, detected_mime, None
|
||||
|
||||
# Check if detected MIME matches any expected MIME for this extension
|
||||
if detected_mime not in expected_mimes:
|
||||
# Mismatch detected!
|
||||
logger.warning(
|
||||
"File type mismatch: extension '%s' but detected '%s'",
|
||||
ext_lower, detected_mime
|
||||
)
|
||||
return False, detected_mime, f"File type mismatch: extension indicates {ext_lower} but content is {detected_mime}"
|
||||
|
||||
return True, detected_mime, None
|
||||
|
||||
async def validate_upload_file(
|
||||
self,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
content_type: Optional[str] = None,
|
||||
trusted_source: bool = False
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
Validate an uploaded file.
|
||||
|
||||
Args:
|
||||
file_content: The raw file bytes
|
||||
filename: The uploaded filename
|
||||
content_type: The Content-Type header value
|
||||
trusted_source: If True and bypass is enabled, skip validation
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, detected_mime_type, error_message)
|
||||
"""
|
||||
# Extract extension
|
||||
extension = filename.rsplit('.', 1)[-1] if '.' in filename else ''
|
||||
|
||||
return self.validate_file_content(
|
||||
file_content=file_content,
|
||||
declared_extension=extension,
|
||||
declared_mime_type=content_type,
|
||||
trusted_source=trusted_source
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance with default configuration
|
||||
mime_validation_service = MimeValidationService()
|
||||
@@ -14,6 +14,7 @@ from sqlalchemy import event
|
||||
from app.models import User, Notification, Task, Comment, Mention
|
||||
from app.core.redis_pubsub import publish_notification as redis_publish, get_channel_name
|
||||
from app.core.redis import get_redis_sync
|
||||
from app.core.database import escape_like
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -427,9 +428,12 @@ class NotificationService:
|
||||
|
||||
# Find users by email or name
|
||||
for username in mentioned_usernames:
|
||||
# Escape special LIKE characters to prevent injection
|
||||
escaped_username = escape_like(username)
|
||||
# Try to find user by email first
|
||||
user = db.query(User).filter(
|
||||
(User.email.ilike(f"{username}%")) | (User.name.ilike(f"%{username}%"))
|
||||
(User.email.ilike(f"{escaped_username}%", escape="\\")) |
|
||||
(User.name.ilike(f"%{escaped_username}%", escape="\\"))
|
||||
).first()
|
||||
|
||||
if user and user.id != author.id:
|
||||
|
||||
Reference in New Issue
Block a user