feat: implement security, error resilience, and query optimization proposals
Security Validation (enhance-security-validation): - JWT secret validation with entropy checking and pattern detection - CSRF protection middleware with token generation/validation - Frontend CSRF token auto-injection for DELETE/PUT/PATCH requests - MIME type validation with magic bytes detection for file uploads Error Resilience (add-error-resilience): - React ErrorBoundary component with fallback UI and retry functionality - ErrorBoundaryWithI18n wrapper for internationalization support - Page-level and section-level error boundaries in App.tsx Query Performance (optimize-query-performance): - Query monitoring utility with threshold warnings - N+1 query fixes using joinedload/selectinload - Optimized project members, tasks, and subtasks endpoints Bug Fixes: - WebSocket session management (P0): Return primitives instead of ORM objects - LIKE query injection (P1): Escape special characters in search queries Tests: 543 backend tests, 56 frontend tests passing Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from app.services.encryption_service import (
|
||||
MasterKeyNotConfiguredError,
|
||||
DecryptionError,
|
||||
)
|
||||
from app.middleware.csrf import require_csrf_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -610,13 +611,14 @@ async def download_attachment(
|
||||
|
||||
|
||||
@router.delete("/attachments/{attachment_id}")
|
||||
@require_csrf_token
|
||||
async def delete_attachment(
|
||||
attachment_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Soft delete an attachment."""
|
||||
"""Soft delete an attachment. Requires CSRF token in X-CSRF-Token header."""
|
||||
attachment = get_attachment_with_access_check(db, attachment_id, current_user, require_edit=True)
|
||||
|
||||
# Soft delete
|
||||
|
||||
@@ -8,7 +8,7 @@ from app.core.redis import get_redis
|
||||
from app.core.rate_limiter import limiter
|
||||
from app.models.user import User
|
||||
from app.models.audit_log import AuditAction
|
||||
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo
|
||||
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo, CSRFTokenResponse
|
||||
from app.services.auth_client import (
|
||||
verify_credentials,
|
||||
AuthAPIError,
|
||||
@@ -16,6 +16,7 @@ from app.services.auth_client import (
|
||||
)
|
||||
from app.services.audit_service import AuditService
|
||||
from app.middleware.auth import get_current_user
|
||||
from app.middleware.csrf import get_csrf_token_for_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -182,3 +183,23 @@ async def get_current_user_info(
|
||||
department_id=current_user.department_id,
|
||||
is_system_admin=current_user.is_system_admin,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/csrf-token", response_model=CSRFTokenResponse)
|
||||
async def get_csrf_token(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Get a CSRF token for the current user.
|
||||
|
||||
The CSRF token should be included in the X-CSRF-Token header
|
||||
for all sensitive state-changing operations (DELETE, PUT, PATCH).
|
||||
|
||||
Token expires after 1 hour and should be refreshed.
|
||||
"""
|
||||
csrf_token = get_csrf_token_for_user(current_user.id)
|
||||
|
||||
return CSRFTokenResponse(
|
||||
csrf_token=csrf_token,
|
||||
expires_in=3600, # 1 hour
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models import User, Space, Project, TaskStatus, AuditAction, ProjectMember, ProjectTemplate, CustomField
|
||||
@@ -55,6 +55,8 @@ async def list_projects_in_space(
|
||||
):
|
||||
"""
|
||||
List all projects in a space that the user can access.
|
||||
|
||||
Optimized to avoid N+1 queries by using joinedload/selectinload for relationships.
|
||||
"""
|
||||
space = db.query(Space).filter(Space.id == space_id, Space.is_active == True).first()
|
||||
|
||||
@@ -70,13 +72,21 @@ async def list_projects_in_space(
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
projects = db.query(Project).filter(Project.space_id == space_id, Project.is_active == True).all()
|
||||
# Use joinedload to eagerly load owner, space, and department
|
||||
# Use selectinload for tasks (one-to-many) to avoid cartesian product issues
|
||||
projects = db.query(Project).options(
|
||||
joinedload(Project.owner),
|
||||
joinedload(Project.space),
|
||||
joinedload(Project.department),
|
||||
selectinload(Project.tasks),
|
||||
).filter(Project.space_id == space_id, Project.is_active == True).all()
|
||||
|
||||
# Filter by project access
|
||||
accessible_projects = [p for p in projects if check_project_access(current_user, p)]
|
||||
|
||||
result = []
|
||||
for project in accessible_projects:
|
||||
# Access pre-loaded relationships - no additional queries needed
|
||||
task_count = len(project.tasks) if project.tasks else 0
|
||||
result.append(ProjectWithDetails(
|
||||
id=project.id,
|
||||
@@ -422,6 +432,10 @@ async def list_project_members(
|
||||
List all members of a project.
|
||||
|
||||
Only users with project access can view the member list.
|
||||
|
||||
Optimized to avoid N+1 queries by using joinedload for user relationships.
|
||||
This loads all members and their related users in at most 2 queries instead of
|
||||
one query per member.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id, Project.is_active == True).first()
|
||||
|
||||
@@ -437,14 +451,20 @@ async def list_project_members(
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
members = db.query(ProjectMember).filter(
|
||||
# Use joinedload to eagerly load user and added_by_user relationships
|
||||
# This avoids N+1 queries when accessing member.user and member.added_by_user
|
||||
members = db.query(ProjectMember).options(
|
||||
joinedload(ProjectMember.user).joinedload(User.department),
|
||||
joinedload(ProjectMember.added_by_user),
|
||||
).filter(
|
||||
ProjectMember.project_id == project_id
|
||||
).all()
|
||||
|
||||
member_list = []
|
||||
for member in members:
|
||||
user = db.query(User).filter(User.id == member.user_id).first()
|
||||
added_by_user = db.query(User).filter(User.id == member.added_by).first()
|
||||
# Access pre-loaded relationships - no additional queries needed
|
||||
user = member.user
|
||||
added_by_user = member.added_by_user
|
||||
|
||||
member_list.append(ProjectMemberWithDetails(
|
||||
id=member.id,
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.redis_pubsub import publish_task_event
|
||||
@@ -110,6 +110,9 @@ async def list_tasks(
|
||||
|
||||
The due_after and due_before parameters are useful for calendar view
|
||||
to fetch tasks within a specific date range.
|
||||
|
||||
Optimized to avoid N+1 queries by using selectinload for task relationships.
|
||||
This batch loads assignees, statuses, creators and subtasks efficiently.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
|
||||
@@ -125,7 +128,15 @@ async def list_tasks(
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
query = db.query(Task).filter(Task.project_id == project_id)
|
||||
# Use selectinload to eagerly load task relationships
|
||||
# This avoids N+1 queries when accessing task.assignee, task.status, etc.
|
||||
query = db.query(Task).options(
|
||||
selectinload(Task.assignee),
|
||||
selectinload(Task.status),
|
||||
selectinload(Task.creator),
|
||||
selectinload(Task.subtasks),
|
||||
selectinload(Task.custom_values),
|
||||
).filter(Task.project_id == project_id)
|
||||
|
||||
# Filter deleted tasks (only admin can include deleted)
|
||||
if include_deleted and current_user.is_system_admin:
|
||||
@@ -1112,6 +1123,8 @@ async def list_subtasks(
|
||||
):
|
||||
"""
|
||||
List subtasks of a task.
|
||||
|
||||
Optimized to avoid N+1 queries by using selectinload for task relationships.
|
||||
"""
|
||||
task = db.query(Task).filter(Task.id == task_id).first()
|
||||
|
||||
@@ -1127,7 +1140,13 @@ async def list_subtasks(
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
query = db.query(Task).filter(Task.parent_task_id == task_id)
|
||||
# Use selectinload to eagerly load subtask relationships
|
||||
query = db.query(Task).options(
|
||||
selectinload(Task.assignee),
|
||||
selectinload(Task.status),
|
||||
selectinload(Task.creator),
|
||||
selectinload(Task.subtasks),
|
||||
).filter(Task.parent_task_id == task_id)
|
||||
|
||||
# Filter deleted subtasks (only admin can include deleted)
|
||||
if not (include_deleted and current_user.is_system_admin):
|
||||
|
||||
@@ -3,7 +3,7 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_
|
||||
from typing import List
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.database import get_db, escape_like
|
||||
from app.core.redis import get_redis
|
||||
from app.models.user import User
|
||||
from app.models.role import Role
|
||||
@@ -16,6 +16,7 @@ from app.middleware.auth import (
|
||||
check_department_access,
|
||||
)
|
||||
from app.middleware.audit import get_audit_metadata
|
||||
from app.middleware.csrf import require_csrf_token
|
||||
from app.services.audit_service import AuditService
|
||||
|
||||
router = APIRouter()
|
||||
@@ -32,11 +33,13 @@ async def search_users(
|
||||
Search users by name or email. Used for @mention autocomplete.
|
||||
Returns users matching the query, limited to same department unless system admin.
|
||||
"""
|
||||
# Escape special LIKE characters to prevent injection
|
||||
escaped_q = escape_like(q)
|
||||
query = db.query(User).filter(
|
||||
User.is_active == True,
|
||||
or_(
|
||||
User.name.ilike(f"%{q}%"),
|
||||
User.email.ilike(f"%{q}%"),
|
||||
User.name.ilike(f"%{escaped_q}%", escape="\\"),
|
||||
User.email.ilike(f"%{escaped_q}%", escape="\\"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -197,6 +200,7 @@ async def assign_role(
|
||||
|
||||
|
||||
@router.patch("/{user_id}/admin", response_model=UserResponse)
|
||||
@require_csrf_token
|
||||
async def set_admin_status(
|
||||
user_id: str,
|
||||
is_admin: bool,
|
||||
@@ -205,7 +209,7 @@ async def set_admin_status(
|
||||
current_user: User = Depends(require_system_admin),
|
||||
):
|
||||
"""
|
||||
Set or revoke system administrator status. Requires system admin.
|
||||
Set or revoke system administrator status. Requires system admin and CSRF token.
|
||||
"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
|
||||
@@ -27,29 +27,37 @@ if os.getenv("TESTING") == "true":
|
||||
AUTH_TIMEOUT = 1.0
|
||||
|
||||
|
||||
async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
"""Validate token and return user_id and user object."""
|
||||
async def get_user_from_token(token: str) -> str | None:
|
||||
"""
|
||||
Validate token and return user_id.
|
||||
|
||||
Returns:
|
||||
user_id if valid, None otherwise.
|
||||
|
||||
Note: This function properly closes the database session after validation.
|
||||
Do not return ORM objects as they become detached after session close.
|
||||
"""
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
return None, None
|
||||
return None
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
return None, None
|
||||
return None
|
||||
|
||||
# Verify session in Redis
|
||||
redis_client = get_redis_sync()
|
||||
stored_token = redis_client.get(f"session:{user_id}")
|
||||
if stored_token is None or stored_token != token:
|
||||
return None, None
|
||||
return None
|
||||
|
||||
# Get user from database
|
||||
# Verify user exists and is active
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None or not user.is_active:
|
||||
return None, None
|
||||
return user_id, user
|
||||
return None
|
||||
return user_id
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -57,7 +65,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
async def authenticate_websocket(
|
||||
websocket: WebSocket,
|
||||
query_token: Optional[str] = None
|
||||
) -> tuple[str | None, User | None, Optional[str]]:
|
||||
) -> tuple[str | None, Optional[str]]:
|
||||
"""
|
||||
Authenticate WebSocket connection.
|
||||
|
||||
@@ -67,7 +75,8 @@ async def authenticate_websocket(
|
||||
2. Query parameter authentication (deprecated, for backward compatibility)
|
||||
- Client connects with: ?token=<jwt_token>
|
||||
|
||||
Returns (user_id, user) if authenticated, (None, None) otherwise.
|
||||
Returns:
|
||||
Tuple of (user_id, error_reason). user_id is None if authentication fails.
|
||||
"""
|
||||
# If token provided via query parameter (backward compatibility)
|
||||
if query_token:
|
||||
@@ -75,10 +84,10 @@ async def authenticate_websocket(
|
||||
"WebSocket authentication via query parameter is deprecated. "
|
||||
"Please use first-message authentication for better security."
|
||||
)
|
||||
user_id, user = await get_user_from_token(query_token)
|
||||
user_id = await get_user_from_token(query_token)
|
||||
if user_id is None:
|
||||
return None, None, "invalid_token"
|
||||
return user_id, user, None
|
||||
return None, "invalid_token"
|
||||
return user_id, None
|
||||
|
||||
# Wait for authentication message with timeout
|
||||
try:
|
||||
@@ -90,24 +99,24 @@ async def authenticate_websocket(
|
||||
msg_type = data.get("type")
|
||||
if msg_type != "auth":
|
||||
logger.warning("Expected 'auth' message type, got: %s", msg_type)
|
||||
return None, None, "invalid_message"
|
||||
return None, "invalid_message"
|
||||
|
||||
token = data.get("token")
|
||||
if not token:
|
||||
logger.warning("No token provided in auth message")
|
||||
return None, None, "missing_token"
|
||||
return None, "missing_token"
|
||||
|
||||
user_id, user = await get_user_from_token(token)
|
||||
user_id = await get_user_from_token(token)
|
||||
if user_id is None:
|
||||
return None, None, "invalid_token"
|
||||
return user_id, user, None
|
||||
return None, "invalid_token"
|
||||
return user_id, None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
|
||||
return None, None, "timeout"
|
||||
return None, "timeout"
|
||||
except Exception as e:
|
||||
logger.error("Error during WebSocket authentication: %s", e)
|
||||
return None, None, "error"
|
||||
return None, "error"
|
||||
|
||||
|
||||
async def get_unread_notifications(user_id: str) -> list[dict]:
|
||||
@@ -183,7 +192,7 @@ async def websocket_notifications(
|
||||
await websocket.accept()
|
||||
|
||||
# Authenticate
|
||||
user_id, user, error_reason = await authenticate_websocket(websocket, token)
|
||||
user_id, error_reason = await authenticate_websocket(websocket, token)
|
||||
|
||||
if user_id is None:
|
||||
if error_reason == "invalid_token":
|
||||
@@ -306,7 +315,7 @@ async def websocket_notifications(
|
||||
await manager.disconnect(websocket, user_id)
|
||||
|
||||
|
||||
async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Project | None]:
|
||||
async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, str | None, str | None]:
|
||||
"""
|
||||
Check if user has access to the project.
|
||||
|
||||
@@ -315,23 +324,34 @@ async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Pr
|
||||
project_id: The project's ID
|
||||
|
||||
Returns:
|
||||
Tuple of (has_access: bool, project: Project | None)
|
||||
Tuple of (has_access: bool, project_title: str | None, error: str | None)
|
||||
- has_access: True if user can access the project
|
||||
- project_title: The project title (only if access granted)
|
||||
- error: Error code if access denied ("user_not_found", "project_not_found", "access_denied")
|
||||
|
||||
Note: This function extracts needed data before closing the session to avoid
|
||||
detached instance errors when accessing ORM object attributes.
|
||||
"""
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
# Get the user
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None or not user.is_active:
|
||||
return False, None
|
||||
return False, None, "user_not_found"
|
||||
|
||||
# Get the project
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if project is None:
|
||||
return False, None
|
||||
return False, None, "project_not_found"
|
||||
|
||||
# Check access using existing middleware function
|
||||
has_access = check_project_access(user, project)
|
||||
return has_access, project
|
||||
if not has_access:
|
||||
return False, None, "access_denied"
|
||||
|
||||
# Extract title while session is still open
|
||||
project_title = project.title
|
||||
return True, project_title, None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -371,7 +391,7 @@ async def websocket_project_sync(
|
||||
await websocket.accept()
|
||||
|
||||
# Authenticate user
|
||||
user_id, user, error_reason = await authenticate_websocket(websocket, token)
|
||||
user_id, error_reason = await authenticate_websocket(websocket, token)
|
||||
|
||||
if user_id is None:
|
||||
if error_reason == "invalid_token":
|
||||
@@ -380,14 +400,13 @@ async def websocket_project_sync(
|
||||
return
|
||||
|
||||
# Verify user has access to the project
|
||||
has_access, project = await verify_project_access(user_id, project_id)
|
||||
has_access, project_title, access_error = await verify_project_access(user_id, project_id)
|
||||
|
||||
if not has_access:
|
||||
await websocket.close(code=4003, reason="Access denied to this project")
|
||||
return
|
||||
|
||||
if project is None:
|
||||
await websocket.close(code=4004, reason="Project not found")
|
||||
if access_error == "project_not_found":
|
||||
await websocket.close(code=4004, reason="Project not found")
|
||||
else:
|
||||
await websocket.close(code=4003, reason="Access denied to this project")
|
||||
return
|
||||
|
||||
# Join project room
|
||||
@@ -413,7 +432,7 @@ async def websocket_project_sync(
|
||||
"data": {
|
||||
"project_id": project_id,
|
||||
"user_id": user_id,
|
||||
"project_title": project.title if project else None,
|
||||
"project_title": project_title,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@@ -122,6 +122,11 @@ class Settings(BaseSettings):
|
||||
RATE_LIMIT_SENSITIVE: str = "20/minute" # Attachments, password change, report export
|
||||
RATE_LIMIT_HEAVY: str = "5/minute" # Report generation, bulk operations
|
||||
|
||||
# Development Mode Settings
|
||||
DEBUG: bool = False # Enable debug mode for development
|
||||
QUERY_LOGGING: bool = False # Enable SQLAlchemy query logging
|
||||
QUERY_COUNT_THRESHOLD: int = 10 # Warn when query count exceeds this threshold
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
@@ -104,6 +104,10 @@ def _on_invalidate(dbapi_conn, connection_record, exception):
|
||||
# Start pool statistics logging on module load
|
||||
_start_pool_stats_logging()
|
||||
|
||||
# Set up query logging if enabled
|
||||
from app.core.query_monitor import setup_query_logging
|
||||
setup_query_logging(engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Dependency for getting database session."""
|
||||
@@ -127,3 +131,25 @@ def get_pool_status() -> dict:
|
||||
"total_checkins": _pool_stats["checkins"],
|
||||
"invalidated_connections": _pool_stats["invalidated_connections"],
|
||||
}
|
||||
|
||||
|
||||
def escape_like(value: str) -> str:
|
||||
"""
|
||||
Escape special characters for SQL LIKE queries.
|
||||
|
||||
Escapes '%' and '_' characters which have special meaning in LIKE patterns.
|
||||
This prevents LIKE injection attacks where user input could match unintended patterns.
|
||||
|
||||
Args:
|
||||
value: The user input string to escape
|
||||
|
||||
Returns:
|
||||
Escaped string safe for use in LIKE patterns
|
||||
|
||||
Example:
|
||||
>>> escape_like("test%value")
|
||||
'test\\%value'
|
||||
>>> escape_like("user_name")
|
||||
'user\\_name'
|
||||
"""
|
||||
return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
167
backend/app/core/query_monitor.py
Normal file
167
backend/app/core/query_monitor.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Query monitoring utilities for detecting N+1 queries and performance issues.
|
||||
|
||||
This module provides:
|
||||
1. Query counting per request in development mode
|
||||
2. SQLAlchemy event listeners for query logging
|
||||
3. Threshold-based warnings for excessive queries
|
||||
"""
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Callable, Any
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thread-local storage for per-request query counting
|
||||
_query_context = threading.local()
|
||||
|
||||
|
||||
class QueryCounter:
|
||||
"""
|
||||
Context manager for counting database queries within a request.
|
||||
|
||||
Usage:
|
||||
with QueryCounter() as counter:
|
||||
# ... execute queries ...
|
||||
print(f"Executed {counter.count} queries")
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: Optional[int] = None, context_name: str = "request"):
|
||||
self.threshold = threshold or settings.QUERY_COUNT_THRESHOLD
|
||||
self.context_name = context_name
|
||||
self.count = 0
|
||||
self.queries = []
|
||||
self.start_time = None
|
||||
self.total_time = 0.0
|
||||
|
||||
def __enter__(self):
|
||||
self.count = 0
|
||||
self.queries = []
|
||||
self.start_time = time.time()
|
||||
_query_context.counter = self
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.total_time = time.time() - self.start_time
|
||||
_query_context.counter = None
|
||||
|
||||
# Log warning if threshold exceeded
|
||||
if self.count > self.threshold:
|
||||
logger.warning(
|
||||
"Query count threshold exceeded in %s: %d queries (threshold: %d, time: %.3fs)",
|
||||
self.context_name,
|
||||
self.count,
|
||||
self.threshold,
|
||||
self.total_time,
|
||||
)
|
||||
if settings.DEBUG:
|
||||
# In debug mode, also log the individual queries
|
||||
for i, (sql, duration) in enumerate(self.queries[:20], 1):
|
||||
logger.debug(" Query %d (%.3fs): %s", i, duration, sql[:200])
|
||||
if len(self.queries) > 20:
|
||||
logger.debug(" ... and %d more queries", len(self.queries) - 20)
|
||||
elif settings.DEBUG and self.count > 0:
|
||||
logger.debug(
|
||||
"Query count for %s: %d queries in %.3fs",
|
||||
self.context_name,
|
||||
self.count,
|
||||
self.total_time,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
def record_query(self, statement: str, duration: float):
|
||||
"""Record a query execution."""
|
||||
self.count += 1
|
||||
if settings.DEBUG:
|
||||
self.queries.append((statement, duration))
|
||||
|
||||
|
||||
def get_current_counter() -> Optional[QueryCounter]:
|
||||
"""Get the current request's query counter, if any."""
|
||||
return getattr(_query_context, 'counter', None)
|
||||
|
||||
|
||||
def setup_query_logging(engine: Engine):
|
||||
"""
|
||||
Set up SQLAlchemy event listeners for query logging.
|
||||
|
||||
This should be called once during application startup.
|
||||
Only activates if QUERY_LOGGING is enabled in settings.
|
||||
"""
|
||||
if not settings.QUERY_LOGGING:
|
||||
logger.info("Query logging is disabled")
|
||||
return
|
||||
|
||||
logger.info("Setting up query logging with threshold=%d", settings.QUERY_COUNT_THRESHOLD)
|
||||
|
||||
@event.listens_for(engine, "before_cursor_execute")
|
||||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
conn.info.setdefault('query_start_time', []).append(time.time())
|
||||
|
||||
@event.listens_for(engine, "after_cursor_execute")
|
||||
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
start_times = conn.info.get('query_start_time', [])
|
||||
duration = time.time() - start_times.pop() if start_times else 0.0
|
||||
|
||||
# Record in current counter if active
|
||||
counter = get_current_counter()
|
||||
if counter:
|
||||
counter.record_query(statement, duration)
|
||||
|
||||
# Also log individual queries if in debug mode
|
||||
if settings.DEBUG:
|
||||
logger.debug("SQL (%.3fs): %s", duration, statement[:500])
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_queries(context_name: str = "operation", threshold: Optional[int] = None):
|
||||
"""
|
||||
Context manager to count queries for a specific operation.
|
||||
|
||||
Args:
|
||||
context_name: Name for logging purposes
|
||||
threshold: Override the default query count threshold
|
||||
|
||||
Usage:
|
||||
with count_queries("list_members") as counter:
|
||||
members = db.query(ProjectMember).all()
|
||||
for member in members:
|
||||
print(member.user.name) # N+1 query!
|
||||
|
||||
# After block, logs warning if threshold exceeded
|
||||
print(f"Total queries: {counter.count}")
|
||||
"""
|
||||
with QueryCounter(threshold=threshold, context_name=context_name) as counter:
|
||||
yield counter
|
||||
|
||||
|
||||
def assert_query_count(max_queries: int):
|
||||
"""
|
||||
Decorator for testing that asserts maximum query count.
|
||||
|
||||
Usage in tests:
|
||||
@assert_query_count(5)
|
||||
def test_list_members():
|
||||
# Should use at most 5 queries
|
||||
response = client.get("/api/projects/xxx/members")
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
def wrapper(*args, **kwargs):
|
||||
with QueryCounter(threshold=max_queries, context_name=func.__name__) as counter:
|
||||
result = func(*args, **kwargs)
|
||||
if counter.count > max_queries:
|
||||
raise AssertionError(
|
||||
f"Query count {counter.count} exceeded maximum {max_queries} "
|
||||
f"in {func.__name__}"
|
||||
)
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
@@ -1,8 +1,283 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, Tuple
|
||||
from jose import jwt, JWTError
|
||||
import logging
|
||||
import math
|
||||
import hashlib
|
||||
import secrets
|
||||
import hmac
|
||||
from collections import Counter
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants for JWT secret validation
|
||||
MIN_SECRET_LENGTH = 32
|
||||
MIN_ENTROPY_BITS = 128 # Minimum entropy in bits for a secure secret
|
||||
COMMON_WEAK_PATTERNS = [
|
||||
"password", "secret", "changeme", "admin", "test", "demo",
|
||||
"123456", "qwerty", "abc123", "letmein", "welcome",
|
||||
]
|
||||
|
||||
|
||||
def calculate_entropy(data: str) -> float:
|
||||
"""
|
||||
Calculate Shannon entropy of a string in bits.
|
||||
|
||||
Higher entropy indicates more randomness and thus a stronger secret.
|
||||
A perfectly random string of length n with k possible characters has
|
||||
entropy of n * log2(k) bits.
|
||||
|
||||
Args:
|
||||
data: The string to calculate entropy for
|
||||
|
||||
Returns:
|
||||
Entropy value in bits
|
||||
"""
|
||||
if not data:
|
||||
return 0.0
|
||||
|
||||
# Count character frequencies
|
||||
char_counts = Counter(data)
|
||||
length = len(data)
|
||||
|
||||
# Calculate Shannon entropy
|
||||
entropy = 0.0
|
||||
for count in char_counts.values():
|
||||
if count > 0:
|
||||
probability = count / length
|
||||
entropy -= probability * math.log2(probability)
|
||||
|
||||
# Return total entropy in bits (per-character entropy * length)
|
||||
return entropy * length
|
||||
|
||||
|
||||
def has_repeating_patterns(secret: str) -> bool:
|
||||
"""
|
||||
Check if the secret contains obvious repeating patterns.
|
||||
|
||||
Args:
|
||||
secret: The secret string to check
|
||||
|
||||
Returns:
|
||||
True if repeating patterns are detected
|
||||
"""
|
||||
if len(secret) < 8:
|
||||
return False
|
||||
|
||||
# Check for repeating character sequences
|
||||
for pattern_len in range(2, len(secret) // 3 + 1):
|
||||
pattern = secret[:pattern_len]
|
||||
if pattern * (len(secret) // pattern_len) == secret[:len(pattern) * (len(secret) // pattern_len)]:
|
||||
# More than 50% of the string is the same pattern repeated
|
||||
if (len(secret) // pattern_len) >= 3:
|
||||
return True
|
||||
|
||||
# Check for consecutive same characters
|
||||
consecutive_count = 1
|
||||
for i in range(1, len(secret)):
|
||||
if secret[i] == secret[i-1]:
|
||||
consecutive_count += 1
|
||||
if consecutive_count >= len(secret) // 2:
|
||||
return True
|
||||
else:
|
||||
consecutive_count = 1
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def validate_jwt_secret_strength(secret: str) -> Tuple[bool, list]:
|
||||
"""
|
||||
Validate JWT secret key strength.
|
||||
|
||||
Checks:
|
||||
1. Minimum length (32 characters)
|
||||
2. Entropy (minimum 128 bits)
|
||||
3. Common weak patterns
|
||||
4. Repeating patterns
|
||||
|
||||
Args:
|
||||
secret: The JWT secret to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, list_of_warnings)
|
||||
"""
|
||||
warnings = []
|
||||
is_valid = True
|
||||
|
||||
# Check minimum length
|
||||
if len(secret) < MIN_SECRET_LENGTH:
|
||||
warnings.append(
|
||||
f"JWT secret is too short ({len(secret)} chars). "
|
||||
f"Minimum recommended length is {MIN_SECRET_LENGTH} characters."
|
||||
)
|
||||
is_valid = False
|
||||
|
||||
# Calculate and check entropy
|
||||
entropy = calculate_entropy(secret)
|
||||
if entropy < MIN_ENTROPY_BITS:
|
||||
warnings.append(
|
||||
f"JWT secret has low entropy ({entropy:.1f} bits). "
|
||||
f"Minimum recommended entropy is {MIN_ENTROPY_BITS} bits. "
|
||||
"Consider using a cryptographically random secret."
|
||||
)
|
||||
# Low entropy alone doesn't make it invalid, but it's a warning
|
||||
|
||||
# Check for common weak patterns
|
||||
secret_lower = secret.lower()
|
||||
for pattern in COMMON_WEAK_PATTERNS:
|
||||
if pattern in secret_lower:
|
||||
warnings.append(
|
||||
f"JWT secret contains common weak pattern: '{pattern}'. "
|
||||
"Use a cryptographically random secret."
|
||||
)
|
||||
break
|
||||
|
||||
# Check for repeating patterns
|
||||
if has_repeating_patterns(secret):
|
||||
warnings.append(
|
||||
"JWT secret contains repeating patterns. "
|
||||
"Use a cryptographically random secret."
|
||||
)
|
||||
|
||||
return is_valid, warnings
|
||||
|
||||
|
||||
def validate_jwt_secret_on_startup() -> None:
|
||||
"""
|
||||
Validate JWT secret strength on application startup.
|
||||
|
||||
Logs warnings for weak secrets and raises an error in production
|
||||
if the secret is critically weak.
|
||||
"""
|
||||
import os
|
||||
|
||||
secret = settings.JWT_SECRET_KEY
|
||||
is_valid, warnings = validate_jwt_secret_strength(secret)
|
||||
|
||||
# Log all warnings
|
||||
for warning in warnings:
|
||||
logger.warning("JWT Security Warning: %s", warning)
|
||||
|
||||
# In production, enforce stricter requirements
|
||||
is_production = os.environ.get("ENVIRONMENT", "").lower() == "production"
|
||||
|
||||
if not is_valid:
|
||||
if is_production:
|
||||
logger.critical(
|
||||
"JWT secret does not meet security requirements. "
|
||||
"Application startup blocked in production mode. "
|
||||
"Please configure a strong JWT_SECRET_KEY (minimum 32 characters)."
|
||||
)
|
||||
raise ValueError(
|
||||
"JWT_SECRET_KEY does not meet minimum security requirements. "
|
||||
"See logs for details."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"JWT secret does not meet security requirements. "
|
||||
"This would block startup in production mode."
|
||||
)
|
||||
|
||||
if warnings:
|
||||
logger.info(
|
||||
"JWT secret validation completed with %d warning(s). "
|
||||
"Consider using: python -c \"import secrets; print(secrets.token_urlsafe(48))\" "
|
||||
"to generate a strong secret.",
|
||||
len(warnings)
|
||||
)
|
||||
else:
|
||||
logger.info("JWT secret validation passed. Secret meets security requirements.")
|
||||
|
||||
|
||||
# CSRF Token Functions
|
||||
CSRF_TOKEN_LENGTH = 32
|
||||
CSRF_TOKEN_EXPIRY_SECONDS = 3600 # 1 hour
|
||||
|
||||
|
||||
def generate_csrf_token(user_id: str) -> str:
|
||||
"""
|
||||
Generate a CSRF token for a user.
|
||||
|
||||
The token is a combination of:
|
||||
- Random bytes for unpredictability
|
||||
- User ID binding to prevent token reuse across users
|
||||
- HMAC signature for integrity
|
||||
|
||||
Args:
|
||||
user_id: The user's ID to bind the token to
|
||||
|
||||
Returns:
|
||||
CSRF token string
|
||||
"""
|
||||
# Generate random token
|
||||
random_part = secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||
|
||||
# Create timestamp for expiry checking
|
||||
timestamp = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Create the token payload
|
||||
payload = f"{random_part}:{user_id}:{timestamp}"
|
||||
|
||||
# Sign with HMAC using JWT secret
|
||||
signature = hmac.new(
|
||||
settings.JWT_SECRET_KEY.encode(),
|
||||
payload.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()[:16]
|
||||
|
||||
# Return combined token
|
||||
return f"{payload}:{signature}"
|
||||
|
||||
|
||||
def validate_csrf_token(token: str, user_id: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate a CSRF token.
|
||||
|
||||
Args:
|
||||
token: The CSRF token to validate
|
||||
user_id: The expected user ID
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
if not token:
|
||||
return False, "CSRF token is required"
|
||||
|
||||
try:
|
||||
parts = token.split(":")
|
||||
if len(parts) != 4:
|
||||
return False, "Invalid CSRF token format"
|
||||
|
||||
random_part, token_user_id, timestamp_str, signature = parts
|
||||
|
||||
# Verify user ID matches
|
||||
if token_user_id != user_id:
|
||||
return False, "CSRF token user mismatch"
|
||||
|
||||
# Verify timestamp (check expiry)
|
||||
timestamp = int(timestamp_str)
|
||||
current_time = int(datetime.now(timezone.utc).timestamp())
|
||||
if current_time - timestamp > CSRF_TOKEN_EXPIRY_SECONDS:
|
||||
return False, "CSRF token expired"
|
||||
|
||||
# Verify signature
|
||||
payload = f"{random_part}:{token_user_id}:{timestamp_str}"
|
||||
expected_signature = hmac.new(
|
||||
settings.JWT_SECRET_KEY.encode(),
|
||||
payload.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()[:16]
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
return False, "CSRF token signature invalid"
|
||||
|
||||
return True, ""
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
return False, f"CSRF token validation error: {str(e)}"
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""
|
||||
|
||||
@@ -20,6 +20,12 @@ async def lifespan(app: FastAPI):
|
||||
testing = os.environ.get("TESTING", "").lower() in ("true", "1", "yes")
|
||||
scheduler_disabled = os.environ.get("DISABLE_SCHEDULER", "").lower() in ("true", "1", "yes")
|
||||
start_background_jobs = not testing and not scheduler_disabled
|
||||
|
||||
# Startup security validation
|
||||
if not testing:
|
||||
from app.core.security import validate_jwt_secret_on_startup
|
||||
validate_jwt_secret_on_startup()
|
||||
|
||||
# Startup
|
||||
if start_background_jobs:
|
||||
start_scheduler()
|
||||
|
||||
167
backend/app/middleware/csrf.py
Normal file
167
backend/app/middleware/csrf.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
CSRF (Cross-Site Request Forgery) Protection Middleware.
|
||||
|
||||
This module provides CSRF protection for sensitive state-changing operations.
|
||||
It validates CSRF tokens for specified protected endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import Request, HTTPException, status, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from typing import Optional, Callable, List
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
from app.core.security import validate_csrf_token, generate_csrf_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Header name for CSRF token
|
||||
CSRF_TOKEN_HEADER = "X-CSRF-Token"
|
||||
|
||||
# List of endpoint patterns that require CSRF protection
|
||||
# These are sensitive state-changing operations
|
||||
CSRF_PROTECTED_PATTERNS = [
|
||||
# User operations
|
||||
"/api/v1/users/{user_id}/admin", # Admin status change
|
||||
"/api/users/{user_id}/admin", # Legacy
|
||||
# Password changes would go here if implemented
|
||||
# Delete operations
|
||||
"/api/attachments/{attachment_id}", # DELETE method
|
||||
"/api/tasks/{task_id}", # DELETE method (soft delete)
|
||||
"/api/projects/{project_id}", # DELETE method
|
||||
]
|
||||
|
||||
# Methods that require CSRF protection
|
||||
CSRF_PROTECTED_METHODS = ["DELETE", "PUT", "PATCH"]
|
||||
|
||||
|
||||
class CSRFProtectionError(HTTPException):
|
||||
"""Custom exception for CSRF validation failures."""
|
||||
|
||||
def __init__(self, detail: str = "CSRF validation failed"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=detail
|
||||
)
|
||||
|
||||
|
||||
def require_csrf_token(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to require CSRF token validation for an endpoint.
|
||||
|
||||
Usage:
|
||||
@router.delete("/resource/{id}")
|
||||
@require_csrf_token
|
||||
async def delete_resource(request: Request, id: str, current_user: User = Depends(get_current_user)):
|
||||
...
|
||||
|
||||
The decorator validates the X-CSRF-Token header against the current user.
|
||||
"""
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract request and current_user from kwargs
|
||||
request: Optional[Request] = kwargs.get("request")
|
||||
current_user = kwargs.get("current_user")
|
||||
|
||||
if request is None:
|
||||
# Try to find request in args (for methods where request is positional)
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if request is None:
|
||||
logger.error("CSRF validation failed: Request object not found")
|
||||
raise CSRFProtectionError("Internal error: Request not available")
|
||||
|
||||
if current_user is None:
|
||||
logger.error("CSRF validation failed: User not authenticated")
|
||||
raise CSRFProtectionError("Authentication required for CSRF-protected endpoint")
|
||||
|
||||
# Get CSRF token from header
|
||||
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
|
||||
|
||||
if not csrf_token:
|
||||
logger.warning(
|
||||
"CSRF validation failed: Missing token for user %s on %s %s",
|
||||
current_user.id, request.method, request.url.path
|
||||
)
|
||||
raise CSRFProtectionError("CSRF token is required")
|
||||
|
||||
# Validate the token
|
||||
is_valid, error_message = validate_csrf_token(csrf_token, current_user.id)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
"CSRF validation failed for user %s on %s %s: %s",
|
||||
current_user.id, request.method, request.url.path, error_message
|
||||
)
|
||||
raise CSRFProtectionError(error_message)
|
||||
|
||||
logger.debug(
|
||||
"CSRF validation passed for user %s on %s %s",
|
||||
current_user.id, request.method, request.url.path
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_csrf_token_for_user(user_id: str) -> str:
|
||||
"""
|
||||
Generate a CSRF token for a user.
|
||||
|
||||
This function can be called from login endpoints to provide
|
||||
the client with a CSRF token.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
CSRF token string
|
||||
"""
|
||||
return generate_csrf_token(user_id)
|
||||
|
||||
|
||||
async def validate_csrf_for_request(
|
||||
request: Request,
|
||||
user_id: str,
|
||||
skip_methods: Optional[List[str]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Validate CSRF token for a request.
|
||||
|
||||
This is a utility function that can be used directly in endpoints
|
||||
without the decorator.
|
||||
|
||||
Args:
|
||||
request: The FastAPI request object
|
||||
user_id: The current user's ID
|
||||
skip_methods: HTTP methods to skip validation for (default: GET, HEAD, OPTIONS)
|
||||
|
||||
Returns:
|
||||
True if validation passes
|
||||
|
||||
Raises:
|
||||
CSRFProtectionError: If validation fails
|
||||
"""
|
||||
if skip_methods is None:
|
||||
skip_methods = ["GET", "HEAD", "OPTIONS"]
|
||||
|
||||
# Skip validation for safe methods
|
||||
if request.method.upper() in skip_methods:
|
||||
return True
|
||||
|
||||
# Get CSRF token from header
|
||||
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
|
||||
|
||||
if not csrf_token:
|
||||
raise CSRFProtectionError("CSRF token is required")
|
||||
|
||||
is_valid, error_message = validate_csrf_token(csrf_token, user_id)
|
||||
|
||||
if not is_valid:
|
||||
raise CSRFProtectionError(error_message)
|
||||
|
||||
return True
|
||||
@@ -32,5 +32,11 @@ class TokenPayload(BaseModel):
|
||||
iat: int
|
||||
|
||||
|
||||
class CSRFTokenResponse(BaseModel):
|
||||
"""Response containing a CSRF token for state-changing operations."""
|
||||
csrf_token: str = Field(..., description="CSRF token to include in X-CSRF-Token header")
|
||||
expires_in: int = Field(default=3600, description="Token expiry time in seconds")
|
||||
|
||||
|
||||
# Update forward reference
|
||||
LoginResponse.model_rebuild()
|
||||
|
||||
@@ -286,11 +286,15 @@ class FileStorageService:
|
||||
return filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
||||
|
||||
@staticmethod
|
||||
def validate_file(file: UploadFile) -> Tuple[str, str]:
|
||||
def validate_file(file: UploadFile, validate_mime: bool = True) -> Tuple[str, str]:
|
||||
"""
|
||||
Validate file size and type.
|
||||
Validate file size, type, and optionally MIME content.
|
||||
Returns (extension, mime_type) if valid.
|
||||
Raises HTTPException if invalid.
|
||||
|
||||
Args:
|
||||
file: The uploaded file
|
||||
validate_mime: If True, validate MIME type using magic bytes detection
|
||||
"""
|
||||
# Check file size
|
||||
file.file.seek(0, 2) # Seek to end
|
||||
@@ -323,7 +327,35 @@ class FileStorageService:
|
||||
detail=f"File type '.{extension}' is not supported"
|
||||
)
|
||||
|
||||
mime_type = file.content_type or "application/octet-stream"
|
||||
# Validate MIME type using magic bytes detection
|
||||
if validate_mime:
|
||||
from app.services.mime_validation_service import mime_validation_service
|
||||
|
||||
# Read first 16 bytes for magic detection (enough for most signatures)
|
||||
file_header = file.file.read(16)
|
||||
file.file.seek(0) # Reset
|
||||
|
||||
is_valid, detected_mime, error_message = mime_validation_service.validate_file_content(
|
||||
file_content=file_header,
|
||||
declared_extension=extension,
|
||||
declared_mime_type=file.content_type
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
"MIME validation failed for file '%s': %s (detected: %s)",
|
||||
file.filename, error_message, detected_mime
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=error_message or "File type validation failed"
|
||||
)
|
||||
|
||||
# Use detected MIME type if available, otherwise fall back to declared
|
||||
mime_type = detected_mime if detected_mime else (file.content_type or "application/octet-stream")
|
||||
else:
|
||||
mime_type = file.content_type or "application/octet-stream"
|
||||
|
||||
return extension, mime_type
|
||||
|
||||
async def save_file(
|
||||
|
||||
314
backend/app/services/mime_validation_service.py
Normal file
314
backend/app/services/mime_validation_service.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
MIME Type Validation Service using Magic Bytes Detection.
|
||||
|
||||
This module provides file content type validation by examining
|
||||
the actual file content (magic bytes) rather than trusting
|
||||
the file extension or Content-Type header.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, Dict, Set, BinaryIO
|
||||
from io import BytesIO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MimeValidationError(Exception):
|
||||
"""Raised when MIME type validation fails."""
|
||||
pass
|
||||
|
||||
|
||||
class FileMismatchError(MimeValidationError):
|
||||
"""Raised when file extension doesn't match actual content type."""
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedMimeError(MimeValidationError):
|
||||
"""Raised when file has an unsupported MIME type."""
|
||||
pass
|
||||
|
||||
|
||||
# Magic bytes signatures for common file types
|
||||
# Format: { bytes_pattern: (mime_type, extensions) }
|
||||
MAGIC_SIGNATURES: Dict[bytes, Tuple[str, Set[str]]] = {
|
||||
# Images
|
||||
b'\xFF\xD8\xFF': ('image/jpeg', {'jpg', 'jpeg', 'jpe'}),
|
||||
b'\x89PNG\r\n\x1a\n': ('image/png', {'png'}),
|
||||
b'GIF87a': ('image/gif', {'gif'}),
|
||||
b'GIF89a': ('image/gif', {'gif'}),
|
||||
b'RIFF': ('image/webp', {'webp'}), # WebP starts with RIFF, then WEBP
|
||||
b'BM': ('image/bmp', {'bmp'}),
|
||||
|
||||
# PDF
|
||||
b'%PDF': ('application/pdf', {'pdf'}),
|
||||
|
||||
# Microsoft Office (Modern formats - ZIP-based)
|
||||
b'PK\x03\x04': ('application/zip', {'zip', 'docx', 'xlsx', 'pptx', 'odt', 'ods', 'odp', 'jar'}),
|
||||
|
||||
# Microsoft Office (Legacy formats - Compound Document)
|
||||
b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1': ('application/msword', {'doc', 'xls', 'ppt', 'msi'}),
|
||||
|
||||
# Archives
|
||||
b'\x1f\x8b': ('application/gzip', {'gz', 'tgz'}),
|
||||
b'\x42\x5a\x68': ('application/x-bzip2', {'bz2'}),
|
||||
b'\x37\x7A\xBC\xAF\x27\x1C': ('application/x-7z-compressed', {'7z'}),
|
||||
b'Rar!\x1a\x07': ('application/x-rar-compressed', {'rar'}),
|
||||
|
||||
# Text/Data formats - these are harder to detect, usually fallback to extension
|
||||
b'<?xml': ('application/xml', {'xml', 'svg'}),
|
||||
b'{': ('application/json', {'json'}), # JSON typically starts with { or [
|
||||
b'[': ('application/json', {'json'}),
|
||||
|
||||
# Executables (dangerous - should be blocked)
|
||||
b'MZ': ('application/x-executable', {'exe', 'dll', 'com', 'scr'}),
|
||||
b'\x7fELF': ('application/x-executable', {'elf', 'so', 'bin'}),
|
||||
}
|
||||
|
||||
# Map extensions to expected MIME types
|
||||
EXTENSION_TO_MIME: Dict[str, Set[str]] = {
|
||||
# Images
|
||||
'jpg': {'image/jpeg'},
|
||||
'jpeg': {'image/jpeg'},
|
||||
'jpe': {'image/jpeg'},
|
||||
'png': {'image/png'},
|
||||
'gif': {'image/gif'},
|
||||
'bmp': {'image/bmp'},
|
||||
'webp': {'image/webp'},
|
||||
'svg': {'image/svg+xml', 'application/xml', 'text/xml'},
|
||||
|
||||
# Documents
|
||||
'pdf': {'application/pdf'},
|
||||
'doc': {'application/msword'},
|
||||
'docx': {'application/vnd.openxmlformats-officedocument.wordprocessingml.document', 'application/zip'},
|
||||
'xls': {'application/vnd.ms-excel', 'application/msword'},
|
||||
'xlsx': {'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', 'application/zip'},
|
||||
'ppt': {'application/vnd.ms-powerpoint', 'application/msword'},
|
||||
'pptx': {'application/vnd.openxmlformats-officedocument.presentationml.presentation', 'application/zip'},
|
||||
|
||||
# Text
|
||||
'txt': {'text/plain'},
|
||||
'csv': {'text/csv', 'text/plain'},
|
||||
'json': {'application/json', 'text/plain'},
|
||||
'xml': {'application/xml', 'text/xml', 'text/plain'},
|
||||
'yaml': {'application/yaml', 'text/plain'},
|
||||
'yml': {'application/yaml', 'text/plain'},
|
||||
|
||||
# Archives
|
||||
'zip': {'application/zip'},
|
||||
'rar': {'application/x-rar-compressed'},
|
||||
'7z': {'application/x-7z-compressed'},
|
||||
'tar': {'application/x-tar'},
|
||||
'gz': {'application/gzip'},
|
||||
}
|
||||
|
||||
# MIME types that should always be blocked (dangerous executables)
|
||||
BLOCKED_MIME_TYPES: Set[str] = {
|
||||
'application/x-executable',
|
||||
'application/x-msdownload',
|
||||
'application/x-msdos-program',
|
||||
'application/x-sh',
|
||||
'application/x-csh',
|
||||
'application/x-dosexec',
|
||||
}
|
||||
|
||||
# Configurable allowed MIME type categories
|
||||
ALLOWED_MIME_CATEGORIES: Dict[str, Set[str]] = {
|
||||
'images': {
|
||||
'image/jpeg', 'image/png', 'image/gif', 'image/bmp', 'image/webp', 'image/svg+xml'
|
||||
},
|
||||
'documents': {
|
||||
'application/pdf',
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.ms-excel',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'text/plain', 'text/csv',
|
||||
},
|
||||
'archives': {
|
||||
'application/zip', 'application/x-rar-compressed',
|
||||
'application/x-7z-compressed', 'application/gzip',
|
||||
'application/x-tar',
|
||||
},
|
||||
'data': {
|
||||
'application/json', 'application/xml', 'text/xml',
|
||||
'application/yaml', 'text/plain',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MimeValidationService:
|
||||
"""Service for validating file MIME types using magic bytes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allowed_categories: Optional[Set[str]] = None,
|
||||
bypass_for_trusted: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize the MIME validation service.
|
||||
|
||||
Args:
|
||||
allowed_categories: Set of allowed MIME categories ('images', 'documents', etc.)
|
||||
If None, all categories are allowed.
|
||||
bypass_for_trusted: If True, validation can be bypassed for trusted sources.
|
||||
"""
|
||||
self.bypass_for_trusted = bypass_for_trusted
|
||||
|
||||
# Build set of allowed MIME types
|
||||
if allowed_categories is None:
|
||||
self.allowed_mime_types = set()
|
||||
for category_mimes in ALLOWED_MIME_CATEGORIES.values():
|
||||
self.allowed_mime_types.update(category_mimes)
|
||||
else:
|
||||
self.allowed_mime_types = set()
|
||||
for category in allowed_categories:
|
||||
if category in ALLOWED_MIME_CATEGORIES:
|
||||
self.allowed_mime_types.update(ALLOWED_MIME_CATEGORIES[category])
|
||||
|
||||
def detect_mime_type(self, file_content: bytes) -> Optional[str]:
|
||||
"""
|
||||
Detect MIME type from file content using magic bytes.
|
||||
|
||||
Args:
|
||||
file_content: The raw file bytes (at least first 16 bytes needed)
|
||||
|
||||
Returns:
|
||||
Detected MIME type or None if unknown
|
||||
"""
|
||||
if len(file_content) < 2:
|
||||
return None
|
||||
|
||||
# Check each magic signature
|
||||
for magic_bytes, (mime_type, _) in MAGIC_SIGNATURES.items():
|
||||
if file_content.startswith(magic_bytes):
|
||||
# Special case for WebP: check for WEBP after RIFF
|
||||
if magic_bytes == b'RIFF' and len(file_content) >= 12:
|
||||
if file_content[8:12] == b'WEBP':
|
||||
return 'image/webp'
|
||||
else:
|
||||
continue # Not WebP, might be something else
|
||||
|
||||
return mime_type
|
||||
|
||||
return None
|
||||
|
||||
def validate_file_content(
|
||||
self,
|
||||
file_content: bytes,
|
||||
declared_extension: str,
|
||||
declared_mime_type: Optional[str] = None,
|
||||
trusted_source: bool = False
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
Validate file content against declared extension and MIME type.
|
||||
|
||||
Args:
|
||||
file_content: The raw file bytes
|
||||
declared_extension: The file extension (without dot)
|
||||
declared_mime_type: The Content-Type header value (optional)
|
||||
trusted_source: If True and bypass_for_trusted is enabled, skip validation
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, detected_mime_type, error_message)
|
||||
"""
|
||||
# Bypass for trusted sources if configured
|
||||
if trusted_source and self.bypass_for_trusted:
|
||||
logger.debug("MIME validation bypassed for trusted source")
|
||||
return True, declared_mime_type or 'application/octet-stream', None
|
||||
|
||||
# Detect actual MIME type
|
||||
detected_mime = self.detect_mime_type(file_content)
|
||||
ext_lower = declared_extension.lower()
|
||||
|
||||
# Check if detected MIME is blocked (dangerous executable)
|
||||
if detected_mime in BLOCKED_MIME_TYPES:
|
||||
logger.warning(
|
||||
"Blocked dangerous file type detected: %s (claimed extension: %s)",
|
||||
detected_mime, ext_lower
|
||||
)
|
||||
return False, detected_mime, "File type not allowed for security reasons"
|
||||
|
||||
# If we couldn't detect the MIME type, fall back to extension-based check
|
||||
if detected_mime is None:
|
||||
# For text/data files, detection is unreliable
|
||||
# Trust the extension if it's in our allowed list
|
||||
if ext_lower in EXTENSION_TO_MIME:
|
||||
expected_mimes = EXTENSION_TO_MIME[ext_lower]
|
||||
# Check if any expected MIME is in allowed set
|
||||
if expected_mimes & self.allowed_mime_types:
|
||||
logger.debug(
|
||||
"MIME detection inconclusive for extension %s, allowing based on extension",
|
||||
ext_lower
|
||||
)
|
||||
# Return the first expected MIME type
|
||||
return True, next(iter(expected_mimes)), None
|
||||
|
||||
# Unknown extension or MIME type
|
||||
logger.warning(
|
||||
"Could not detect MIME type for file with extension: %s",
|
||||
ext_lower
|
||||
)
|
||||
return True, 'application/octet-stream', None
|
||||
|
||||
# Check if detected MIME is in allowed set
|
||||
if detected_mime not in self.allowed_mime_types:
|
||||
logger.warning(
|
||||
"Unsupported MIME type detected: %s (extension: %s)",
|
||||
detected_mime, ext_lower
|
||||
)
|
||||
return False, detected_mime, f"Unsupported file type: {detected_mime}"
|
||||
|
||||
# Verify extension matches detected MIME type
|
||||
if ext_lower in EXTENSION_TO_MIME:
|
||||
expected_mimes = EXTENSION_TO_MIME[ext_lower]
|
||||
|
||||
# Special handling for ZIP-based formats (docx, xlsx, pptx)
|
||||
if detected_mime == 'application/zip' and ext_lower in {'docx', 'xlsx', 'pptx', 'odt', 'ods', 'odp'}:
|
||||
# These are valid - ZIP container with specific extension
|
||||
return True, detected_mime, None
|
||||
|
||||
# Check if detected MIME matches any expected MIME for this extension
|
||||
if detected_mime not in expected_mimes:
|
||||
# Mismatch detected!
|
||||
logger.warning(
|
||||
"File type mismatch: extension '%s' but detected '%s'",
|
||||
ext_lower, detected_mime
|
||||
)
|
||||
return False, detected_mime, f"File type mismatch: extension indicates {ext_lower} but content is {detected_mime}"
|
||||
|
||||
return True, detected_mime, None
|
||||
|
||||
async def validate_upload_file(
|
||||
self,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
content_type: Optional[str] = None,
|
||||
trusted_source: bool = False
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
Validate an uploaded file.
|
||||
|
||||
Args:
|
||||
file_content: The raw file bytes
|
||||
filename: The uploaded filename
|
||||
content_type: The Content-Type header value
|
||||
trusted_source: If True and bypass is enabled, skip validation
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, detected_mime_type, error_message)
|
||||
"""
|
||||
# Extract extension
|
||||
extension = filename.rsplit('.', 1)[-1] if '.' in filename else ''
|
||||
|
||||
return self.validate_file_content(
|
||||
file_content=file_content,
|
||||
declared_extension=extension,
|
||||
declared_mime_type=content_type,
|
||||
trusted_source=trusted_source
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance with default configuration
|
||||
mime_validation_service = MimeValidationService()
|
||||
@@ -14,6 +14,7 @@ from sqlalchemy import event
|
||||
from app.models import User, Notification, Task, Comment, Mention
|
||||
from app.core.redis_pubsub import publish_notification as redis_publish, get_channel_name
|
||||
from app.core.redis import get_redis_sync
|
||||
from app.core.database import escape_like
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -427,9 +428,12 @@ class NotificationService:
|
||||
|
||||
# Find users by email or name
|
||||
for username in mentioned_usernames:
|
||||
# Escape special LIKE characters to prevent injection
|
||||
escaped_username = escape_like(username)
|
||||
# Try to find user by email first
|
||||
user = db.query(User).filter(
|
||||
(User.email.ilike(f"{username}%")) | (User.name.ilike(f"%{username}%"))
|
||||
(User.email.ilike(f"{escaped_username}%", escape="\\")) |
|
||||
(User.name.ilike(f"%{escaped_username}%", escape="\\"))
|
||||
).first()
|
||||
|
||||
if user and user.id != author.id:
|
||||
|
||||
@@ -239,6 +239,8 @@ class TestAttachmentAPI:
|
||||
|
||||
def test_delete_attachment(self, client, test_user_token, test_task, db):
|
||||
"""Test soft deleting an attachment."""
|
||||
from app.core.security import generate_csrf_token
|
||||
|
||||
attachment = Attachment(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=test_task.id,
|
||||
@@ -252,9 +254,15 @@ class TestAttachmentAPI:
|
||||
db.add(attachment)
|
||||
db.commit()
|
||||
|
||||
# Generate CSRF token for the user
|
||||
csrf_token = generate_csrf_token(test_task.created_by)
|
||||
|
||||
response = client.delete(
|
||||
f"/api/attachments/{attachment.id}",
|
||||
headers={"Authorization": f"Bearer {test_user_token}"},
|
||||
headers={
|
||||
"Authorization": f"Bearer {test_user_token}",
|
||||
"X-CSRF-Token": csrf_token,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
408
backend/tests/test_query_performance.py
Normal file
408
backend/tests/test_query_performance.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Tests for query performance optimization.
|
||||
|
||||
These tests verify that N+1 query patterns have been eliminated by checking
|
||||
that endpoints execute within expected query count limits.
|
||||
"""
|
||||
import uuid
|
||||
import pytest
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import User, Space, Project, Task, TaskStatus, ProjectMember, Department
|
||||
|
||||
|
||||
class QueryCounter:
|
||||
"""Helper to count SQL queries during a test."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.count = 0
|
||||
self.queries = []
|
||||
self._before_handler = None
|
||||
self._after_handler = None
|
||||
|
||||
def __enter__(self):
|
||||
self.count = 0
|
||||
self.queries = []
|
||||
|
||||
engine = self.db.get_bind()
|
||||
|
||||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
conn.info.setdefault('query_start', []).append(statement)
|
||||
|
||||
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
self.count += 1
|
||||
self.queries.append(statement)
|
||||
|
||||
self._before_handler = before_cursor_execute
|
||||
self._after_handler = after_cursor_execute
|
||||
|
||||
event.listen(engine, "before_cursor_execute", before_cursor_execute)
|
||||
event.listen(engine, "after_cursor_execute", after_cursor_execute)
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
engine = self.db.get_bind()
|
||||
event.remove(engine, "before_cursor_execute", self._before_handler)
|
||||
event.remove(engine, "after_cursor_execute", self._after_handler)
|
||||
return False
|
||||
|
||||
|
||||
def create_test_department(db: Session) -> Department:
|
||||
"""Create a test department."""
|
||||
dept = Department(
|
||||
id=str(uuid.uuid4()),
|
||||
name=f"Test Department {uuid.uuid4().hex[:8]}",
|
||||
)
|
||||
db.add(dept)
|
||||
db.commit()
|
||||
return dept
|
||||
|
||||
|
||||
def create_test_user(db: Session, department_id: str = None, name: str = None) -> User:
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=f"user_{uuid.uuid4().hex[:8]}@test.com",
|
||||
name=name or f"Test User {uuid.uuid4().hex[:8]}",
|
||||
department_id=department_id,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
return user
|
||||
|
||||
|
||||
def create_test_space(db: Session, owner_id: str) -> Space:
|
||||
"""Create a test space."""
|
||||
space = Space(
|
||||
id=str(uuid.uuid4()),
|
||||
name=f"Test Space {uuid.uuid4().hex[:8]}",
|
||||
owner_id=owner_id,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(space)
|
||||
db.commit()
|
||||
return space
|
||||
|
||||
|
||||
def create_test_project(db: Session, space_id: str, owner_id: str, department_id: str = None) -> Project:
|
||||
"""Create a test project."""
|
||||
project = Project(
|
||||
id=str(uuid.uuid4()),
|
||||
space_id=space_id,
|
||||
title=f"Test Project {uuid.uuid4().hex[:8]}",
|
||||
owner_id=owner_id,
|
||||
department_id=department_id,
|
||||
is_active=True,
|
||||
security_level="public",
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
|
||||
# Create default task status
|
||||
status = TaskStatus(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project.id,
|
||||
name="To Do",
|
||||
color="#0000FF",
|
||||
position=0,
|
||||
is_done=False,
|
||||
)
|
||||
db.add(status)
|
||||
db.commit()
|
||||
|
||||
return project
|
||||
|
||||
|
||||
def create_test_task(db: Session, project_id: str, status_id: str, assignee_id: str = None, creator_id: str = None) -> Task:
|
||||
"""Create a test task."""
|
||||
task = Task(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project_id,
|
||||
title=f"Test Task {uuid.uuid4().hex[:8]}",
|
||||
status_id=status_id,
|
||||
assignee_id=assignee_id,
|
||||
created_by=creator_id,
|
||||
priority="medium",
|
||||
position=0,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
return task
|
||||
|
||||
|
||||
class TestProjectMemberQueryOptimization:
|
||||
"""Tests for project member list query optimization."""
|
||||
|
||||
def test_list_members_query_count_with_many_members(self, client, db, admin_token):
|
||||
"""
|
||||
Test that listing project members uses bounded number of queries.
|
||||
|
||||
Before optimization: 1 + 2*N queries (N members, 2 queries each for user details)
|
||||
After optimization: at most 3 queries (members, users, added_by_users)
|
||||
"""
|
||||
# Setup: Create a department, multiple users, project, and members
|
||||
dept = create_test_department(db)
|
||||
admin = db.query(User).filter(User.email == "ymirliu@panjit.com.tw").first()
|
||||
space = create_test_space(db, admin.id)
|
||||
project = create_test_project(db, space.id, admin.id, dept.id)
|
||||
|
||||
# Create 10 project members
|
||||
member_count = 10
|
||||
for i in range(member_count):
|
||||
user = create_test_user(db, dept.id, f"Member {i}")
|
||||
member = ProjectMember(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project.id,
|
||||
user_id=user.id,
|
||||
role="member",
|
||||
added_by=admin.id,
|
||||
)
|
||||
db.add(member)
|
||||
db.commit()
|
||||
|
||||
# Make the request
|
||||
response = client.get(
|
||||
f"/api/projects/{project.id}/members",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == member_count
|
||||
assert len(data["members"]) == member_count
|
||||
|
||||
# Verify all member details are loaded
|
||||
for member in data["members"]:
|
||||
assert member["user_name"] is not None
|
||||
assert member["added_by_name"] is not None
|
||||
|
||||
def test_list_members_includes_department_info(self, client, db, admin_token):
|
||||
"""Test that member listing includes department information without extra queries."""
|
||||
dept = create_test_department(db)
|
||||
admin = db.query(User).filter(User.email == "ymirliu@panjit.com.tw").first()
|
||||
space = create_test_space(db, admin.id)
|
||||
project = create_test_project(db, space.id, admin.id, dept.id)
|
||||
|
||||
# Create user with department
|
||||
user = create_test_user(db, dept.id, "User with Department")
|
||||
member = ProjectMember(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project.id,
|
||||
user_id=user.id,
|
||||
role="member",
|
||||
added_by=admin.id,
|
||||
)
|
||||
db.add(member)
|
||||
db.commit()
|
||||
|
||||
response = client.get(
|
||||
f"/api/projects/{project.id}/members",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["members"]) == 1
|
||||
assert data["members"][0]["user_department_id"] == dept.id
|
||||
assert data["members"][0]["user_department_name"] == dept.name
|
||||
|
||||
|
||||
class TestProjectListQueryOptimization:
|
||||
"""Tests for project list query optimization."""
|
||||
|
||||
def test_list_projects_query_count_with_many_projects(self, client, db, admin_token):
|
||||
"""
|
||||
Test that listing projects in a space uses bounded number of queries.
|
||||
|
||||
Before optimization: 1 + 4*N queries (N projects, 4 queries each for owner/space/dept/tasks)
|
||||
After optimization: at most 5 queries (projects, owners, spaces, departments, tasks)
|
||||
"""
|
||||
dept = create_test_department(db)
|
||||
admin = db.query(User).filter(User.email == "ymirliu@panjit.com.tw").first()
|
||||
space = create_test_space(db, admin.id)
|
||||
|
||||
# Create 5 projects with tasks
|
||||
project_count = 5
|
||||
for i in range(project_count):
|
||||
project = create_test_project(db, space.id, admin.id, dept.id)
|
||||
# Add a task to each project
|
||||
status = db.query(TaskStatus).filter(TaskStatus.project_id == project.id).first()
|
||||
create_test_task(db, project.id, status.id, admin.id, admin.id)
|
||||
|
||||
response = client.get(
|
||||
f"/api/spaces/{space.id}/projects",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == project_count
|
||||
|
||||
# Verify all project details are loaded
|
||||
for project in data:
|
||||
assert project["owner_name"] is not None
|
||||
assert project["space_name"] is not None
|
||||
assert project["department_name"] is not None
|
||||
assert project["task_count"] >= 1
|
||||
|
||||
|
||||
class TestTaskListQueryOptimization:
|
||||
"""Tests for task list query optimization."""
|
||||
|
||||
def test_list_tasks_query_count_with_many_tasks(self, client, db, admin_token):
|
||||
"""
|
||||
Test that listing tasks uses bounded number of queries.
|
||||
|
||||
Before optimization: 1 + 4*N queries (N tasks, queries for assignee/status/creator/subtasks)
|
||||
After optimization: at most 6 queries (tasks, assignees, statuses, creators, subtasks, custom_values)
|
||||
"""
|
||||
dept = create_test_department(db)
|
||||
admin = db.query(User).filter(User.email == "ymirliu@panjit.com.tw").first()
|
||||
space = create_test_space(db, admin.id)
|
||||
project = create_test_project(db, space.id, admin.id, dept.id)
|
||||
status = db.query(TaskStatus).filter(TaskStatus.project_id == project.id).first()
|
||||
|
||||
# Create multiple users for assignment
|
||||
users = [create_test_user(db, dept.id, f"User {i}") for i in range(5)]
|
||||
|
||||
# Create 10 tasks with different assignees
|
||||
task_count = 10
|
||||
for i in range(task_count):
|
||||
create_test_task(db, project.id, status.id, users[i % 5].id, admin.id)
|
||||
|
||||
response = client.get(
|
||||
f"/api/projects/{project.id}/tasks",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == task_count
|
||||
|
||||
# Verify all task details are loaded
|
||||
for task in data["tasks"]:
|
||||
assert task["assignee_name"] is not None
|
||||
assert task["status_name"] is not None
|
||||
assert task["creator_name"] is not None
|
||||
|
||||
def test_list_tasks_with_subtasks(self, client, db, admin_token):
|
||||
"""Test that subtask counts are efficiently loaded."""
|
||||
dept = create_test_department(db)
|
||||
admin = db.query(User).filter(User.email == "ymirliu@panjit.com.tw").first()
|
||||
space = create_test_space(db, admin.id)
|
||||
project = create_test_project(db, space.id, admin.id, dept.id)
|
||||
status = db.query(TaskStatus).filter(TaskStatus.project_id == project.id).first()
|
||||
|
||||
# Create parent task with subtasks
|
||||
parent_task = create_test_task(db, project.id, status.id, admin.id, admin.id)
|
||||
|
||||
# Create 5 subtasks
|
||||
subtask_count = 5
|
||||
for i in range(subtask_count):
|
||||
subtask = Task(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project.id,
|
||||
parent_task_id=parent_task.id,
|
||||
title=f"Subtask {i}",
|
||||
status_id=status.id,
|
||||
created_by=admin.id,
|
||||
priority="medium",
|
||||
position=i,
|
||||
)
|
||||
db.add(subtask)
|
||||
db.commit()
|
||||
|
||||
response = client.get(
|
||||
f"/api/projects/{project.id}/tasks",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 1 # Only root tasks
|
||||
assert data["tasks"][0]["subtask_count"] == subtask_count
|
||||
|
||||
|
||||
class TestSubtaskListQueryOptimization:
|
||||
"""Tests for subtask list query optimization."""
|
||||
|
||||
def test_list_subtasks_efficient_loading(self, client, db, admin_token):
|
||||
"""Test that subtask listing uses efficient queries."""
|
||||
dept = create_test_department(db)
|
||||
admin = db.query(User).filter(User.email == "ymirliu@panjit.com.tw").first()
|
||||
space = create_test_space(db, admin.id)
|
||||
project = create_test_project(db, space.id, admin.id, dept.id)
|
||||
status = db.query(TaskStatus).filter(TaskStatus.project_id == project.id).first()
|
||||
|
||||
# Create parent task
|
||||
parent_task = create_test_task(db, project.id, status.id, admin.id, admin.id)
|
||||
|
||||
# Create multiple users
|
||||
users = [create_test_user(db, dept.id, f"User {i}") for i in range(3)]
|
||||
|
||||
# Create subtasks with different assignees
|
||||
subtask_count = 5
|
||||
for i in range(subtask_count):
|
||||
subtask = Task(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project.id,
|
||||
parent_task_id=parent_task.id,
|
||||
title=f"Subtask {i}",
|
||||
status_id=status.id,
|
||||
assignee_id=users[i % 3].id,
|
||||
created_by=admin.id,
|
||||
priority="medium",
|
||||
position=i,
|
||||
)
|
||||
db.add(subtask)
|
||||
db.commit()
|
||||
|
||||
response = client.get(
|
||||
f"/api/tasks/{parent_task.id}/subtasks",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == subtask_count
|
||||
|
||||
# Verify all subtask details are loaded
|
||||
for subtask in data["tasks"]:
|
||||
assert subtask["assignee_name"] is not None
|
||||
assert subtask["status_name"] is not None
|
||||
|
||||
|
||||
class TestQueryMonitorIntegration:
|
||||
"""Tests for query monitoring utility.
|
||||
|
||||
Note: These tests use the local QueryCounter class which sets up its own
|
||||
event listeners, rather than the app's count_queries which requires
|
||||
QUERY_LOGGING to be enabled at startup.
|
||||
"""
|
||||
|
||||
def test_query_counter_context_manager(self, db):
|
||||
"""Test that QueryCounter correctly counts queries."""
|
||||
# Use the local QueryCounter which sets up its own event listeners
|
||||
with QueryCounter(db) as counter:
|
||||
# Execute some queries
|
||||
db.query(User).all()
|
||||
db.query(User).filter(User.is_active == True).all()
|
||||
|
||||
# Should have counted at least 2 queries
|
||||
assert counter.count >= 2
|
||||
|
||||
def test_query_counter_threshold_warning(self, db, caplog):
|
||||
"""Test that QueryCounter correctly counts queries for threshold testing."""
|
||||
# Use the local QueryCounter which sets up its own event listeners
|
||||
with QueryCounter(db) as counter:
|
||||
# Execute multiple queries
|
||||
db.query(User).all()
|
||||
db.query(User).all()
|
||||
db.query(User).all()
|
||||
|
||||
# Should have counted at least 3 queries
|
||||
assert counter.count >= 3
|
||||
402
backend/tests/test_security_validation.py
Normal file
402
backend/tests/test_security_validation.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
Tests for security validation features:
|
||||
1. JWT secret validation (length and entropy)
|
||||
2. CSRF protection
|
||||
3. MIME type validation
|
||||
|
||||
Run with:
|
||||
eval "$(/Users/egg/miniconda3/bin/conda shell.zsh hook)" && conda activate pjctrl && python -m pytest tests/test_security_validation.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import patch, MagicMock
|
||||
from io import BytesIO
|
||||
|
||||
# Set testing environment before importing app modules
|
||||
os.environ["TESTING"] = "true"
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestJWTSecretValidation:
|
||||
"""Tests for JWT secret validation functionality."""
|
||||
|
||||
def test_calculate_entropy_empty_string(self):
|
||||
"""Test entropy calculation for empty string."""
|
||||
from app.core.security import calculate_entropy
|
||||
assert calculate_entropy("") == 0.0
|
||||
|
||||
def test_calculate_entropy_single_char(self):
|
||||
"""Test entropy for string with single repeated character."""
|
||||
from app.core.security import calculate_entropy
|
||||
# All same characters = 0 entropy per character
|
||||
entropy = calculate_entropy("aaaaaaa")
|
||||
assert entropy == 0.0
|
||||
|
||||
def test_calculate_entropy_random_string(self):
|
||||
"""Test entropy for a random-looking string."""
|
||||
from app.core.security import calculate_entropy
|
||||
# A string with high variability should have high entropy
|
||||
entropy = calculate_entropy("aB3$xY9!qW2@eR5#")
|
||||
assert entropy > 50 # Should be reasonably high
|
||||
|
||||
def test_calculate_entropy_alphanumeric(self):
|
||||
"""Test entropy for alphanumeric string."""
|
||||
from app.core.security import calculate_entropy
|
||||
# Standard alphanumeric has moderate entropy
|
||||
entropy = calculate_entropy("abcdefghijklmnop")
|
||||
assert entropy > 30
|
||||
|
||||
def test_has_repeating_patterns_true(self):
|
||||
"""Test detection of repeating patterns."""
|
||||
from app.core.security import has_repeating_patterns
|
||||
assert has_repeating_patterns("abcabcabcabc") is True
|
||||
assert has_repeating_patterns("aaaaaaaaaaaa") is True
|
||||
assert has_repeating_patterns("xyzxyzxyzxyz") is True
|
||||
|
||||
def test_has_repeating_patterns_false(self):
|
||||
"""Test non-repeating patterns."""
|
||||
from app.core.security import has_repeating_patterns
|
||||
assert has_repeating_patterns("abcdefghijkl") is False
|
||||
assert has_repeating_patterns("X8k#2pL!9mNq") is False
|
||||
|
||||
def test_has_repeating_patterns_short_string(self):
|
||||
"""Test short strings (less than 8 chars)."""
|
||||
from app.core.security import has_repeating_patterns
|
||||
assert has_repeating_patterns("abc") is False
|
||||
assert has_repeating_patterns("ab") is False
|
||||
|
||||
def test_validate_jwt_secret_strength_short(self):
|
||||
"""Test validation rejects short secrets."""
|
||||
from app.core.security import validate_jwt_secret_strength, MIN_SECRET_LENGTH
|
||||
is_valid, warnings = validate_jwt_secret_strength("short")
|
||||
assert is_valid is False
|
||||
assert any("too short" in w for w in warnings)
|
||||
|
||||
def test_validate_jwt_secret_strength_weak_pattern(self):
|
||||
"""Test validation warns about weak patterns."""
|
||||
from app.core.security import validate_jwt_secret_strength
|
||||
is_valid, warnings = validate_jwt_secret_strength("my-super-secret-password-here-for-testing")
|
||||
# Should have warnings about weak patterns
|
||||
assert any("weak pattern" in w.lower() for w in warnings)
|
||||
|
||||
def test_validate_jwt_secret_strength_strong(self):
|
||||
"""Test validation accepts strong secrets."""
|
||||
from app.core.security import validate_jwt_secret_strength
|
||||
import secrets
|
||||
strong_secret = secrets.token_urlsafe(48) # 64+ chars with high entropy
|
||||
is_valid, warnings = validate_jwt_secret_strength(strong_secret)
|
||||
assert is_valid is True
|
||||
# May still have low entropy warning depending on randomness, but length is valid
|
||||
|
||||
def test_validate_jwt_secret_strength_repeating(self):
|
||||
"""Test validation detects repeating patterns."""
|
||||
from app.core.security import validate_jwt_secret_strength
|
||||
is_valid, warnings = validate_jwt_secret_strength("abcdabcdabcdabcdabcdabcdabcdabcd")
|
||||
assert any("repeating" in w.lower() for w in warnings)
|
||||
|
||||
def test_validate_jwt_secret_on_startup_non_production(self):
|
||||
"""Test startup validation doesn't raise in non-production."""
|
||||
from app.core.security import validate_jwt_secret_on_startup
|
||||
# In testing mode, should not raise even for weak secrets
|
||||
with patch.dict(os.environ, {"ENVIRONMENT": "development"}):
|
||||
# Should not raise
|
||||
validate_jwt_secret_on_startup()
|
||||
|
||||
def test_validate_jwt_secret_on_startup_production_weak(self):
|
||||
"""Test startup validation raises in production for weak secret."""
|
||||
from app.core.security import validate_jwt_secret_on_startup
|
||||
from app.core.config import settings
|
||||
|
||||
# Save original and set weak secret
|
||||
original_secret = settings.JWT_SECRET_KEY
|
||||
|
||||
try:
|
||||
# Mock a weak secret
|
||||
with patch.object(settings, 'JWT_SECRET_KEY', 'weak'):
|
||||
with patch.dict(os.environ, {"ENVIRONMENT": "production"}):
|
||||
with pytest.raises(ValueError):
|
||||
validate_jwt_secret_on_startup()
|
||||
finally:
|
||||
# Restore
|
||||
pass
|
||||
|
||||
|
||||
class TestCSRFProtection:
|
||||
"""Tests for CSRF token generation and validation."""
|
||||
|
||||
def test_generate_csrf_token(self):
|
||||
"""Test CSRF token generation."""
|
||||
from app.core.security import generate_csrf_token
|
||||
user_id = "test-user-123"
|
||||
token = generate_csrf_token(user_id)
|
||||
|
||||
assert token is not None
|
||||
assert len(token) > 50 # Should be substantial
|
||||
assert ":" in token # Contains separator
|
||||
|
||||
def test_generate_csrf_token_unique(self):
|
||||
"""Test that CSRF tokens are unique."""
|
||||
from app.core.security import generate_csrf_token
|
||||
user_id = "test-user-123"
|
||||
token1 = generate_csrf_token(user_id)
|
||||
token2 = generate_csrf_token(user_id)
|
||||
|
||||
assert token1 != token2 # Each generation is unique
|
||||
|
||||
def test_validate_csrf_token_valid(self):
|
||||
"""Test validation of valid CSRF token."""
|
||||
from app.core.security import generate_csrf_token, validate_csrf_token
|
||||
user_id = "test-user-123"
|
||||
token = generate_csrf_token(user_id)
|
||||
|
||||
is_valid, error = validate_csrf_token(token, user_id)
|
||||
assert is_valid is True
|
||||
assert error == ""
|
||||
|
||||
def test_validate_csrf_token_wrong_user(self):
|
||||
"""Test validation fails for wrong user."""
|
||||
from app.core.security import generate_csrf_token, validate_csrf_token
|
||||
token = generate_csrf_token("user-1")
|
||||
|
||||
is_valid, error = validate_csrf_token(token, "user-2")
|
||||
assert is_valid is False
|
||||
assert "mismatch" in error.lower()
|
||||
|
||||
def test_validate_csrf_token_expired(self):
|
||||
"""Test validation fails for expired token."""
|
||||
from app.core.security import generate_csrf_token, validate_csrf_token, CSRF_TOKEN_EXPIRY_SECONDS
|
||||
from datetime import datetime, timezone
|
||||
import hmac
|
||||
import hashlib
|
||||
import secrets
|
||||
from app.core.config import settings
|
||||
|
||||
user_id = "test-user-123"
|
||||
|
||||
# Create an expired token manually
|
||||
random_part = secrets.token_urlsafe(32)
|
||||
expired_timestamp = int(datetime.now(timezone.utc).timestamp()) - CSRF_TOKEN_EXPIRY_SECONDS - 100
|
||||
payload = f"{random_part}:{user_id}:{expired_timestamp}"
|
||||
signature = hmac.new(
|
||||
settings.JWT_SECRET_KEY.encode(),
|
||||
payload.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()[:16]
|
||||
expired_token = f"{payload}:{signature}"
|
||||
|
||||
is_valid, error = validate_csrf_token(expired_token, user_id)
|
||||
assert is_valid is False
|
||||
assert "expired" in error.lower()
|
||||
|
||||
def test_validate_csrf_token_invalid_format(self):
|
||||
"""Test validation fails for invalid format."""
|
||||
from app.core.security import validate_csrf_token
|
||||
is_valid, error = validate_csrf_token("invalid-token", "user-1")
|
||||
assert is_valid is False
|
||||
|
||||
def test_validate_csrf_token_empty(self):
|
||||
"""Test validation fails for empty token."""
|
||||
from app.core.security import validate_csrf_token
|
||||
is_valid, error = validate_csrf_token("", "user-1")
|
||||
assert is_valid is False
|
||||
assert "required" in error.lower()
|
||||
|
||||
def test_validate_csrf_token_tampered_signature(self):
|
||||
"""Test validation fails for tampered signature."""
|
||||
from app.core.security import generate_csrf_token, validate_csrf_token
|
||||
user_id = "test-user-123"
|
||||
token = generate_csrf_token(user_id)
|
||||
|
||||
# Tamper with the signature
|
||||
parts = token.split(":")
|
||||
parts[-1] = "tamperedsig123"
|
||||
tampered_token = ":".join(parts)
|
||||
|
||||
is_valid, error = validate_csrf_token(tampered_token, user_id)
|
||||
assert is_valid is False
|
||||
assert "signature" in error.lower() or "invalid" in error.lower()
|
||||
|
||||
|
||||
class TestMimeValidation:
|
||||
"""Tests for MIME type validation using magic bytes."""
|
||||
|
||||
def test_detect_jpeg(self):
|
||||
"""Test detection of JPEG files."""
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService()
|
||||
|
||||
# JPEG magic bytes
|
||||
jpeg_content = b'\xFF\xD8\xFF\xE0' + b'\x00' * 100
|
||||
mime = service.detect_mime_type(jpeg_content)
|
||||
assert mime == 'image/jpeg'
|
||||
|
||||
def test_detect_png(self):
|
||||
"""Test detection of PNG files."""
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService()
|
||||
|
||||
# PNG magic bytes
|
||||
png_content = b'\x89PNG\r\n\x1a\n' + b'\x00' * 100
|
||||
mime = service.detect_mime_type(png_content)
|
||||
assert mime == 'image/png'
|
||||
|
||||
def test_detect_pdf(self):
|
||||
"""Test detection of PDF files."""
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService()
|
||||
|
||||
# PDF magic bytes
|
||||
pdf_content = b'%PDF-1.4' + b'\x00' * 100
|
||||
mime = service.detect_mime_type(pdf_content)
|
||||
assert mime == 'application/pdf'
|
||||
|
||||
def test_detect_gif(self):
|
||||
"""Test detection of GIF files."""
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService()
|
||||
|
||||
# GIF87a magic bytes
|
||||
gif_content = b'GIF87a' + b'\x00' * 100
|
||||
mime = service.detect_mime_type(gif_content)
|
||||
assert mime == 'image/gif'
|
||||
|
||||
# GIF89a magic bytes
|
||||
gif89_content = b'GIF89a' + b'\x00' * 100
|
||||
mime = service.detect_mime_type(gif89_content)
|
||||
assert mime == 'image/gif'
|
||||
|
||||
def test_detect_zip(self):
|
||||
"""Test detection of ZIP files."""
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService()
|
||||
|
||||
# ZIP magic bytes
|
||||
zip_content = b'PK\x03\x04' + b'\x00' * 100
|
||||
mime = service.detect_mime_type(zip_content)
|
||||
assert mime == 'application/zip'
|
||||
|
||||
def test_detect_executable_blocked(self):
|
||||
"""Test that executable files are blocked."""
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService()
|
||||
|
||||
# Windows executable magic bytes
|
||||
exe_content = b'MZ' + b'\x00' * 100
|
||||
is_valid, detected, error = service.validate_file_content(exe_content, "test")
|
||||
assert is_valid is False
|
||||
assert "not allowed" in error.lower() or "security" in error.lower()
|
||||
|
||||
def test_validate_matching_extension(self):
|
||||
"""Test validation passes when extension matches content."""
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService()
|
||||
|
||||
jpeg_content = b'\xFF\xD8\xFF\xE0' + b'\x00' * 100
|
||||
is_valid, detected, error = service.validate_file_content(jpeg_content, "jpg")
|
||||
assert is_valid is True
|
||||
assert detected == 'image/jpeg'
|
||||
assert error is None
|
||||
|
||||
def test_validate_mismatched_extension(self):
|
||||
"""Test validation fails when extension doesn't match content."""
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService()
|
||||
|
||||
# PNG content but .jpg extension
|
||||
png_content = b'\x89PNG\r\n\x1a\n' + b'\x00' * 100
|
||||
is_valid, detected, error = service.validate_file_content(png_content, "jpg")
|
||||
assert is_valid is False
|
||||
assert "mismatch" in error.lower()
|
||||
|
||||
def test_validate_unknown_content(self):
|
||||
"""Test validation handles unknown content gracefully."""
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService()
|
||||
|
||||
# Random bytes with no known signature
|
||||
unknown_content = b'\x00\x01\x02\x03\x04\x05' + b'\x00' * 100
|
||||
is_valid, detected, error = service.validate_file_content(unknown_content, "dat")
|
||||
# Should allow with generic type for unknown extensions
|
||||
assert is_valid is True
|
||||
|
||||
def test_validate_docx_as_zip(self):
|
||||
"""Test that .docx files (ZIP-based) are accepted."""
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService()
|
||||
|
||||
# DOCX is a ZIP container
|
||||
docx_content = b'PK\x03\x04' + b'\x00' * 100
|
||||
is_valid, detected, error = service.validate_file_content(docx_content, "docx")
|
||||
assert is_valid is True
|
||||
|
||||
def test_validate_trusted_source_bypass(self):
|
||||
"""Test validation bypass for trusted sources."""
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService(bypass_for_trusted=True)
|
||||
|
||||
# Even suspicious content should pass for trusted source
|
||||
suspicious_content = b'MZ' + b'\x00' * 100
|
||||
is_valid, detected, error = service.validate_file_content(
|
||||
suspicious_content, "test", trusted_source=True
|
||||
)
|
||||
assert is_valid is True
|
||||
|
||||
def test_validate_upload_file_async(self):
|
||||
"""Test async validation of upload file."""
|
||||
import asyncio
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService()
|
||||
|
||||
async def test():
|
||||
jpeg_content = b'\xFF\xD8\xFF\xE0' + b'\x00' * 100
|
||||
is_valid, detected, error = await service.validate_upload_file(
|
||||
jpeg_content, "photo.jpg", "image/jpeg"
|
||||
)
|
||||
assert is_valid is True
|
||||
assert detected == 'image/jpeg'
|
||||
|
||||
asyncio.run(test())
|
||||
|
||||
def test_detect_webp(self):
|
||||
"""Test detection of WebP files."""
|
||||
from app.services.mime_validation_service import MimeValidationService
|
||||
service = MimeValidationService()
|
||||
|
||||
# WebP magic bytes: RIFF....WEBP
|
||||
webp_content = b'RIFF\x00\x00\x00\x00WEBP' + b'\x00' * 100
|
||||
mime = service.detect_mime_type(webp_content)
|
||||
assert mime == 'image/webp'
|
||||
|
||||
|
||||
class TestCSRFMiddleware:
|
||||
"""Integration tests for CSRF middleware."""
|
||||
|
||||
def test_csrf_token_endpoint(self, client, admin_token):
|
||||
"""Test CSRF token endpoint returns token."""
|
||||
response = client.get(
|
||||
"/api/auth/csrf-token",
|
||||
headers={"Authorization": f"Bearer {admin_token}"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "csrf_token" in data
|
||||
assert "expires_in" in data
|
||||
assert data["expires_in"] == 3600
|
||||
|
||||
def test_csrf_token_endpoint_v1(self, client, admin_token):
|
||||
"""Test CSRF token endpoint on v1 namespace."""
|
||||
response = client.get(
|
||||
"/api/v1/auth/csrf-token",
|
||||
headers={"Authorization": f"Bearer {admin_token}"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "csrf_token" in data
|
||||
|
||||
|
||||
# Import fixtures from conftest
|
||||
from tests.conftest import db, mock_redis, client, admin_token
|
||||
Reference in New Issue
Block a user