feat: implement security, error resilience, and query optimization proposals

Security Validation (enhance-security-validation):
- JWT secret validation with entropy checking and pattern detection
- CSRF protection middleware with token generation/validation
- Frontend CSRF token auto-injection for DELETE/PUT/PATCH requests
- MIME type validation with magic bytes detection for file uploads

Error Resilience (add-error-resilience):
- React ErrorBoundary component with fallback UI and retry functionality
- ErrorBoundaryWithI18n wrapper for internationalization support
- Page-level and section-level error boundaries in App.tsx

Query Performance (optimize-query-performance):
- Query monitoring utility with threshold warnings
- N+1 query fixes using joinedload/selectinload
- Optimized project members, tasks, and subtasks endpoints

Bug Fixes:
- WebSocket session management (P0): Return primitives instead of ORM objects
- LIKE query injection (P1): Escape special characters in search queries

Tests: 543 backend tests, 56 frontend tests passing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-11 18:41:19 +08:00
parent 2cb591ef23
commit 679b89ae4c
41 changed files with 3673 additions and 153 deletions

View File

@@ -24,6 +24,7 @@ from app.services.encryption_service import (
MasterKeyNotConfiguredError,
DecryptionError,
)
from app.middleware.csrf import require_csrf_token
logger = logging.getLogger(__name__)
@@ -610,13 +611,14 @@ async def download_attachment(
@router.delete("/attachments/{attachment_id}")
@require_csrf_token
async def delete_attachment(
attachment_id: str,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Soft delete an attachment."""
"""Soft delete an attachment. Requires CSRF token in X-CSRF-Token header."""
attachment = get_attachment_with_access_check(db, attachment_id, current_user, require_edit=True)
# Soft delete

View File

@@ -8,7 +8,7 @@ from app.core.redis import get_redis
from app.core.rate_limiter import limiter
from app.models.user import User
from app.models.audit_log import AuditAction
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo, CSRFTokenResponse
from app.services.auth_client import (
verify_credentials,
AuthAPIError,
@@ -16,6 +16,7 @@ from app.services.auth_client import (
)
from app.services.audit_service import AuditService
from app.middleware.auth import get_current_user
from app.middleware.csrf import get_csrf_token_for_user
router = APIRouter()
@@ -182,3 +183,23 @@ async def get_current_user_info(
department_id=current_user.department_id,
is_system_admin=current_user.is_system_admin,
)
@router.get("/csrf-token", response_model=CSRFTokenResponse)
async def get_csrf_token(
current_user: User = Depends(get_current_user),
):
"""
Get a CSRF token for the current user.
The CSRF token should be included in the X-CSRF-Token header
for all sensitive state-changing operations (DELETE, PUT, PATCH).
Token expires after 1 hour and should be refreshed.
"""
csrf_token = get_csrf_token_for_user(current_user.id)
return CSRFTokenResponse(
csrf_token=csrf_token,
expires_in=3600, # 1 hour
)

View File

@@ -1,7 +1,7 @@
import uuid
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload, selectinload
from app.core.database import get_db
from app.models import User, Space, Project, TaskStatus, AuditAction, ProjectMember, ProjectTemplate, CustomField
@@ -55,6 +55,8 @@ async def list_projects_in_space(
):
"""
List all projects in a space that the user can access.
Optimized to avoid N+1 queries by using joinedload/selectinload for relationships.
"""
space = db.query(Space).filter(Space.id == space_id, Space.is_active == True).first()
@@ -70,13 +72,21 @@ async def list_projects_in_space(
detail="Access denied",
)
projects = db.query(Project).filter(Project.space_id == space_id, Project.is_active == True).all()
# Use joinedload to eagerly load owner, space, and department
# Use selectinload for tasks (one-to-many) to avoid cartesian product issues
projects = db.query(Project).options(
joinedload(Project.owner),
joinedload(Project.space),
joinedload(Project.department),
selectinload(Project.tasks),
).filter(Project.space_id == space_id, Project.is_active == True).all()
# Filter by project access
accessible_projects = [p for p in projects if check_project_access(current_user, p)]
result = []
for project in accessible_projects:
# Access pre-loaded relationships - no additional queries needed
task_count = len(project.tasks) if project.tasks else 0
result.append(ProjectWithDetails(
id=project.id,
@@ -422,6 +432,10 @@ async def list_project_members(
List all members of a project.
Only users with project access can view the member list.
Optimized to avoid N+1 queries by using joinedload for user relationships.
This loads all members and their related users in at most 2 queries instead of
one query per member.
"""
project = db.query(Project).filter(Project.id == project_id, Project.is_active == True).first()
@@ -437,14 +451,20 @@ async def list_project_members(
detail="Access denied",
)
members = db.query(ProjectMember).filter(
# Use joinedload to eagerly load user and added_by_user relationships
# This avoids N+1 queries when accessing member.user and member.added_by_user
members = db.query(ProjectMember).options(
joinedload(ProjectMember.user).joinedload(User.department),
joinedload(ProjectMember.added_by_user),
).filter(
ProjectMember.project_id == project_id
).all()
member_list = []
for member in members:
user = db.query(User).filter(User.id == member.user_id).first()
added_by_user = db.query(User).filter(User.id == member.added_by).first()
# Access pre-loaded relationships - no additional queries needed
user = member.user
added_by_user = member.added_by_user
member_list.append(ProjectMemberWithDetails(
id=member.id,

View File

@@ -3,7 +3,7 @@ import uuid
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload, selectinload
from app.core.database import get_db
from app.core.redis_pubsub import publish_task_event
@@ -110,6 +110,9 @@ async def list_tasks(
The due_after and due_before parameters are useful for calendar view
to fetch tasks within a specific date range.
Optimized to avoid N+1 queries by using selectinload for task relationships.
This batch loads assignees, statuses, creators and subtasks efficiently.
"""
project = db.query(Project).filter(Project.id == project_id).first()
@@ -125,7 +128,15 @@ async def list_tasks(
detail="Access denied",
)
query = db.query(Task).filter(Task.project_id == project_id)
# Use selectinload to eagerly load task relationships
# This avoids N+1 queries when accessing task.assignee, task.status, etc.
query = db.query(Task).options(
selectinload(Task.assignee),
selectinload(Task.status),
selectinload(Task.creator),
selectinload(Task.subtasks),
selectinload(Task.custom_values),
).filter(Task.project_id == project_id)
# Filter deleted tasks (only admin can include deleted)
if include_deleted and current_user.is_system_admin:
@@ -1112,6 +1123,8 @@ async def list_subtasks(
):
"""
List subtasks of a task.
Optimized to avoid N+1 queries by using selectinload for task relationships.
"""
task = db.query(Task).filter(Task.id == task_id).first()
@@ -1127,7 +1140,13 @@ async def list_subtasks(
detail="Access denied",
)
query = db.query(Task).filter(Task.parent_task_id == task_id)
# Use selectinload to eagerly load subtask relationships
query = db.query(Task).options(
selectinload(Task.assignee),
selectinload(Task.status),
selectinload(Task.creator),
selectinload(Task.subtasks),
).filter(Task.parent_task_id == task_id)
# Filter deleted subtasks (only admin can include deleted)
if not (include_deleted and current_user.is_system_admin):

View File

@@ -3,7 +3,7 @@ from sqlalchemy.orm import Session
from sqlalchemy import or_
from typing import List
from app.core.database import get_db
from app.core.database import get_db, escape_like
from app.core.redis import get_redis
from app.models.user import User
from app.models.role import Role
@@ -16,6 +16,7 @@ from app.middleware.auth import (
check_department_access,
)
from app.middleware.audit import get_audit_metadata
from app.middleware.csrf import require_csrf_token
from app.services.audit_service import AuditService
router = APIRouter()
@@ -32,11 +33,13 @@ async def search_users(
Search users by name or email. Used for @mention autocomplete.
Returns users matching the query, limited to same department unless system admin.
"""
# Escape special LIKE characters to prevent injection
escaped_q = escape_like(q)
query = db.query(User).filter(
User.is_active == True,
or_(
User.name.ilike(f"%{q}%"),
User.email.ilike(f"%{q}%"),
User.name.ilike(f"%{escaped_q}%", escape="\\"),
User.email.ilike(f"%{escaped_q}%", escape="\\"),
)
)
@@ -197,6 +200,7 @@ async def assign_role(
@router.patch("/{user_id}/admin", response_model=UserResponse)
@require_csrf_token
async def set_admin_status(
user_id: str,
is_admin: bool,
@@ -205,7 +209,7 @@ async def set_admin_status(
current_user: User = Depends(require_system_admin),
):
"""
Set or revoke system administrator status. Requires system admin.
Set or revoke system administrator status. Requires system admin and CSRF token.
"""
user = db.query(User).filter(User.id == user_id).first()
if not user:

View File

@@ -27,29 +27,37 @@ if os.getenv("TESTING") == "true":
AUTH_TIMEOUT = 1.0
async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
"""Validate token and return user_id and user object."""
async def get_user_from_token(token: str) -> str | None:
"""
Validate token and return user_id.
Returns:
user_id if valid, None otherwise.
Note: This function properly closes the database session after validation.
Do not return ORM objects as they become detached after session close.
"""
payload = decode_access_token(token)
if payload is None:
return None, None
return None
user_id = payload.get("sub")
if user_id is None:
return None, None
return None
# Verify session in Redis
redis_client = get_redis_sync()
stored_token = redis_client.get(f"session:{user_id}")
if stored_token is None or stored_token != token:
return None, None
return None
# Get user from database
# Verify user exists and is active
db = database.SessionLocal()
try:
user = db.query(User).filter(User.id == user_id).first()
if user is None or not user.is_active:
return None, None
return user_id, user
return None
return user_id
finally:
db.close()
@@ -57,7 +65,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
async def authenticate_websocket(
websocket: WebSocket,
query_token: Optional[str] = None
) -> tuple[str | None, User | None, Optional[str]]:
) -> tuple[str | None, Optional[str]]:
"""
Authenticate WebSocket connection.
@@ -67,7 +75,8 @@ async def authenticate_websocket(
2. Query parameter authentication (deprecated, for backward compatibility)
- Client connects with: ?token=<jwt_token>
Returns (user_id, user) if authenticated, (None, None) otherwise.
Returns:
Tuple of (user_id, error_reason). user_id is None if authentication fails.
"""
# If token provided via query parameter (backward compatibility)
if query_token:
@@ -75,10 +84,10 @@ async def authenticate_websocket(
"WebSocket authentication via query parameter is deprecated. "
"Please use first-message authentication for better security."
)
user_id, user = await get_user_from_token(query_token)
user_id = await get_user_from_token(query_token)
if user_id is None:
return None, None, "invalid_token"
return user_id, user, None
return None, "invalid_token"
return user_id, None
# Wait for authentication message with timeout
try:
@@ -90,24 +99,24 @@ async def authenticate_websocket(
msg_type = data.get("type")
if msg_type != "auth":
logger.warning("Expected 'auth' message type, got: %s", msg_type)
return None, None, "invalid_message"
return None, "invalid_message"
token = data.get("token")
if not token:
logger.warning("No token provided in auth message")
return None, None, "missing_token"
return None, "missing_token"
user_id, user = await get_user_from_token(token)
user_id = await get_user_from_token(token)
if user_id is None:
return None, None, "invalid_token"
return user_id, user, None
return None, "invalid_token"
return user_id, None
except asyncio.TimeoutError:
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
return None, None, "timeout"
return None, "timeout"
except Exception as e:
logger.error("Error during WebSocket authentication: %s", e)
return None, None, "error"
return None, "error"
async def get_unread_notifications(user_id: str) -> list[dict]:
@@ -183,7 +192,7 @@ async def websocket_notifications(
await websocket.accept()
# Authenticate
user_id, user, error_reason = await authenticate_websocket(websocket, token)
user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
@@ -306,7 +315,7 @@ async def websocket_notifications(
await manager.disconnect(websocket, user_id)
async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Project | None]:
async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, str | None, str | None]:
"""
Check if user has access to the project.
@@ -315,23 +324,34 @@ async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Pr
project_id: The project's ID
Returns:
Tuple of (has_access: bool, project: Project | None)
Tuple of (has_access: bool, project_title: str | None, error: str | None)
- has_access: True if user can access the project
- project_title: The project title (only if access granted)
- error: Error code if access denied ("user_not_found", "project_not_found", "access_denied")
Note: This function extracts needed data before closing the session to avoid
detached instance errors when accessing ORM object attributes.
"""
db = database.SessionLocal()
try:
# Get the user
user = db.query(User).filter(User.id == user_id).first()
if user is None or not user.is_active:
return False, None
return False, None, "user_not_found"
# Get the project
project = db.query(Project).filter(Project.id == project_id).first()
if project is None:
return False, None
return False, None, "project_not_found"
# Check access using existing middleware function
has_access = check_project_access(user, project)
return has_access, project
if not has_access:
return False, None, "access_denied"
# Extract title while session is still open
project_title = project.title
return True, project_title, None
finally:
db.close()
@@ -371,7 +391,7 @@ async def websocket_project_sync(
await websocket.accept()
# Authenticate user
user_id, user, error_reason = await authenticate_websocket(websocket, token)
user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
@@ -380,14 +400,13 @@ async def websocket_project_sync(
return
# Verify user has access to the project
has_access, project = await verify_project_access(user_id, project_id)
has_access, project_title, access_error = await verify_project_access(user_id, project_id)
if not has_access:
await websocket.close(code=4003, reason="Access denied to this project")
return
if project is None:
await websocket.close(code=4004, reason="Project not found")
if access_error == "project_not_found":
await websocket.close(code=4004, reason="Project not found")
else:
await websocket.close(code=4003, reason="Access denied to this project")
return
# Join project room
@@ -413,7 +432,7 @@ async def websocket_project_sync(
"data": {
"project_id": project_id,
"user_id": user_id,
"project_title": project.title if project else None,
"project_title": project_title,
},
})