feat: implement 5 QA-driven security and quality proposals

Implemented proposals from comprehensive QA review:

1. extend-csrf-protection
   - Add POST to CSRF protected methods in frontend
   - Global CSRF middleware for all state-changing operations
   - Update tests with CSRF token fixtures

2. tighten-cors-websocket-security
   - Replace wildcard CORS with explicit method/header lists
   - Disable query parameter auth in production (code 4002)
   - Add per-user WebSocket connection limit (max 5, code 4005)

3. shorten-jwt-expiry
   - Reduce JWT expiry from 7 days to 60 minutes
   - Add refresh token support with 7-day expiry
   - Implement token rotation on refresh
   - Frontend auto-refresh when token near expiry (<5 min)

4. fix-frontend-quality
   - Add React.lazy() code splitting for all pages
   - Fix useCallback dependency arrays (Dashboard, Comments)
   - Add localStorage data validation in AuthContext
   - Complete i18n for AttachmentUpload component

5. enhance-backend-validation
   - Add SecurityAuditMiddleware for access denied logging
   - Add ErrorSanitizerMiddleware for production error messages
   - Protect /health/detailed with admin authentication
   - Add input length validation (comment 5000, desc 10000)

All 521 backend tests passing. Frontend builds successfully.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-12 23:19:05 +08:00
parent df50d5e7f8
commit 35c90fe76b
48 changed files with 2132 additions and 403 deletions

View File

@@ -3,12 +3,28 @@ from sqlalchemy.orm import Session
from app.core.config import settings from app.core.config import settings
from app.core.database import get_db from app.core.database import get_db
from app.core.security import create_access_token, create_token_payload from app.core.security import (
create_access_token,
create_token_payload,
generate_refresh_token,
store_refresh_token,
validate_refresh_token,
invalidate_refresh_token,
invalidate_all_user_refresh_tokens,
decode_refresh_token_user_id,
)
from app.core.redis import get_redis from app.core.redis import get_redis
from app.core.rate_limiter import limiter from app.core.rate_limiter import limiter
from app.models.user import User from app.models.user import User
from app.models.audit_log import AuditAction from app.models.audit_log import AuditAction
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo, CSRFTokenResponse from app.schemas.auth import (
LoginRequest,
LoginResponse,
UserInfo,
CSRFTokenResponse,
RefreshTokenRequest,
RefreshTokenResponse,
)
from app.services.auth_client import ( from app.services.auth_client import (
verify_credentials, verify_credentials,
AuthAPIError, AuthAPIError,
@@ -119,6 +135,9 @@ async def login(
# Create access token # Create access token
access_token = create_access_token(token_data) access_token = create_access_token(token_data)
# Generate refresh token
refresh_token = generate_refresh_token()
# Store session in Redis (sync with JWT expiry) # Store session in Redis (sync with JWT expiry)
redis_client.setex( redis_client.setex(
f"session:{user.id}", f"session:{user.id}",
@@ -126,6 +145,9 @@ async def login(
access_token, access_token,
) )
# Store refresh token in Redis with user binding
store_refresh_token(redis_client, user.id, refresh_token)
# Log successful login # Log successful login
AuditService.log_event( AuditService.log_event(
db=db, db=db,
@@ -141,6 +163,8 @@ async def login(
return LoginResponse( return LoginResponse(
access_token=access_token, access_token=access_token,
refresh_token=refresh_token,
expires_in=settings.JWT_EXPIRE_MINUTES * 60,
user=UserInfo( user=UserInfo(
id=user.id, id=user.id,
email=user.email, email=user.email,
@@ -158,14 +182,114 @@ async def logout(
redis_client=Depends(get_redis), redis_client=Depends(get_redis),
): ):
""" """
Logout user and invalidate session. Logout user and invalidate session and all refresh tokens.
""" """
# Remove session from Redis # Remove session from Redis
redis_client.delete(f"session:{current_user.id}") redis_client.delete(f"session:{current_user.id}")
# Invalidate all refresh tokens for this user
invalidate_all_user_refresh_tokens(redis_client, current_user.id)
return {"detail": "Successfully logged out"} return {"detail": "Successfully logged out"}
@router.post("/refresh", response_model=RefreshTokenResponse)
@limiter.limit("10/minute")
async def refresh_access_token(
request: Request,
refresh_request: RefreshTokenRequest,
db: Session = Depends(get_db),
redis_client=Depends(get_redis),
):
"""
Refresh access token using a valid refresh token.
This endpoint implements refresh token rotation:
- Validates the provided refresh token
- Issues a new access token
- Issues a new refresh token (rotating the old one)
- Invalidates the old refresh token
This provides enhanced security by ensuring refresh tokens are single-use.
"""
old_refresh_token = refresh_request.refresh_token
# Find the user ID associated with this refresh token
user_id = decode_refresh_token_user_id(old_refresh_token, redis_client)
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
headers={"WWW-Authenticate": "Bearer"},
)
# Validate the refresh token is still valid and bound to this user
if not validate_refresh_token(redis_client, user_id, old_refresh_token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
headers={"WWW-Authenticate": "Bearer"},
)
# Get user from database
user = db.query(User).filter(User.id == user_id).first()
if user is None:
# Invalidate the token since user no longer exists
invalidate_refresh_token(redis_client, user_id, old_refresh_token)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
# Invalidate all tokens for disabled user
invalidate_all_user_refresh_tokens(redis_client, user_id)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User account is disabled",
)
# Invalidate the old refresh token (rotation)
invalidate_refresh_token(redis_client, user_id, old_refresh_token)
# Get role name
role_name = user.role.name if user.role else None
# Create new token payload
token_data = create_token_payload(
user_id=user.id,
email=user.email,
role=role_name,
department_id=user.department_id,
is_system_admin=user.is_system_admin,
)
# Create new access token
new_access_token = create_access_token(token_data)
# Generate new refresh token (rotation)
new_refresh_token = generate_refresh_token()
# Store new session in Redis
redis_client.setex(
f"session:{user.id}",
settings.JWT_EXPIRE_MINUTES * 60,
new_access_token,
)
# Store new refresh token
store_refresh_token(redis_client, user.id, new_refresh_token)
return RefreshTokenResponse(
access_token=new_access_token,
refresh_token=new_refresh_token,
expires_in=settings.JWT_EXPIRE_MINUTES * 60,
)
@router.get("/me", response_model=UserInfo) @router.get("/me", response_model=UserInfo)
async def get_current_user_info( async def get_current_user_info(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from app.core import database from app.core import database
from app.core.security import decode_access_token from app.core.security import decode_access_token
from app.core.redis import get_redis_sync from app.core.redis import get_redis_sync
from app.core.config import settings
from app.models import User, Notification, Project from app.models import User, Notification, Project
from app.services.websocket_manager import manager from app.services.websocket_manager import manager
from app.core.redis_pubsub import NotificationSubscriber, ProjectTaskSubscriber from app.core.redis_pubsub import NotificationSubscriber, ProjectTaskSubscriber
@@ -72,14 +73,24 @@ async def authenticate_websocket(
Supports two authentication methods: Supports two authentication methods:
1. First message authentication (preferred, more secure) 1. First message authentication (preferred, more secure)
- Client sends: {"type": "auth", "token": "<jwt_token>"} - Client sends: {"type": "auth", "token": "<jwt_token>"}
2. Query parameter authentication (deprecated, for backward compatibility) 2. Query parameter authentication (disabled in production, for backward compatibility only)
- Client connects with: ?token=<jwt_token> - Client connects with: ?token=<jwt_token>
Returns: Returns:
Tuple of (user_id, error_reason). user_id is None if authentication fails. Tuple of (user_id, error_reason). user_id is None if authentication fails.
Error reasons: "invalid_token", "invalid_message", "missing_token",
"timeout", "error", "query_auth_disabled"
""" """
# If token provided via query parameter (backward compatibility) # If token provided via query parameter (backward compatibility)
if query_token: if query_token:
# Reject query parameter auth in production for security
if settings.ENVIRONMENT == "production":
logger.warning(
"WebSocket query parameter authentication attempted in production environment. "
"This is disabled for security reasons."
)
return None, "query_auth_disabled"
logger.warning( logger.warning(
"WebSocket authentication via query parameter is deprecated. " "WebSocket authentication via query parameter is deprecated. "
"Please use first-message authentication for better security." "Please use first-message authentication for better security."
@@ -195,9 +206,21 @@ async def websocket_notifications(
user_id, error_reason = await authenticate_websocket(websocket, token) user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None: if user_id is None:
if error_reason == "invalid_token": if error_reason == "query_auth_disabled":
await websocket.send_json({"type": "error", "message": "Query parameter auth disabled in production"})
await websocket.close(code=4002, reason="Query parameter auth disabled in production")
elif error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"}) await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token") await websocket.close(code=4001, reason="Invalid or expired token")
else:
await websocket.close(code=4001, reason="Invalid or expired token")
return
# Check connection limit before accepting
can_connect, reject_reason = await manager.check_connection_limit(user_id)
if not can_connect:
await websocket.send_json({"type": "error", "message": reject_reason})
await websocket.close(code=4005, reason=reject_reason)
return return
await manager.connect(websocket, user_id) await manager.connect(websocket, user_id)
@@ -394,9 +417,21 @@ async def websocket_project_sync(
user_id, error_reason = await authenticate_websocket(websocket, token) user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None: if user_id is None:
if error_reason == "invalid_token": if error_reason == "query_auth_disabled":
await websocket.send_json({"type": "error", "message": "Query parameter auth disabled in production"})
await websocket.close(code=4002, reason="Query parameter auth disabled in production")
elif error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"}) await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token") await websocket.close(code=4001, reason="Invalid or expired token")
else:
await websocket.close(code=4001, reason="Invalid or expired token")
return
# Check connection limit before accepting
can_connect, reject_reason = await manager.check_connection_limit(user_id)
if not can_connect:
await websocket.send_json({"type": "error", "message": reject_reason})
await websocket.close(code=4005, reason=reject_reason)
return return
# Verify user has access to the project # Verify user has access to the project

View File

@@ -28,7 +28,8 @@ class Settings(BaseSettings):
# JWT - Must be set in environment, no default allowed # JWT - Must be set in environment, no default allowed
JWT_SECRET_KEY: str = "" JWT_SECRET_KEY: str = ""
JWT_ALGORITHM: str = "HS256" JWT_ALGORITHM: str = "HS256"
JWT_EXPIRE_MINUTES: int = 10080 # 7 days JWT_EXPIRE_MINUTES: int = 60 # 1 hour (short-lived access token)
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # Refresh token valid for 7 days
@field_validator("JWT_SECRET_KEY") @field_validator("JWT_SECRET_KEY")
@classmethod @classmethod
@@ -127,6 +128,12 @@ class Settings(BaseSettings):
QUERY_LOGGING: bool = False # Enable SQLAlchemy query logging QUERY_LOGGING: bool = False # Enable SQLAlchemy query logging
QUERY_COUNT_THRESHOLD: int = 10 # Warn when query count exceeds this threshold QUERY_COUNT_THRESHOLD: int = 10 # Warn when query count exceeds this threshold
# Environment
ENVIRONMENT: str = "development" # Options: development, staging, production
# WebSocket Settings
MAX_WEBSOCKET_CONNECTIONS_PER_USER: int = 5 # Maximum concurrent WebSocket connections per user
class Config: class Config:
env_file = ".env" env_file = ".env"
case_sensitive = True case_sensitive = True

View File

@@ -356,3 +356,140 @@ def create_token_payload(
"department_id": department_id, "department_id": department_id,
"is_system_admin": is_system_admin, "is_system_admin": is_system_admin,
} }
# Refresh Token Functions
REFRESH_TOKEN_BYTES = 32
def generate_refresh_token() -> str:
"""
Generate a cryptographically secure refresh token.
Returns:
A URL-safe base64-encoded random token
"""
return secrets.token_urlsafe(REFRESH_TOKEN_BYTES)
def get_refresh_token_key(user_id: str, token: str) -> str:
"""
Generate the Redis key for a refresh token.
Args:
user_id: The user's ID
token: The refresh token
Returns:
Redis key string
"""
# Hash the token to avoid storing it directly as a key
token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
return f"refresh_token:{user_id}:{token_hash}"
def store_refresh_token(redis_client, user_id: str, token: str) -> None:
"""
Store a refresh token in Redis with user binding.
Args:
redis_client: Redis client instance
user_id: The user's ID
token: The refresh token to store
"""
key = get_refresh_token_key(user_id, token)
# Store with TTL based on REFRESH_TOKEN_EXPIRE_DAYS
ttl_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
redis_client.setex(key, ttl_seconds, user_id)
def validate_refresh_token(redis_client, user_id: str, token: str) -> bool:
"""
Validate a refresh token exists in Redis and is bound to the user.
Args:
redis_client: Redis client instance
user_id: The expected user ID
token: The refresh token to validate
Returns:
True if token is valid, False otherwise
"""
key = get_refresh_token_key(user_id, token)
stored_user_id = redis_client.get(key)
if stored_user_id is None:
return False
# Handle Redis bytes type
if isinstance(stored_user_id, bytes):
stored_user_id = stored_user_id.decode("utf-8")
return stored_user_id == user_id
def invalidate_refresh_token(redis_client, user_id: str, token: str) -> bool:
"""
Invalidate (delete) a refresh token from Redis.
Args:
redis_client: Redis client instance
user_id: The user's ID
token: The refresh token to invalidate
Returns:
True if token was deleted, False if it didn't exist
"""
key = get_refresh_token_key(user_id, token)
result = redis_client.delete(key)
return result > 0 if isinstance(result, int) else bool(result)
def invalidate_all_user_refresh_tokens(redis_client, user_id: str) -> int:
"""
Invalidate all refresh tokens for a user.
Args:
redis_client: Redis client instance
user_id: The user's ID
Returns:
Number of tokens invalidated
"""
pattern = f"refresh_token:{user_id}:*"
count = 0
for key in redis_client.scan_iter(match=pattern):
redis_client.delete(key)
count += 1
return count
def decode_refresh_token_user_id(token: str, redis_client) -> Optional[str]:
"""
Find the user ID associated with a refresh token by searching Redis.
This is used when we only have the token and need to find which user it belongs to.
Note: This is less efficient but necessary for refresh token validation when
the user_id is not provided in the request.
Args:
token: The refresh token
redis_client: Redis client instance
Returns:
User ID if found, None otherwise
"""
# We need to search for the token across all users
# This is done by checking the token hash pattern
token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
pattern = f"refresh_token:*:{token_hash}"
for key in redis_client.scan_iter(match=pattern):
# Extract user_id from key format: refresh_token:{user_id}:{token_hash}
if isinstance(key, bytes):
key = key.decode("utf-8")
parts = key.split(":")
if len(parts) == 3:
return parts[1]
return None

View File

@@ -1,7 +1,7 @@
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from fastapi import FastAPI, Request, APIRouter from fastapi import FastAPI, Request, APIRouter, Depends
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from slowapi import _rate_limit_exceeded_handler from slowapi import _rate_limit_exceeded_handler
@@ -9,6 +9,9 @@ from slowapi.errors import RateLimitExceeded
from sqlalchemy import text from sqlalchemy import text
from app.middleware.audit import AuditMiddleware from app.middleware.audit import AuditMiddleware
from app.middleware.csrf import CSRFMiddleware
from app.middleware.security_audit import SecurityAuditMiddleware
from app.middleware.error_sanitizer import ErrorSanitizerMiddleware
from app.core.scheduler import start_scheduler, shutdown_scheduler, scheduler from app.core.scheduler import start_scheduler, shutdown_scheduler, scheduler
from app.core.rate_limiter import limiter from app.core.rate_limiter import limiter
from app.core.deprecation import DeprecationMiddleware from app.core.deprecation import DeprecationMiddleware
@@ -61,6 +64,8 @@ from app.core.database import get_pool_status, engine
from app.core.redis import redis_client from app.core.redis import redis_client
from app.services.notification_service import get_redis_fallback_status from app.services.notification_service import get_redis_fallback_status
from app.services.file_storage_service import file_storage_service from app.services.file_storage_service import file_storage_service
from app.middleware.auth import require_system_admin
from app.models import User
app = FastAPI( app = FastAPI(
title="Project Control API", title="Project Control API",
@@ -73,18 +78,28 @@ app = FastAPI(
app.state.limiter = limiter app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# CORS middleware # CORS middleware - Explicit methods and headers for security
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.CORS_ORIGINS, allow_origins=settings.CORS_ORIGINS,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"], allow_headers=["Content-Type", "Authorization", "X-CSRF-Token", "X-Request-ID"],
) )
# Error sanitizer middleware - sanitizes error messages in production
# Must be first in the chain to intercept all error responses
app.add_middleware(ErrorSanitizerMiddleware)
# Audit middleware - extracts request metadata for audit logging # Audit middleware - extracts request metadata for audit logging
app.add_middleware(AuditMiddleware) app.add_middleware(AuditMiddleware)
# Security audit middleware - logs 401/403 responses to audit trail
app.add_middleware(SecurityAuditMiddleware)
# CSRF middleware - validates CSRF tokens for state-changing requests
app.add_middleware(CSRFMiddleware)
# Deprecation middleware - adds deprecation headers to legacy /api/ routes # Deprecation middleware - adds deprecation headers to legacy /api/ routes
app.add_middleware(DeprecationMiddleware) app.add_middleware(DeprecationMiddleware)
@@ -252,14 +267,20 @@ async def readiness_check():
@app.get("/health/detailed") @app.get("/health/detailed")
async def detailed_health_check(): async def detailed_health_check(
"""Detailed health check endpoint. current_user: User = Depends(require_system_admin),
):
"""Detailed health check endpoint (requires system admin).
Returns comprehensive status of all system components: Returns comprehensive status of all system components:
- database: Connection pool status and connectivity - database: Connection pool status and connectivity
- redis: Connection status and fallback queue status - redis: Connection status and fallback queue status
- storage: File storage validation status - storage: File storage validation status
- scheduler: Background job scheduler status - scheduler: Background job scheduler status
Note: This endpoint requires system admin authentication because it exposes
sensitive infrastructure details including connection pool statistics and
internal service states.
""" """
db_health = check_database_health() db_health = check_database_health()
redis_health = check_redis_health() redis_health = check_redis_health()

View File

@@ -1,38 +1,55 @@
""" """
CSRF (Cross-Site Request Forgery) Protection Middleware. CSRF (Cross-Site Request Forgery) Protection Middleware.
This module provides CSRF protection for sensitive state-changing operations. This module provides CSRF protection for all state-changing operations.
It validates CSRF tokens for specified protected endpoints. It validates CSRF tokens globally for authenticated POST, PUT, PATCH, DELETE requests.
""" """
from fastapi import Request, HTTPException, status, Depends from starlette.middleware.base import BaseHTTPMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from starlette.requests import Request
from typing import Optional, Callable, List from starlette.responses import JSONResponse
from fastapi import HTTPException, status
from typing import Optional, Callable, List, Set
from functools import wraps from functools import wraps
import logging import logging
from app.core.security import validate_csrf_token, generate_csrf_token from app.core.security import validate_csrf_token, generate_csrf_token, decode_access_token
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Header name for CSRF token # Header name for CSRF token
CSRF_TOKEN_HEADER = "X-CSRF-Token" CSRF_TOKEN_HEADER = "X-CSRF-Token"
# List of endpoint patterns that require CSRF protection # Methods that require CSRF protection (all state-changing operations)
# These are sensitive state-changing operations CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
CSRF_PROTECTED_PATTERNS = [
# User operations
"/api/v1/users/{user_id}/admin", # Admin status change
"/api/users/{user_id}/admin", # Legacy
# Password changes would go here if implemented
# Delete operations
"/api/attachments/{attachment_id}", # DELETE method
"/api/tasks/{task_id}", # DELETE method (soft delete)
"/api/projects/{project_id}", # DELETE method
]
# Methods that require CSRF protection # Safe methods that don't require CSRF protection
CSRF_PROTECTED_METHODS = ["DELETE", "PUT", "PATCH"] CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
# Public endpoints that don't require CSRF validation
# These are endpoints that either:
# 1. Don't require authentication (login, health checks)
# 2. Are not state-changing in a security-sensitive way
CSRF_EXCLUDED_PATHS: Set[str] = {
# Authentication endpoints (unauthenticated)
"/api/auth/login",
"/api/v1/auth/login",
# Health check endpoints (unauthenticated)
"/health",
"/health/live",
"/health/ready",
"/health/detailed",
# WebSocket endpoints (use different auth mechanism)
"/api/ws",
"/ws",
}
# Path prefixes that are excluded from CSRF validation
CSRF_EXCLUDED_PREFIXES: List[str] = [
# WebSocket paths
"/api/ws/",
"/ws/",
]
class CSRFProtectionError(HTTPException): class CSRFProtectionError(HTTPException):
@@ -45,6 +62,114 @@ class CSRFProtectionError(HTTPException):
) )
class CSRFMiddleware(BaseHTTPMiddleware):
"""
Global CSRF protection middleware.
Validates CSRF tokens for all authenticated state-changing requests
(POST, PUT, PATCH, DELETE) except for explicitly excluded endpoints.
"""
async def dispatch(self, request: Request, call_next):
"""Process the request and validate CSRF token if needed."""
method = request.method.upper()
path = request.url.path
# Skip CSRF validation for safe methods
if method in CSRF_SAFE_METHODS:
return await call_next(request)
# Skip CSRF validation for excluded paths
if self._is_excluded_path(path):
logger.debug("CSRF validation skipped for excluded path: %s", path)
return await call_next(request)
# Try to extract user ID from the Authorization header
user_id = self._extract_user_id_from_token(request)
# If no user ID (unauthenticated request), skip CSRF validation
# The authentication middleware will handle unauthorized access
if user_id is None:
logger.debug(
"CSRF validation skipped (no auth token): %s %s",
method, path
)
return await call_next(request)
# Get CSRF token from header
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
if not csrf_token:
logger.warning(
"CSRF validation failed: Missing token for user %s on %s %s",
user_id, method, path
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": "CSRF token is required"}
)
# Validate the token
is_valid, error_message = validate_csrf_token(csrf_token, user_id)
if not is_valid:
logger.warning(
"CSRF validation failed for user %s on %s %s: %s",
user_id, method, path, error_message
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": error_message}
)
logger.debug(
"CSRF validation passed for user %s on %s %s",
user_id, method, path
)
return await call_next(request)
def _is_excluded_path(self, path: str) -> bool:
"""Check if the path is excluded from CSRF validation."""
# Check exact path matches
if path in CSRF_EXCLUDED_PATHS:
return True
# Check path prefixes
for prefix in CSRF_EXCLUDED_PREFIXES:
if path.startswith(prefix):
return True
return False
def _extract_user_id_from_token(self, request: Request) -> Optional[str]:
"""
Extract user ID from the Authorization header.
Returns None if no valid token is found (unauthenticated request).
"""
auth_header = request.headers.get("Authorization")
if not auth_header:
return None
# Parse Bearer token
parts = auth_header.split()
if len(parts) != 2 or parts[0].lower() != "bearer":
return None
token = parts[1]
# Decode the token to get user ID
try:
payload = decode_access_token(token)
if payload is None:
return None
return payload.get("sub")
except Exception as e:
logger.debug("Failed to decode token for CSRF validation: %s", e)
return None
def require_csrf_token(func: Callable) -> Callable: def require_csrf_token(func: Callable) -> Callable:
""" """
Decorator to require CSRF token validation for an endpoint. Decorator to require CSRF token validation for an endpoint.

View File

@@ -0,0 +1,187 @@
"""Error message sanitization middleware for production environments.
This middleware intercepts error responses and sanitizes them to prevent
information disclosure in production environments. Detailed error messages
are only shown when DEBUG mode is enabled.
"""
import json
import logging
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from app.core.config import settings
logger = logging.getLogger(__name__)
# Generic error messages for production
GENERIC_ERROR_MESSAGES = {
400: "Bad Request",
401: "Authentication required",
403: "Access denied",
404: "Resource not found",
405: "Method not allowed",
409: "Request conflict",
422: "Validation error",
429: "Too many requests",
500: "Internal server error",
502: "Service unavailable",
503: "Service temporarily unavailable",
504: "Request timeout",
}
# Status codes that should preserve their original message even in production
# These are typically user-facing validation errors that don't leak sensitive info
PRESERVE_MESSAGE_CODES = {
400, # Bad request - users need to know what's wrong with their request
401, # Unauthorized - users need to know why auth failed
403, # Forbidden - users need to know what permission they lack
404, # Not found - usually safe to preserve
409, # Conflict - users need to know about conflicts
422, # Validation errors - users need to know what to fix
}
# Patterns that indicate sensitive information in error messages
SENSITIVE_PATTERNS = [
"traceback",
"stack trace",
"file path",
"/usr/",
"/var/",
"/home/",
"connection refused",
"connection error",
"timeout connecting",
"database error",
"sql",
"query failed",
"password",
"secret",
"token",
"key=",
"credentials",
".py line",
"exception in",
]
def _contains_sensitive_info(message: str) -> bool:
"""Check if an error message contains potentially sensitive information."""
if not message:
return False
message_lower = message.lower()
return any(pattern.lower() in message_lower for pattern in SENSITIVE_PATTERNS)
def _sanitize_detail(detail: any, status_code: int) -> any:
"""Sanitize error detail, removing sensitive information in production.
Args:
detail: The error detail (can be string, list, or dict)
status_code: The HTTP status code
Returns:
Sanitized detail for production, or original detail for debug mode
"""
# In debug mode, return original detail
if settings.DEBUG:
return detail
# For preserved status codes, keep the detail if it doesn't contain sensitive info
if status_code in PRESERVE_MESSAGE_CODES:
if isinstance(detail, str) and not _contains_sensitive_info(detail):
return detail
if isinstance(detail, list):
# For validation errors (list of dicts), keep the structure but sanitize
sanitized = []
for item in detail:
if isinstance(item, dict):
# Keep loc, msg, type for pydantic validation errors
sanitized_item = {}
if 'loc' in item:
sanitized_item['loc'] = item['loc']
if 'msg' in item and not _contains_sensitive_info(str(item['msg'])):
sanitized_item['msg'] = item['msg']
else:
sanitized_item['msg'] = 'Validation failed'
if 'type' in item:
sanitized_item['type'] = item['type']
sanitized.append(sanitized_item)
else:
sanitized.append(item if not _contains_sensitive_info(str(item)) else 'Invalid value')
return sanitized
return detail
# For other status codes, use generic message
return GENERIC_ERROR_MESSAGES.get(status_code, "An error occurred")
class ErrorSanitizerMiddleware(BaseHTTPMiddleware):
"""Middleware to sanitize error responses in production.
This middleware:
1. Intercepts error responses (4xx and 5xx status codes)
2. Parses JSON response bodies
3. Sanitizes the 'detail' field to remove sensitive information
4. Returns the sanitized response
In DEBUG mode, original error messages are preserved for development.
"""
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# Only process error responses with JSON content
if response.status_code < 400:
return response
content_type = response.headers.get("content-type", "")
if "application/json" not in content_type:
return response
# Read the response body
body = b""
async for chunk in response.body_iterator:
body += chunk
if not body:
return response
try:
data = json.loads(body)
except (json.JSONDecodeError, UnicodeDecodeError):
# Not valid JSON, return as-is
return Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
# Sanitize the detail field if present
if "detail" in data:
original_detail = data["detail"]
data["detail"] = _sanitize_detail(original_detail, response.status_code)
# Log the original error in production for debugging
if not settings.DEBUG and original_detail != data["detail"]:
logger.warning(
"Sanitized error response",
extra={
"status_code": response.status_code,
"path": str(request.url.path),
"method": request.method,
"original_detail_length": len(str(original_detail)),
}
)
# Return the sanitized response
return JSONResponse(
content=data,
status_code=response.status_code,
headers={
k: v for k, v in response.headers.items()
if k.lower() not in ("content-length", "content-type")
},
)

View File

@@ -0,0 +1,215 @@
"""Security audit middleware for logging access denials and suspicious auth patterns."""
import time
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Tuple
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.models import AuditLog, AuditAction
from app.services.audit_service import AuditService
# In-memory storage for tracking auth failures
# Structure: {ip_address: [(timestamp, path), ...]}
_auth_failure_tracker: Dict[str, List[Tuple[float, str]]] = defaultdict(list)
# Configuration constants
AUTH_FAILURE_THRESHOLD = 5 # Number of failures to trigger suspicious pattern alert
AUTH_FAILURE_WINDOW_SECONDS = 600 # 10 minutes window
def _cleanup_old_failures(ip: str) -> None:
"""Remove auth failures older than the tracking window."""
if ip not in _auth_failure_tracker:
return
cutoff = time.time() - AUTH_FAILURE_WINDOW_SECONDS
_auth_failure_tracker[ip] = [
(ts, path) for ts, path in _auth_failure_tracker[ip]
if ts > cutoff
]
# Clean up empty entries
if not _auth_failure_tracker[ip]:
del _auth_failure_tracker[ip]
def _track_auth_failure(ip: str, path: str) -> int:
"""Track an auth failure and return the count in the window."""
_cleanup_old_failures(ip)
_auth_failure_tracker[ip].append((time.time(), path))
return len(_auth_failure_tracker[ip])
def _get_recent_failures(ip: str) -> List[str]:
"""Get list of paths that failed auth for this IP."""
_cleanup_old_failures(ip)
return [path for _, path in _auth_failure_tracker.get(ip, [])]
class SecurityAuditMiddleware(BaseHTTPMiddleware):
"""Middleware to audit security-related events like 401/403 responses."""
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# Only process 401 and 403 responses
if response.status_code not in (401, 403):
return response
# Get client IP from audit metadata if available
ip_address = self._get_client_ip(request)
path = str(request.url.path)
method = request.method
# Get user_id if available from request state (set by auth middleware)
user_id = getattr(request.state, 'user_id', None)
db: Session = SessionLocal()
try:
if response.status_code == 403:
self._log_access_denied(db, ip_address, path, method, user_id, request)
elif response.status_code == 401:
self._log_auth_failure(db, ip_address, path, method, request)
db.commit()
except Exception:
db.rollback()
# Don't fail the request due to audit logging errors
finally:
db.close()
return response
def _get_client_ip(self, request: Request) -> str:
"""Get the real client IP address from request."""
# Check for audit metadata first (set by AuditMiddleware)
audit_metadata = getattr(request.state, 'audit_metadata', None)
if audit_metadata and 'ip_address' in audit_metadata:
return audit_metadata['ip_address']
# Fallback to checking headers directly
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"
def _log_access_denied(
self,
db: Session,
ip_address: str,
path: str,
method: str,
user_id: str | None,
request: Request,
) -> None:
"""Log a 403 Forbidden response to the audit trail."""
request_metadata = {
"ip_address": ip_address,
"user_agent": request.headers.get("user-agent", ""),
"method": method,
"path": path,
}
# Try to extract resource info from path
resource_type = self._extract_resource_type(path)
AuditService.log_event(
db=db,
event_type="security.access_denied",
resource_type=resource_type,
action=AuditAction.ACCESS_DENIED,
user_id=user_id,
resource_id=None,
changes=[{
"attempted_path": path,
"attempted_method": method,
"ip_address": ip_address,
}],
request_metadata=request_metadata,
)
def _log_auth_failure(
self,
db: Session,
ip_address: str,
path: str,
method: str,
request: Request,
) -> None:
"""Log a 401 Unauthorized response and check for suspicious patterns."""
# Track this failure
failure_count = _track_auth_failure(ip_address, path)
request_metadata = {
"ip_address": ip_address,
"user_agent": request.headers.get("user-agent", ""),
"method": method,
"path": path,
}
resource_type = self._extract_resource_type(path)
# Log the auth failure
AuditService.log_event(
db=db,
event_type="security.auth_failed",
resource_type=resource_type,
action=AuditAction.AUTH_FAILED,
user_id=None, # No user for 401
resource_id=None,
changes=[{
"attempted_path": path,
"attempted_method": method,
"ip_address": ip_address,
"failure_count_in_window": failure_count,
}],
request_metadata=request_metadata,
)
# Check for suspicious pattern
if failure_count >= AUTH_FAILURE_THRESHOLD:
recent_paths = _get_recent_failures(ip_address)
AuditService.log_event(
db=db,
event_type="security.suspicious_auth_pattern",
resource_type="security",
action=AuditAction.AUTH_FAILED,
user_id=None,
resource_id=None,
changes=[{
"ip_address": ip_address,
"failure_count": failure_count,
"window_minutes": AUTH_FAILURE_WINDOW_SECONDS // 60,
"attempted_paths": list(set(recent_paths)),
}],
request_metadata=request_metadata,
)
def _extract_resource_type(self, path: str) -> str:
"""Extract resource type from path for audit logging."""
# Remove /api/ or /api/v1/ prefix
clean_path = path
if clean_path.startswith("/api/v1/"):
clean_path = clean_path[8:]
elif clean_path.startswith("/api/"):
clean_path = clean_path[5:]
# Get the first path segment as resource type
parts = clean_path.strip("/").split("/")
if parts and parts[0]:
return parts[0]
return "unknown"

View File

@@ -13,6 +13,8 @@ class AuditAction(str, enum.Enum):
RESTORE = "restore" RESTORE = "restore"
LOGIN = "login" LOGIN = "login"
LOGOUT = "logout" LOGOUT = "logout"
ACCESS_DENIED = "access_denied"
AUTH_FAILED = "auth_failed"
class SensitivityLevel(str, enum.Enum): class SensitivityLevel(str, enum.Enum):
@@ -42,10 +44,20 @@ EVENT_SENSITIVITY = {
"attachment.upload": SensitivityLevel.LOW, "attachment.upload": SensitivityLevel.LOW,
"attachment.download": SensitivityLevel.LOW, "attachment.download": SensitivityLevel.LOW,
"attachment.delete": SensitivityLevel.MEDIUM, "attachment.delete": SensitivityLevel.MEDIUM,
# Security events
"security.access_denied": SensitivityLevel.MEDIUM,
"security.auth_failed": SensitivityLevel.MEDIUM,
"security.suspicious_auth_pattern": SensitivityLevel.HIGH,
} }
# Events that should trigger alerts # Events that should trigger alerts
ALERT_EVENTS = {"project.delete", "user.permission_change", "user.admin_change", "role.permission_change"} ALERT_EVENTS = {
"project.delete",
"user.permission_change",
"user.admin_change",
"role.permission_change",
"security.suspicious_auth_pattern",
}
class AuditLog(Base): class AuditLog(Base):
@@ -57,7 +69,7 @@ class AuditLog(Base):
resource_id = Column(String(36), nullable=True) resource_id = Column(String(36), nullable=True)
user_id = Column(String(36), ForeignKey("pjctrl_users.id", ondelete="SET NULL"), nullable=True) user_id = Column(String(36), ForeignKey("pjctrl_users.id", ondelete="SET NULL"), nullable=True)
action = Column( action = Column(
Enum("create", "update", "delete", "restore", "login", "logout", name="audit_action_enum"), Enum("create", "update", "delete", "restore", "login", "logout", "access_denied", "auth_failed", name="audit_action_enum"),
nullable=False nullable=False
) )
changes = Column(JSON, nullable=True) changes = Column(JSON, nullable=True)

View File

@@ -9,10 +9,25 @@ class LoginRequest(BaseModel):
class LoginResponse(BaseModel): class LoginResponse(BaseModel):
access_token: str access_token: str
refresh_token: str
token_type: str = "bearer" token_type: str = "bearer"
expires_in: int = Field(default=3600, description="Access token expiry in seconds")
user: "UserInfo" user: "UserInfo"
class RefreshTokenRequest(BaseModel):
"""Request body for refresh token endpoint."""
refresh_token: str = Field(..., description="The refresh token to use for obtaining a new access token")
class RefreshTokenResponse(BaseModel):
"""Response for refresh token endpoint."""
access_token: str
refresh_token: str # New refresh token (rotation)
token_type: str = "bearer"
expires_in: int = Field(default=3600, description="Access token expiry in seconds")
class UserInfo(BaseModel): class UserInfo(BaseModel):
id: str id: str
email: str email: str

View File

@@ -4,12 +4,12 @@ from pydantic import BaseModel, Field
class CommentCreate(BaseModel): class CommentCreate(BaseModel):
content: str = Field(..., min_length=1, max_length=10000) content: str = Field(..., min_length=1, max_length=5000)
parent_comment_id: Optional[str] = None parent_comment_id: Optional[str] = None
class CommentUpdate(BaseModel): class CommentUpdate(BaseModel):
content: str = Field(..., min_length=1, max_length=10000) content: str = Field(..., min_length=1, max_length=5000)
class CommentAuthor(BaseModel): class CommentAuthor(BaseModel):

View File

@@ -25,7 +25,7 @@ class CustomFieldDefinition(BaseModel):
class ProjectTemplateBase(BaseModel): class ProjectTemplateBase(BaseModel):
"""Base schema for project template.""" """Base schema for project template."""
name: str = Field(..., min_length=1, max_length=200) name: str = Field(..., min_length=1, max_length=200)
description: Optional[str] = None description: Optional[str] = Field(None, max_length=2000)
is_public: bool = Field(default=False) is_public: bool = Field(default=False)
task_statuses: Optional[List[TaskStatusDefinition]] = None task_statuses: Optional[List[TaskStatusDefinition]] = None
custom_fields: Optional[List[CustomFieldDefinition]] = None custom_fields: Optional[List[CustomFieldDefinition]] = None
@@ -43,7 +43,7 @@ class ProjectTemplateCreate(ProjectTemplateBase):
class ProjectTemplateUpdate(BaseModel): class ProjectTemplateUpdate(BaseModel):
"""Schema for updating a project template.""" """Schema for updating a project template."""
name: Optional[str] = Field(None, min_length=1, max_length=200) name: Optional[str] = Field(None, min_length=1, max_length=200)
description: Optional[str] = None description: Optional[str] = Field(None, max_length=2000)
is_public: Optional[bool] = None is_public: Optional[bool] = None
task_statuses: Optional[List[TaskStatusDefinition]] = None task_statuses: Optional[List[TaskStatusDefinition]] = None
custom_fields: Optional[List[CustomFieldDefinition]] = None custom_fields: Optional[List[CustomFieldDefinition]] = None

View File

@@ -4,6 +4,7 @@ import logging
from typing import Dict, Set, Optional, Tuple from typing import Dict, Set, Optional, Tuple
from fastapi import WebSocket from fastapi import WebSocket
from app.core.redis import get_redis_sync from app.core.redis import get_redis_sync
from app.core.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -19,13 +20,48 @@ class ConnectionManager:
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._project_lock = asyncio.Lock() self._project_lock = asyncio.Lock()
async def check_connection_limit(self, user_id: str) -> Tuple[bool, Optional[str]]:
"""
Check if user can create a new WebSocket connection.
Args:
user_id: The user's ID
Returns:
Tuple of (can_connect: bool, reject_reason: str | None)
- can_connect: True if user is within connection limit
- reject_reason: Error message if connection should be rejected
"""
max_connections = settings.MAX_WEBSOCKET_CONNECTIONS_PER_USER
async with self._lock:
current_count = len(self.active_connections.get(user_id, set()))
if current_count >= max_connections:
logger.warning(
f"User {user_id} exceeded WebSocket connection limit "
f"({current_count}/{max_connections})"
)
return False, "Too many connections"
return True, None
def get_user_connection_count(self, user_id: str) -> int:
"""Get the current number of WebSocket connections for a user."""
return len(self.active_connections.get(user_id, set()))
async def connect(self, websocket: WebSocket, user_id: str): async def connect(self, websocket: WebSocket, user_id: str):
"""Accept and track a new WebSocket connection.""" """
await websocket.accept() Track a new WebSocket connection.
Note: WebSocket must already be accepted before calling this method.
Connection limit should be checked via check_connection_limit() before calling.
"""
async with self._lock: async with self._lock:
if user_id not in self.active_connections: if user_id not in self.active_connections:
self.active_connections[user_id] = set() self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket) self.active_connections[user_id].add(websocket)
logger.debug(
f"User {user_id} connected. Total connections: "
f"{len(self.active_connections[user_id])}"
)
async def disconnect(self, websocket: WebSocket, user_id: str): async def disconnect(self, websocket: WebSocket, user_id: str):
"""Remove a WebSocket connection.""" """Remove a WebSocket connection."""

View File

@@ -166,3 +166,20 @@ def admin_token(client, mock_redis):
mock_redis.setex("session:00000000-0000-0000-0000-000000000001", 900, token) mock_redis.setex("session:00000000-0000-0000-0000-000000000001", 900, token)
return token return token
@pytest.fixture
def csrf_token():
"""Generate a CSRF token for the admin user."""
from app.core.security import generate_csrf_token
return generate_csrf_token("00000000-0000-0000-0000-000000000001")
@pytest.fixture
def auth_headers(admin_token, csrf_token):
"""Get complete auth headers including both Authorization and CSRF token."""
return {
"Authorization": f"Bearer {admin_token}",
"X-CSRF-Token": csrf_token,
}

View File

@@ -173,7 +173,7 @@ class TestProjectTemplates:
# Should return list of templates # Should return list of templates
assert "templates" in data or isinstance(data, list) assert "templates" in data or isinstance(data, list)
def test_create_template(self, client, admin_token, db): def test_create_template(self, client, auth_headers, db):
"""Test creating a new project template.""" """Test creating a new project template."""
from app.models import Space from app.models import Space
@@ -192,14 +192,14 @@ class TestProjectTemplates:
{"name": "Done", "color": "#00FF00"} {"name": "Done", "color": "#00FF00"}
] ]
}, },
headers={"Authorization": f"Bearer {admin_token}"} headers=auth_headers
) )
assert response.status_code in [200, 201] assert response.status_code in [200, 201]
data = response.json() data = response.json()
assert data.get("name") == "Test Template" assert data.get("name") == "Test Template"
def test_create_project_from_template(self, client, admin_token, db): def test_create_project_from_template(self, client, auth_headers, db):
"""Test creating a project from a template.""" """Test creating a project from a template."""
from app.models import Space, ProjectTemplate from app.models import Space, ProjectTemplate
@@ -228,14 +228,14 @@ class TestProjectTemplates:
"description": "Created from template", "description": "Created from template",
"template_id": "test-template-id" "template_id": "test-template-id"
}, },
headers={"Authorization": f"Bearer {admin_token}"} headers=auth_headers
) )
assert response.status_code in [200, 201] assert response.status_code in [200, 201]
data = response.json() data = response.json()
assert data.get("name") == "Project from Template" assert data.get("name") == "Project from Template"
def test_delete_template(self, client, admin_token, db): def test_delete_template(self, client, auth_headers, db):
"""Test deleting a project template.""" """Test deleting a project template."""
from app.models import ProjectTemplate from app.models import ProjectTemplate
@@ -251,7 +251,7 @@ class TestProjectTemplates:
response = client.delete( response = client.delete(
"/api/templates/delete-template-id", "/api/templates/delete-template-id",
headers={"Authorization": f"Bearer {admin_token}"} headers=auth_headers
) )
assert response.status_code in [200, 204] assert response.status_code in [200, 204]

View File

@@ -42,6 +42,22 @@ def test_user_token(client, mock_redis, test_user):
return token return token
@pytest.fixture
def test_user_csrf_token(test_user):
"""Generate a CSRF token for the test user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_user.id)
@pytest.fixture
def test_user_auth_headers(test_user_token, test_user_csrf_token):
"""Get complete auth headers for test user."""
return {
"Authorization": f"Bearer {test_user_token}",
"X-CSRF-Token": test_user_csrf_token,
}
@pytest.fixture @pytest.fixture
def test_space(db, test_user): def test_space(db, test_user):
"""Create a test space.""" """Create a test space."""
@@ -154,7 +170,7 @@ class TestFileStorageService:
class TestAttachmentAPI: class TestAttachmentAPI:
"""Tests for Attachment API endpoints.""" """Tests for Attachment API endpoints."""
def test_upload_attachment(self, client, test_user_token, test_task, db, monkeypatch, temp_upload_dir): def test_upload_attachment(self, client, test_user_auth_headers, test_task, db, monkeypatch, temp_upload_dir):
"""Test uploading an attachment.""" """Test uploading an attachment."""
monkeypatch.setattr("app.core.config.settings.UPLOAD_DIR", temp_upload_dir) monkeypatch.setattr("app.core.config.settings.UPLOAD_DIR", temp_upload_dir)
@@ -163,7 +179,7 @@ class TestAttachmentAPI:
response = client.post( response = client.post(
f"/api/tasks/{test_task.id}/attachments", f"/api/tasks/{test_task.id}/attachments",
headers={"Authorization": f"Bearer {test_user_token}"}, headers=test_user_auth_headers,
files=files, files=files,
) )
@@ -271,14 +287,14 @@ class TestAttachmentAPI:
db.refresh(attachment) db.refresh(attachment)
assert attachment.is_deleted == True assert attachment.is_deleted == True
def test_upload_blocked_file_type(self, client, test_user_token, test_task): def test_upload_blocked_file_type(self, client, test_user_auth_headers, test_task):
"""Test that blocked file types are rejected.""" """Test that blocked file types are rejected."""
content = b"malicious content" content = b"malicious content"
files = {"file": ("virus.exe", BytesIO(content), "application/octet-stream")} files = {"file": ("virus.exe", BytesIO(content), "application/octet-stream")}
response = client.post( response = client.post(
f"/api/tasks/{test_task.id}/attachments", f"/api/tasks/{test_task.id}/attachments",
headers={"Authorization": f"Bearer {test_user_token}"}, headers=test_user_auth_headers,
files=files, files=files,
) )
@@ -322,7 +338,7 @@ class TestAttachmentAPI:
assert data["total"] == 2 assert data["total"] == 2
assert len(data["versions"]) == 2 assert len(data["versions"]) == 2
def test_restore_version(self, client, test_user_token, test_task, db): def test_restore_version(self, client, test_user_auth_headers, test_task, db):
"""Test restoring to a previous version.""" """Test restoring to a previous version."""
attachment = Attachment( attachment = Attachment(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@@ -351,7 +367,7 @@ class TestAttachmentAPI:
response = client.post( response = client.post(
f"/api/attachments/{attachment.id}/restore/1", f"/api/attachments/{attachment.id}/restore/1",
headers={"Authorization": f"Bearer {test_user_token}"}, headers=test_user_auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -253,7 +253,7 @@ class TestAuditAPI:
assert data["total"] == 3 assert data["total"] == 3
assert all(log["resource_id"] == resource_id for log in data["logs"]) assert all(log["resource_id"] == resource_id for log in data["logs"])
def test_verify_integrity(self, client, admin_token, db): def test_verify_integrity(self, client, auth_headers, db):
"""Test integrity verification.""" """Test integrity verification."""
now = datetime.utcnow() now = datetime.utcnow()
@@ -270,7 +270,7 @@ class TestAuditAPI:
response = client.post( response = client.post(
"/api/audit-logs/verify-integrity", "/api/audit-logs/verify-integrity",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={ json={
"start_date": (now - timedelta(hours=1)).isoformat(), "start_date": (now - timedelta(hours=1)).isoformat(),
"end_date": (now + timedelta(hours=1)).isoformat(), "end_date": (now + timedelta(hours=1)).isoformat(),
@@ -281,7 +281,7 @@ class TestAuditAPI:
assert data["total_checked"] >= 1 assert data["total_checked"] >= 1
assert data["invalid_count"] == 0 assert data["invalid_count"] == 0
def test_acknowledge_alert(self, client, admin_token, db): def test_acknowledge_alert(self, client, auth_headers, db):
"""Test acknowledging an alert.""" """Test acknowledging an alert."""
# Create a log and alert # Create a log and alert
log = AuditLog( log = AuditLog(
@@ -309,7 +309,7 @@ class TestAuditAPI:
response = client.put( response = client.put(
f"/api/audit-alerts/{alert.id}/acknowledge", f"/api/audit-alerts/{alert.id}/acknowledge",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()

View File

@@ -1,5 +1,16 @@
import pytest import pytest
from app.core.security import create_access_token, decode_access_token, create_token_payload from app.core.security import (
create_access_token,
decode_access_token,
create_token_payload,
generate_refresh_token,
store_refresh_token,
validate_refresh_token,
invalidate_refresh_token,
invalidate_all_user_refresh_tokens,
decode_refresh_token_user_id,
get_refresh_token_key,
)
class TestJWT: class TestJWT:
@@ -59,7 +70,7 @@ class TestAuthEndpoints:
def test_get_me_without_auth(self, client): def test_get_me_without_auth(self, client):
"""Test accessing /me without authentication.""" """Test accessing /me without authentication."""
response = client.get("/api/auth/me") response = client.get("/api/auth/me")
assert response.status_code == 403 assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized
def test_get_me_with_auth(self, client, admin_token): def test_get_me_with_auth(self, client, admin_token):
"""Test accessing /me with valid authentication.""" """Test accessing /me with valid authentication."""
@@ -72,13 +83,196 @@ class TestAuthEndpoints:
assert data["email"] == "ymirliu@panjit.com.tw" assert data["email"] == "ymirliu@panjit.com.tw"
assert data["is_system_admin"] is True assert data["is_system_admin"] is True
def test_logout(self, client, admin_token, mock_redis): def test_logout(self, client, auth_headers, mock_redis):
"""Test logout endpoint.""" """Test logout endpoint."""
response = client.post( response = client.post(
"/api/auth/logout", "/api/auth/logout",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
# Verify session is removed # Verify session is removed
assert mock_redis.get("session:00000000-0000-0000-0000-000000000001") is None assert mock_redis.get("session:00000000-0000-0000-0000-000000000001") is None
class TestRefreshToken:
"""Test refresh token functionality."""
def test_generate_refresh_token(self):
"""Test that refresh tokens are generated correctly."""
token = generate_refresh_token()
assert token is not None
assert isinstance(token, str)
assert len(token) > 20 # URL-safe base64 encoded 32 bytes
def test_generate_unique_refresh_tokens(self):
"""Test that each generated token is unique."""
tokens = [generate_refresh_token() for _ in range(100)]
assert len(set(tokens)) == 100 # All tokens should be unique
def test_store_and_validate_refresh_token(self, mock_redis):
"""Test storing and validating refresh tokens."""
user_id = "test-user-123"
token = generate_refresh_token()
# Store the token
store_refresh_token(mock_redis, user_id, token)
# Validate the token
assert validate_refresh_token(mock_redis, user_id, token) is True
# Wrong user should fail
assert validate_refresh_token(mock_redis, "wrong-user", token) is False
# Wrong token should fail
assert validate_refresh_token(mock_redis, user_id, "wrong-token") is False
def test_invalidate_refresh_token(self, mock_redis):
"""Test invalidating a refresh token."""
user_id = "test-user-123"
token = generate_refresh_token()
# Store and verify
store_refresh_token(mock_redis, user_id, token)
assert validate_refresh_token(mock_redis, user_id, token) is True
# Invalidate
result = invalidate_refresh_token(mock_redis, user_id, token)
assert result is True
# Should no longer be valid
assert validate_refresh_token(mock_redis, user_id, token) is False
def test_invalidate_all_user_refresh_tokens(self, mock_redis):
"""Test invalidating all refresh tokens for a user."""
user_id = "test-user-123"
tokens = [generate_refresh_token() for _ in range(3)]
# Store multiple tokens
for token in tokens:
store_refresh_token(mock_redis, user_id, token)
# Verify all are valid
for token in tokens:
assert validate_refresh_token(mock_redis, user_id, token) is True
# Invalidate all
count = invalidate_all_user_refresh_tokens(mock_redis, user_id)
assert count == 3
# All should be invalid now
for token in tokens:
assert validate_refresh_token(mock_redis, user_id, token) is False
def test_decode_refresh_token_user_id(self, mock_redis):
"""Test finding user ID from refresh token."""
user_id = "test-user-456"
token = generate_refresh_token()
# Store the token
store_refresh_token(mock_redis, user_id, token)
# Find user ID
found_user_id = decode_refresh_token_user_id(token, mock_redis)
assert found_user_id == user_id
# Invalid token should return None
assert decode_refresh_token_user_id("invalid-token", mock_redis) is None
class TestRefreshTokenEndpoint:
"""Test the refresh token API endpoint."""
def test_refresh_token_success(self, client, db, mock_redis):
"""Test successful token refresh."""
user_id = "00000000-0000-0000-0000-000000000001"
# Generate and store a refresh token
refresh_token = generate_refresh_token()
store_refresh_token(mock_redis, user_id, refresh_token)
# Call refresh endpoint
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token},
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
assert data["expires_in"] > 0
# Old refresh token should be invalidated (rotation)
assert validate_refresh_token(mock_redis, user_id, refresh_token) is False
# New refresh token should be valid
assert validate_refresh_token(mock_redis, user_id, data["refresh_token"]) is True
def test_refresh_token_invalid(self, client, mock_redis):
"""Test refresh with invalid token."""
response = client.post(
"/api/auth/refresh",
json={"refresh_token": "invalid-token"},
)
assert response.status_code == 401
assert "Invalid or expired refresh token" in response.json()["detail"]
def test_refresh_token_rotation(self, client, db, mock_redis):
"""Test that refresh tokens are rotated (old one invalidated)."""
user_id = "00000000-0000-0000-0000-000000000001"
# Generate and store initial refresh token
initial_token = generate_refresh_token()
store_refresh_token(mock_redis, user_id, initial_token)
# First refresh
response1 = client.post(
"/api/auth/refresh",
json={"refresh_token": initial_token},
)
assert response1.status_code == 200
new_token = response1.json()["refresh_token"]
# Try to reuse the old token (should fail due to rotation)
response2 = client.post(
"/api/auth/refresh",
json={"refresh_token": initial_token},
)
assert response2.status_code == 401
# New token should still work
response3 = client.post(
"/api/auth/refresh",
json={"refresh_token": new_token},
)
assert response3.status_code == 200
def test_refresh_token_disabled_user(self, client, db, mock_redis):
"""Test that disabled users cannot refresh tokens."""
from app.models.user import User
# Create a disabled user
disabled_user = User(
id="disabled-user-123",
email="disabled@example.com",
name="Disabled User",
is_active=False,
)
db.add(disabled_user)
db.commit()
# Generate and store refresh token for disabled user
refresh_token = generate_refresh_token()
store_refresh_token(mock_redis, disabled_user.id, refresh_token)
# Try to refresh
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token},
)
assert response.status_code == 403
assert "disabled" in response.json()["detail"].lower()

View File

@@ -128,7 +128,7 @@ class TestRedisFailover:
class TestBlockerDeletionCheck: class TestBlockerDeletionCheck:
"""Test blocker check before task deletion.""" """Test blocker check before task deletion."""
def test_delete_task_with_blockers_warning(self, client, admin_token, db): def test_delete_task_with_blockers_warning(self, client, admin_token, csrf_token, db):
"""Test that deleting task with blockers shows warning.""" """Test that deleting task with blockers shows warning."""
from app.models import Space, Project, Task, TaskStatus, TaskDependency from app.models import Space, Project, Task, TaskStatus, TaskDependency
@@ -174,7 +174,7 @@ class TestBlockerDeletionCheck:
# Try to delete without force # Try to delete without force
response = client.delete( response = client.delete(
"/api/tasks/blocker-task", "/api/tasks/blocker-task",
headers={"Authorization": f"Bearer {admin_token}"} headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
) )
# Should return warning or require confirmation # Should return warning or require confirmation
@@ -185,7 +185,7 @@ class TestBlockerDeletionCheck:
if "warning" in data or "blocker_count" in data: if "warning" in data or "blocker_count" in data:
assert data.get("blocker_count", 0) >= 1 or "blocker" in str(data).lower() assert data.get("blocker_count", 0) >= 1 or "blocker" in str(data).lower()
def test_force_delete_resolves_blockers(self, client, admin_token, db): def test_force_delete_resolves_blockers(self, client, admin_token, csrf_token, db):
"""Test that force delete resolves blockers.""" """Test that force delete resolves blockers."""
from app.models import Space, Project, Task, TaskStatus, TaskDependency from app.models import Space, Project, Task, TaskStatus, TaskDependency
@@ -231,7 +231,7 @@ class TestBlockerDeletionCheck:
# Force delete # Force delete
response = client.delete( response = client.delete(
"/api/tasks/force-del-task?force_delete=true", "/api/tasks/force-del-task?force_delete=true",
headers={"Authorization": f"Bearer {admin_token}"} headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -240,7 +240,7 @@ class TestBlockerDeletionCheck:
db.refresh(task_to_delete) db.refresh(task_to_delete)
assert task_to_delete.is_deleted is True assert task_to_delete.is_deleted is True
def test_delete_task_without_blockers(self, client, admin_token, db): def test_delete_task_without_blockers(self, client, admin_token, csrf_token, db):
"""Test deleting task without blockers succeeds normally.""" """Test deleting task without blockers succeeds normally."""
from app.models import Space, Project, Task, TaskStatus from app.models import Space, Project, Task, TaskStatus
@@ -267,7 +267,7 @@ class TestBlockerDeletionCheck:
# Delete should succeed without warning # Delete should succeed without warning
response = client.delete( response = client.delete(
"/api/tasks/no-blocker-task", "/api/tasks/no-blocker-task",
headers={"Authorization": f"Bearer {admin_token}"} headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -36,6 +36,13 @@ def user_token(client, mock_redis, test_user):
return token return token
@pytest.fixture
def user_csrf_token(test_user):
"""Generate a CSRF token for the test user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_user.id)
@pytest.fixture @pytest.fixture
def test_space(db): def test_space(db):
"""Create a test space.""" """Create a test space."""
@@ -100,11 +107,11 @@ def test_task(db, test_project, test_status):
class TestComments: class TestComments:
"""Tests for Comments API.""" """Tests for Comments API."""
def test_create_comment(self, client, admin_token, test_task): def test_create_comment(self, client, auth_headers, test_task):
"""Test creating a comment.""" """Test creating a comment."""
response = client.post( response = client.post(
f"/api/tasks/{test_task.id}/comments", f"/api/tasks/{test_task.id}/comments",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={"content": "This is a test comment"}, json={"content": "This is a test comment"},
) )
assert response.status_code == 201 assert response.status_code == 201
@@ -136,7 +143,7 @@ class TestComments:
assert len(data["comments"]) == 1 assert len(data["comments"]) == 1
assert data["comments"][0]["content"] == "Test comment" assert data["comments"][0]["content"] == "Test comment"
def test_update_comment(self, client, admin_token, db, test_task): def test_update_comment(self, client, auth_headers, db, test_task):
"""Test updating a comment.""" """Test updating a comment."""
comment = Comment( comment = Comment(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@@ -149,7 +156,7 @@ class TestComments:
response = client.put( response = client.put(
f"/api/comments/{comment.id}", f"/api/comments/{comment.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={"content": "Updated content"}, json={"content": "Updated content"},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -157,7 +164,7 @@ class TestComments:
assert data["content"] == "Updated content" assert data["content"] == "Updated content"
assert data["is_edited"] is True assert data["is_edited"] is True
def test_delete_comment(self, client, admin_token, db, test_task): def test_delete_comment(self, client, auth_headers, db, test_task):
"""Test deleting a comment (soft delete).""" """Test deleting a comment (soft delete)."""
comment = Comment( comment = Comment(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@@ -170,7 +177,7 @@ class TestComments:
response = client.delete( response = client.delete(
f"/api/comments/{comment.id}", f"/api/comments/{comment.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 204 assert response.status_code == 204
@@ -178,13 +185,13 @@ class TestComments:
db.refresh(comment) db.refresh(comment)
assert comment.is_deleted is True assert comment.is_deleted is True
def test_mention_limit(self, client, admin_token, test_task): def test_mention_limit(self, client, auth_headers, test_task):
"""Test that @mention limit is enforced.""" """Test that @mention limit is enforced."""
# Create content with more than 10 mentions # Create content with more than 10 mentions
mentions = " ".join([f"@user{i}" for i in range(15)]) mentions = " ".join([f"@user{i}" for i in range(15)])
response = client.post( response = client.post(
f"/api/tasks/{test_task.id}/comments", f"/api/tasks/{test_task.id}/comments",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={"content": f"Test with many mentions: {mentions}"}, json={"content": f"Test with many mentions: {mentions}"},
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -218,7 +225,7 @@ class TestNotifications:
assert data["total"] >= 1 assert data["total"] >= 1
assert data["unread_count"] >= 1 assert data["unread_count"] >= 1
def test_mark_notification_as_read(self, client, admin_token, db): def test_mark_notification_as_read(self, client, auth_headers, db):
"""Test marking a notification as read.""" """Test marking a notification as read."""
notification = Notification( notification = Notification(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@@ -233,14 +240,14 @@ class TestNotifications:
response = client.put( response = client.put(
f"/api/notifications/{notification.id}/read", f"/api/notifications/{notification.id}/read",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["is_read"] is True assert data["is_read"] is True
assert data["read_at"] is not None assert data["read_at"] is not None
def test_mark_all_as_read(self, client, admin_token, db): def test_mark_all_as_read(self, client, auth_headers, db):
"""Test marking all notifications as read.""" """Test marking all notifications as read."""
# Create multiple unread notifications # Create multiple unread notifications
for i in range(3): for i in range(3):
@@ -257,7 +264,7 @@ class TestNotifications:
response = client.put( response = client.put(
"/api/notifications/read-all", "/api/notifications/read-all",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -290,11 +297,11 @@ class TestNotifications:
class TestBlockers: class TestBlockers:
"""Tests for Blockers API.""" """Tests for Blockers API."""
def test_create_blocker(self, client, admin_token, test_task): def test_create_blocker(self, client, auth_headers, test_task):
"""Test creating a blocker.""" """Test creating a blocker."""
response = client.post( response = client.post(
f"/api/tasks/{test_task.id}/blockers", f"/api/tasks/{test_task.id}/blockers",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={"reason": "Waiting for external dependency"}, json={"reason": "Waiting for external dependency"},
) )
assert response.status_code == 201 assert response.status_code == 201
@@ -302,7 +309,7 @@ class TestBlockers:
assert data["reason"] == "Waiting for external dependency" assert data["reason"] == "Waiting for external dependency"
assert data["resolved_at"] is None assert data["resolved_at"] is None
def test_resolve_blocker(self, client, admin_token, db, test_task): def test_resolve_blocker(self, client, auth_headers, db, test_task):
"""Test resolving a blocker.""" """Test resolving a blocker."""
blocker = Blocker( blocker = Blocker(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@@ -316,7 +323,7 @@ class TestBlockers:
response = client.put( response = client.put(
f"/api/blockers/{blocker.id}/resolve", f"/api/blockers/{blocker.id}/resolve",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={"resolution_note": "Issue resolved by updating config"}, json={"resolution_note": "Issue resolved by updating config"},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -348,7 +355,7 @@ class TestBlockers:
assert data["total"] == 1 assert data["total"] == 1
assert data["blockers"][0]["reason"] == "Test blocker" assert data["blockers"][0]["reason"] == "Test blocker"
def test_cannot_create_duplicate_active_blocker(self, client, admin_token, db, test_task): def test_cannot_create_duplicate_active_blocker(self, client, auth_headers, db, test_task):
"""Test that duplicate active blockers are prevented.""" """Test that duplicate active blockers are prevented."""
# Create first blocker # Create first blocker
blocker = Blocker( blocker = Blocker(
@@ -363,7 +370,7 @@ class TestBlockers:
# Try to create second blocker # Try to create second blocker
response = client.post( response = client.post(
f"/api/tasks/{test_task.id}/blockers", f"/api/tasks/{test_task.id}/blockers",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={"reason": "Second blocker"}, json={"reason": "Second blocker"},
) )
assert response.status_code == 400 assert response.status_code == 400

View File

@@ -18,7 +18,7 @@ from datetime import datetime, timedelta
class TestOptimisticLocking: class TestOptimisticLocking:
"""Test optimistic locking for concurrent updates.""" """Test optimistic locking for concurrent updates."""
def test_version_increments_on_update(self, client, admin_token, db): def test_version_increments_on_update(self, client, admin_token, csrf_token, db):
"""Test that task version increments on successful update.""" """Test that task version increments on successful update."""
from app.models import Space, Project, Task, TaskStatus from app.models import Space, Project, Task, TaskStatus
@@ -47,7 +47,7 @@ class TestOptimisticLocking:
response = client.patch( response = client.patch(
"/api/tasks/task-1", "/api/tasks/task-1",
json={"title": "Updated Task", "version": 1}, json={"title": "Updated Task", "version": 1},
headers={"Authorization": f"Bearer {admin_token}"} headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -55,7 +55,7 @@ class TestOptimisticLocking:
assert data["title"] == "Updated Task" assert data["title"] == "Updated Task"
assert data["version"] == 2 # Version should increment assert data["version"] == 2 # Version should increment
def test_version_conflict_returns_409(self, client, admin_token, db): def test_version_conflict_returns_409(self, client, admin_token, csrf_token, db):
"""Test that stale version returns 409 Conflict.""" """Test that stale version returns 409 Conflict."""
from app.models import Space, Project, Task, TaskStatus from app.models import Space, Project, Task, TaskStatus
@@ -84,7 +84,7 @@ class TestOptimisticLocking:
response = client.patch( response = client.patch(
"/api/tasks/task-2", "/api/tasks/task-2",
json={"title": "Stale Update", "version": 1}, json={"title": "Stale Update", "version": 1},
headers={"Authorization": f"Bearer {admin_token}"} headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
) )
assert response.status_code == 409 assert response.status_code == 409
@@ -94,7 +94,7 @@ class TestOptimisticLocking:
assert detail.get("current_version") == 5 assert detail.get("current_version") == 5
assert detail.get("provided_version") == 1 assert detail.get("provided_version") == 1
def test_update_without_version_succeeds(self, client, admin_token, db): def test_update_without_version_succeeds(self, client, admin_token, csrf_token, db):
"""Test that update without version (for backward compatibility) still works.""" """Test that update without version (for backward compatibility) still works."""
from app.models import Space, Project, Task, TaskStatus from app.models import Space, Project, Task, TaskStatus
@@ -123,7 +123,7 @@ class TestOptimisticLocking:
response = client.patch( response = client.patch(
"/api/tasks/task-3", "/api/tasks/task-3",
json={"title": "No Version Update"}, json={"title": "No Version Update"},
headers={"Authorization": f"Bearer {admin_token}"} headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
) )
# Should succeed (backward compatibility) # Should succeed (backward compatibility)
@@ -179,7 +179,7 @@ class TestTriggerRetryMechanism:
class TestCascadeRestore: class TestCascadeRestore:
"""Test cascade restore for soft-deleted tasks.""" """Test cascade restore for soft-deleted tasks."""
def test_restore_parent_with_children(self, client, admin_token, db): def test_restore_parent_with_children(self, client, admin_token, csrf_token, db):
"""Test restoring parent task also restores children deleted at same time.""" """Test restoring parent task also restores children deleted at same time."""
from app.models import Space, Project, Task, TaskStatus from app.models import Space, Project, Task, TaskStatus
from datetime import datetime from datetime import datetime
@@ -236,7 +236,7 @@ class TestCascadeRestore:
response = client.post( response = client.post(
"/api/tasks/parent-task/restore", "/api/tasks/parent-task/restore",
json={"cascade": True}, json={"cascade": True},
headers={"Authorization": f"Bearer {admin_token}"} headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -254,7 +254,7 @@ class TestCascadeRestore:
assert child_task1.is_deleted is False assert child_task1.is_deleted is False
assert child_task2.is_deleted is False assert child_task2.is_deleted is False
def test_restore_parent_only(self, client, admin_token, db): def test_restore_parent_only(self, client, admin_token, csrf_token, db):
"""Test restoring parent task without cascade leaves children deleted.""" """Test restoring parent task without cascade leaves children deleted."""
from app.models import Space, Project, Task, TaskStatus from app.models import Space, Project, Task, TaskStatus
from datetime import datetime from datetime import datetime
@@ -299,7 +299,7 @@ class TestCascadeRestore:
response = client.post( response = client.post(
"/api/tasks/parent-task-2/restore", "/api/tasks/parent-task-2/restore",
json={"cascade": False}, json={"cascade": False},
headers={"Authorization": f"Bearer {admin_token}"} headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -39,7 +39,7 @@ class TestCustomFieldsCRUD:
db.commit() db.commit()
return project return project
def test_create_text_field(self, client, db, admin_token): def test_create_text_field(self, client, db, auth_headers):
"""Test creating a text custom field.""" """Test creating a text custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -50,7 +50,7 @@ class TestCustomFieldsCRUD:
"field_type": "text", "field_type": "text",
"is_required": False, "is_required": False,
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 201 assert response.status_code == 201
@@ -59,7 +59,7 @@ class TestCustomFieldsCRUD:
assert data["field_type"] == "text" assert data["field_type"] == "text"
assert data["is_required"] is False assert data["is_required"] is False
def test_create_number_field(self, client, db, admin_token): def test_create_number_field(self, client, db, auth_headers):
"""Test creating a number custom field.""" """Test creating a number custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -70,7 +70,7 @@ class TestCustomFieldsCRUD:
"field_type": "number", "field_type": "number",
"is_required": True, "is_required": True,
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 201 assert response.status_code == 201
@@ -79,7 +79,7 @@ class TestCustomFieldsCRUD:
assert data["field_type"] == "number" assert data["field_type"] == "number"
assert data["is_required"] is True assert data["is_required"] is True
def test_create_dropdown_field(self, client, db, admin_token): def test_create_dropdown_field(self, client, db, auth_headers):
"""Test creating a dropdown custom field.""" """Test creating a dropdown custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -91,7 +91,7 @@ class TestCustomFieldsCRUD:
"options": ["Frontend", "Backend", "Database", "API"], "options": ["Frontend", "Backend", "Database", "API"],
"is_required": False, "is_required": False,
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 201 assert response.status_code == 201
@@ -100,7 +100,7 @@ class TestCustomFieldsCRUD:
assert data["field_type"] == "dropdown" assert data["field_type"] == "dropdown"
assert data["options"] == ["Frontend", "Backend", "Database", "API"] assert data["options"] == ["Frontend", "Backend", "Database", "API"]
def test_create_dropdown_field_without_options_fails(self, client, db, admin_token): def test_create_dropdown_field_without_options_fails(self, client, db, auth_headers):
"""Test that creating a dropdown field without options fails.""" """Test that creating a dropdown field without options fails."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -111,12 +111,12 @@ class TestCustomFieldsCRUD:
"field_type": "dropdown", "field_type": "dropdown",
"options": [], "options": [],
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 422 # Validation error assert response.status_code == 422 # Validation error
def test_create_formula_field(self, client, db, admin_token): def test_create_formula_field(self, client, db, auth_headers):
"""Test creating a formula custom field.""" """Test creating a formula custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -127,7 +127,7 @@ class TestCustomFieldsCRUD:
"name": "hours_worked", "name": "hours_worked",
"field_type": "number", "field_type": "number",
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
# Create formula field # Create formula field
@@ -138,7 +138,7 @@ class TestCustomFieldsCRUD:
"field_type": "formula", "field_type": "formula",
"formula": "{time_spent} / {original_estimate} * 100", "formula": "{time_spent} / {original_estimate} * 100",
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 201 assert response.status_code == 201
@@ -147,7 +147,7 @@ class TestCustomFieldsCRUD:
assert data["field_type"] == "formula" assert data["field_type"] == "formula"
assert "{time_spent}" in data["formula"] assert "{time_spent}" in data["formula"]
def test_list_custom_fields(self, client, db, admin_token): def test_list_custom_fields(self, client, db, auth_headers):
"""Test listing custom fields for a project.""" """Test listing custom fields for a project."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -155,17 +155,17 @@ class TestCustomFieldsCRUD:
client.post( client.post(
f"/api/projects/{project.id}/custom-fields", f"/api/projects/{project.id}/custom-fields",
json={"name": "Field 1", "field_type": "text"}, json={"name": "Field 1", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
client.post( client.post(
f"/api/projects/{project.id}/custom-fields", f"/api/projects/{project.id}/custom-fields",
json={"name": "Field 2", "field_type": "number"}, json={"name": "Field 2", "field_type": "number"},
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
response = client.get( response = client.get(
f"/api/projects/{project.id}/custom-fields", f"/api/projects/{project.id}/custom-fields",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -173,7 +173,7 @@ class TestCustomFieldsCRUD:
assert data["total"] == 2 assert data["total"] == 2
assert len(data["fields"]) == 2 assert len(data["fields"]) == 2
def test_update_custom_field(self, client, db, admin_token): def test_update_custom_field(self, client, db, auth_headers):
"""Test updating a custom field.""" """Test updating a custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -181,7 +181,7 @@ class TestCustomFieldsCRUD:
create_response = client.post( create_response = client.post(
f"/api/projects/{project.id}/custom-fields", f"/api/projects/{project.id}/custom-fields",
json={"name": "Original Name", "field_type": "text"}, json={"name": "Original Name", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
field_id = create_response.json()["id"] field_id = create_response.json()["id"]
@@ -189,7 +189,7 @@ class TestCustomFieldsCRUD:
response = client.put( response = client.put(
f"/api/custom-fields/{field_id}", f"/api/custom-fields/{field_id}",
json={"name": "Updated Name", "is_required": True}, json={"name": "Updated Name", "is_required": True},
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -197,7 +197,7 @@ class TestCustomFieldsCRUD:
assert data["name"] == "Updated Name" assert data["name"] == "Updated Name"
assert data["is_required"] is True assert data["is_required"] is True
def test_delete_custom_field(self, client, db, admin_token): def test_delete_custom_field(self, client, db, auth_headers):
"""Test deleting a custom field.""" """Test deleting a custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -205,14 +205,14 @@ class TestCustomFieldsCRUD:
create_response = client.post( create_response = client.post(
f"/api/projects/{project.id}/custom-fields", f"/api/projects/{project.id}/custom-fields",
json={"name": "To Delete", "field_type": "text"}, json={"name": "To Delete", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
field_id = create_response.json()["id"] field_id = create_response.json()["id"]
# Delete it # Delete it
response = client.delete( response = client.delete(
f"/api/custom-fields/{field_id}", f"/api/custom-fields/{field_id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 204 assert response.status_code == 204
@@ -220,11 +220,11 @@ class TestCustomFieldsCRUD:
# Verify it's gone # Verify it's gone
get_response = client.get( get_response = client.get(
f"/api/custom-fields/{field_id}", f"/api/custom-fields/{field_id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert get_response.status_code == 404 assert get_response.status_code == 404
def test_max_fields_limit(self, client, db, admin_token): def test_max_fields_limit(self, client, db, auth_headers):
"""Test that maximum 20 custom fields per project is enforced.""" """Test that maximum 20 custom fields per project is enforced."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -233,7 +233,7 @@ class TestCustomFieldsCRUD:
response = client.post( response = client.post(
f"/api/projects/{project.id}/custom-fields", f"/api/projects/{project.id}/custom-fields",
json={"name": f"Field {i}", "field_type": "text"}, json={"name": f"Field {i}", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 201 assert response.status_code == 201
@@ -241,12 +241,12 @@ class TestCustomFieldsCRUD:
response = client.post( response = client.post(
f"/api/projects/{project.id}/custom-fields", f"/api/projects/{project.id}/custom-fields",
json={"name": "Field 21", "field_type": "text"}, json={"name": "Field 21", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 400 assert response.status_code == 400
assert "Maximum" in response.json()["detail"] assert "Maximum" in response.json()["detail"]
def test_duplicate_name_rejected(self, client, db, admin_token): def test_duplicate_name_rejected(self, client, db, auth_headers):
"""Test that duplicate field names are rejected.""" """Test that duplicate field names are rejected."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -254,14 +254,14 @@ class TestCustomFieldsCRUD:
client.post( client.post(
f"/api/projects/{project.id}/custom-fields", f"/api/projects/{project.id}/custom-fields",
json={"name": "Unique Name", "field_type": "text"}, json={"name": "Unique Name", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
# Try to create another with same name # Try to create another with same name
response = client.post( response = client.post(
f"/api/projects/{project.id}/custom-fields", f"/api/projects/{project.id}/custom-fields",
json={"name": "Unique Name", "field_type": "number"}, json={"name": "Unique Name", "field_type": "number"},
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 400 assert response.status_code == 400
assert "already exists" in response.json()["detail"] assert "already exists" in response.json()["detail"]
@@ -311,7 +311,7 @@ class TestFormulaService:
class TestCustomValuesWithTasks: class TestCustomValuesWithTasks:
"""Test custom values integration with tasks.""" """Test custom values integration with tasks."""
def setup_project_with_fields(self, db, client, admin_token, owner_id: str): def setup_project_with_fields(self, db, client, auth_headers, owner_id: str):
"""Create a project with custom fields for testing.""" """Create a project with custom fields for testing."""
space = Space( space = Space(
id="test-space-002", id="test-space-002",
@@ -342,23 +342,23 @@ class TestCustomValuesWithTasks:
text_response = client.post( text_response = client.post(
f"/api/projects/{project.id}/custom-fields", f"/api/projects/{project.id}/custom-fields",
json={"name": "sprint_number", "field_type": "text"}, json={"name": "sprint_number", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
text_field_id = text_response.json()["id"] text_field_id = text_response.json()["id"]
number_response = client.post( number_response = client.post(
f"/api/projects/{project.id}/custom-fields", f"/api/projects/{project.id}/custom-fields",
json={"name": "story_points", "field_type": "number"}, json={"name": "story_points", "field_type": "number"},
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
number_field_id = number_response.json()["id"] number_field_id = number_response.json()["id"]
return project, text_field_id, number_field_id return project, text_field_id, number_field_id
def test_create_task_with_custom_values(self, client, db, admin_token): def test_create_task_with_custom_values(self, client, db, auth_headers):
"""Test creating a task with custom values.""" """Test creating a task with custom values."""
project, text_field_id, number_field_id = self.setup_project_with_fields( project, text_field_id, number_field_id = self.setup_project_with_fields(
db, client, admin_token, "00000000-0000-0000-0000-000000000001" db, client, auth_headers, "00000000-0000-0000-0000-000000000001"
) )
response = client.post( response = client.post(
@@ -370,15 +370,15 @@ class TestCustomValuesWithTasks:
{"field_id": number_field_id, "value": "8"}, {"field_id": number_field_id, "value": "8"},
], ],
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 201 assert response.status_code == 201
def test_get_task_includes_custom_values(self, client, db, admin_token): def test_get_task_includes_custom_values(self, client, db, auth_headers):
"""Test that getting a task includes custom values.""" """Test that getting a task includes custom values."""
project, text_field_id, number_field_id = self.setup_project_with_fields( project, text_field_id, number_field_id = self.setup_project_with_fields(
db, client, admin_token, "00000000-0000-0000-0000-000000000001" db, client, auth_headers, "00000000-0000-0000-0000-000000000001"
) )
# Create task with custom values # Create task with custom values
@@ -391,14 +391,14 @@ class TestCustomValuesWithTasks:
{"field_id": number_field_id, "value": "8"}, {"field_id": number_field_id, "value": "8"},
], ],
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
task_id = create_response.json()["id"] task_id = create_response.json()["id"]
# Get task and check custom values # Get task and check custom values
get_response = client.get( get_response = client.get(
f"/api/tasks/{task_id}", f"/api/tasks/{task_id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert get_response.status_code == 200 assert get_response.status_code == 200
@@ -406,10 +406,10 @@ class TestCustomValuesWithTasks:
assert data["custom_values"] is not None assert data["custom_values"] is not None
assert len(data["custom_values"]) >= 2 assert len(data["custom_values"]) >= 2
def test_update_task_custom_values(self, client, db, admin_token): def test_update_task_custom_values(self, client, db, auth_headers):
"""Test updating custom values on a task.""" """Test updating custom values on a task."""
project, text_field_id, number_field_id = self.setup_project_with_fields( project, text_field_id, number_field_id = self.setup_project_with_fields(
db, client, admin_token, "00000000-0000-0000-0000-000000000001" db, client, auth_headers, "00000000-0000-0000-0000-000000000001"
) )
# Create task # Create task
@@ -421,7 +421,7 @@ class TestCustomValuesWithTasks:
{"field_id": text_field_id, "value": "Sprint 5"}, {"field_id": text_field_id, "value": "Sprint 5"},
], ],
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
task_id = create_response.json()["id"] task_id = create_response.json()["id"]
@@ -434,7 +434,7 @@ class TestCustomValuesWithTasks:
{"field_id": number_field_id, "value": "13"}, {"field_id": number_field_id, "value": "13"},
], ],
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert update_response.status_code == 200 assert update_response.status_code == 200

View File

@@ -619,7 +619,7 @@ class TestDashboardAPI:
def test_dashboard_unauthorized(self, client, db): def test_dashboard_unauthorized(self, client, db):
"""Unauthenticated requests should fail.""" """Unauthenticated requests should fail."""
response = client.get("/api/dashboard") response = client.get("/api/dashboard")
assert response.status_code == 403 assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized
def test_dashboard_with_user_tasks(self, client, db, admin_token): def test_dashboard_with_user_tasks(self, client, db, admin_token):
"""Dashboard should reflect user's tasks correctly.""" """Dashboard should reflect user's tasks correctly."""

View File

@@ -312,6 +312,13 @@ class TestConfidentialProjectUpload:
mock_redis.setex(f"session:{test_user.id}", 900, token) mock_redis.setex(f"session:{test_user.id}", 900, token)
return token return token
@pytest.fixture
def test_user_csrf_token(self, test_user):
"""Generate a CSRF token for the test user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_user.id)
@pytest.fixture @pytest.fixture
def test_space(self, db, test_user): def test_space(self, db, test_user):
"""Create a test space.""" """Create a test space."""
@@ -364,7 +371,7 @@ class TestConfidentialProjectUpload:
return task return task
def test_upload_confidential_project_encryption_unavailable( def test_upload_confidential_project_encryption_unavailable(
self, client, test_user_token, test_task, db self, client, test_user_token, test_user_csrf_token, test_task, db
): ):
"""Test that uploading to confidential project returns 400 when encryption is unavailable.""" """Test that uploading to confidential project returns 400 when encryption is unavailable."""
from io import BytesIO from io import BytesIO
@@ -378,7 +385,7 @@ class TestConfidentialProjectUpload:
response = client.post( response = client.post(
f"/api/tasks/{test_task.id}/attachments", f"/api/tasks/{test_task.id}/attachments",
headers={"Authorization": f"Bearer {test_user_token}"}, headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
files=files, files=files,
) )
@@ -387,7 +394,7 @@ class TestConfidentialProjectUpload:
assert "environment variable" in response.json()["detail"] assert "environment variable" in response.json()["detail"]
def test_upload_confidential_project_no_active_key( def test_upload_confidential_project_no_active_key(
self, client, test_user_token, test_task, db self, client, test_user_token, test_user_csrf_token, test_task, db
): ):
"""Test that uploading to confidential project returns 400 when no active encryption key exists.""" """Test that uploading to confidential project returns 400 when no active encryption key exists."""
from io import BytesIO from io import BytesIO
@@ -408,7 +415,7 @@ class TestConfidentialProjectUpload:
response = client.post( response = client.post(
f"/api/tasks/{test_task.id}/attachments", f"/api/tasks/{test_task.id}/attachments",
headers={"Authorization": f"Bearer {test_user_token}"}, headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
files=files, files=files,
) )

View File

@@ -614,7 +614,7 @@ class TestHealthAPI:
def test_unauthorized_access(self, client, db): def test_unauthorized_access(self, client, db):
"""Unauthenticated requests should fail.""" """Unauthenticated requests should fail."""
response = client.get("/api/projects/health/dashboard") response = client.get("/api/projects/health/dashboard")
assert response.status_code == 403 assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized
def test_dashboard_with_status_filter(self, client, db, admin_token): def test_dashboard_with_status_filter(self, client, db, admin_token):
"""Dashboard should respect status filter.""" """Dashboard should respect status filter."""

View File

@@ -38,6 +38,14 @@ def test_user_token(client, mock_redis, test_user):
return token return token
@pytest.fixture
def test_user_csrf_token(test_user):
"""Generate a CSRF token for the test user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_user.id)
@pytest.fixture @pytest.fixture
def test_space(db, test_user): def test_space(db, test_user):
"""Create a test space.""" """Create a test space."""
@@ -284,11 +292,11 @@ class TestReportAPI:
assert "projects" in data assert "projects" in data
assert data["summary"]["total_tasks"] == 3 assert data["summary"]["total_tasks"] == 3
def test_generate_weekly_report_api(self, client, test_user_token, test_project, test_tasks, test_statuses): def test_generate_weekly_report_api(self, client, test_user_token, test_user_csrf_token, test_project, test_tasks, test_statuses):
"""Test generating weekly report via API.""" """Test generating weekly report via API."""
response = client.post( response = client.post(
"/api/reports/weekly/generate", "/api/reports/weekly/generate",
headers={"Authorization": f"Bearer {test_user_token}"}, headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -297,7 +305,7 @@ class TestReportAPI:
assert "report_id" in data assert "report_id" in data
assert "summary" in data assert "summary" in data
def test_weekly_report_subscription_toggle(self, client, test_user_token, db, test_user): def test_weekly_report_subscription_toggle(self, client, test_user_token, test_user_csrf_token, db, test_user):
"""Test weekly report subscription toggle endpoints.""" """Test weekly report subscription toggle endpoints."""
response = client.get( response = client.get(
"/api/reports/weekly/subscription", "/api/reports/weekly/subscription",
@@ -308,7 +316,7 @@ class TestReportAPI:
response = client.put( response = client.put(
"/api/reports/weekly/subscription", "/api/reports/weekly/subscription",
headers={"Authorization": f"Bearer {test_user_token}"}, headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={"is_active": True}, json={"is_active": True},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -323,7 +331,7 @@ class TestReportAPI:
response = client.put( response = client.put(
"/api/reports/weekly/subscription", "/api/reports/weekly/subscription",
headers={"Authorization": f"Bearer {test_user_token}"}, headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={"is_active": False}, json={"is_active": False},
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -52,6 +52,14 @@ def test_user_token(client, mock_redis, test_user):
return token return token
@pytest.fixture
def test_user_csrf_token(test_user):
"""Generate a CSRF token for the test user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_user.id)
@pytest.fixture @pytest.fixture
def test_space(db, test_user): def test_space(db, test_user):
"""Create a test space.""" """Create a test space."""
@@ -445,11 +453,11 @@ class TestDeadlineReminderLogic:
class TestScheduleTriggerAPI: class TestScheduleTriggerAPI:
"""Tests for Schedule Trigger API endpoints.""" """Tests for Schedule Trigger API endpoints."""
def test_create_cron_trigger(self, client, test_user_token, test_project): def test_create_cron_trigger(self, client, test_user_token, test_user_csrf_token, test_project):
"""Test creating a schedule trigger with cron expression.""" """Test creating a schedule trigger with cron expression."""
response = client.post( response = client.post(
f"/api/projects/{test_project.id}/triggers", f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"}, headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={ json={
"name": "Weekly Monday Reminder", "name": "Weekly Monday Reminder",
"description": "Remind every Monday at 9am", "description": "Remind every Monday at 9am",
@@ -471,11 +479,11 @@ class TestScheduleTriggerAPI:
assert data["trigger_type"] == "schedule" assert data["trigger_type"] == "schedule"
assert data["conditions"]["cron_expression"] == "0 9 * * 1" assert data["conditions"]["cron_expression"] == "0 9 * * 1"
def test_create_deadline_trigger(self, client, test_user_token, test_project): def test_create_deadline_trigger(self, client, test_user_token, test_user_csrf_token, test_project):
"""Test creating a schedule trigger with deadline reminder.""" """Test creating a schedule trigger with deadline reminder."""
response = client.post( response = client.post(
f"/api/projects/{test_project.id}/triggers", f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"}, headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={ json={
"name": "Deadline Reminder", "name": "Deadline Reminder",
"description": "Remind 5 days before deadline", "description": "Remind 5 days before deadline",
@@ -494,11 +502,11 @@ class TestScheduleTriggerAPI:
data = response.json() data = response.json()
assert data["conditions"]["deadline_reminder_days"] == 5 assert data["conditions"]["deadline_reminder_days"] == 5
def test_create_schedule_trigger_invalid_cron(self, client, test_user_token, test_project): def test_create_schedule_trigger_invalid_cron(self, client, test_user_token, test_user_csrf_token, test_project):
"""Test creating a schedule trigger with invalid cron expression.""" """Test creating a schedule trigger with invalid cron expression."""
response = client.post( response = client.post(
f"/api/projects/{test_project.id}/triggers", f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"}, headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={ json={
"name": "Invalid Cron Trigger", "name": "Invalid Cron Trigger",
"trigger_type": "schedule", "trigger_type": "schedule",
@@ -512,11 +520,11 @@ class TestScheduleTriggerAPI:
assert response.status_code == 400 assert response.status_code == 400
assert "Invalid cron expression" in response.json()["detail"] assert "Invalid cron expression" in response.json()["detail"]
def test_create_schedule_trigger_missing_condition(self, client, test_user_token, test_project): def test_create_schedule_trigger_missing_condition(self, client, test_user_token, test_user_csrf_token, test_project):
"""Test creating a schedule trigger without cron or deadline condition.""" """Test creating a schedule trigger without cron or deadline condition."""
response = client.post( response = client.post(
f"/api/projects/{test_project.id}/triggers", f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"}, headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={ json={
"name": "Empty Schedule Trigger", "name": "Empty Schedule Trigger",
"trigger_type": "schedule", "trigger_type": "schedule",
@@ -528,11 +536,11 @@ class TestScheduleTriggerAPI:
assert response.status_code == 400 assert response.status_code == 400
assert "require either cron_expression or deadline_reminder_days" in response.json()["detail"] assert "require either cron_expression or deadline_reminder_days" in response.json()["detail"]
def test_update_schedule_trigger_cron(self, client, test_user_token, cron_trigger): def test_update_schedule_trigger_cron(self, client, test_user_token, test_user_csrf_token, cron_trigger):
"""Test updating a schedule trigger's cron expression.""" """Test updating a schedule trigger's cron expression."""
response = client.put( response = client.put(
f"/api/triggers/{cron_trigger.id}", f"/api/triggers/{cron_trigger.id}",
headers={"Authorization": f"Bearer {test_user_token}"}, headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={ json={
"conditions": { "conditions": {
"cron_expression": "0 10 * * *", # Changed to 10am "cron_expression": "0 10 * * *", # Changed to 10am
@@ -544,11 +552,11 @@ class TestScheduleTriggerAPI:
data = response.json() data = response.json()
assert data["conditions"]["cron_expression"] == "0 10 * * *" assert data["conditions"]["cron_expression"] == "0 10 * * *"
def test_update_schedule_trigger_invalid_cron(self, client, test_user_token, cron_trigger): def test_update_schedule_trigger_invalid_cron(self, client, test_user_token, test_user_csrf_token, cron_trigger):
"""Test updating a schedule trigger with invalid cron expression.""" """Test updating a schedule trigger with invalid cron expression."""
response = client.put( response = client.put(
f"/api/triggers/{cron_trigger.id}", f"/api/triggers/{cron_trigger.id}",
headers={"Authorization": f"Bearer {test_user_token}"}, headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={ json={
"conditions": { "conditions": {
"cron_expression": "not valid", "cron_expression": "not valid",

View File

@@ -69,6 +69,22 @@ def regular_token(client, mock_redis, test_regular_user):
return token return token
@pytest.fixture
def csrf_token(test_admin):
"""Generate a CSRF token for the test admin user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_admin.id)
@pytest.fixture
def auth_headers(admin_token, csrf_token):
"""Get complete auth headers including both Authorization and CSRF token."""
return {
"Authorization": f"Bearer {admin_token}",
"X-CSRF-Token": csrf_token,
}
@pytest.fixture @pytest.fixture
def test_space(db, test_admin): def test_space(db, test_admin):
"""Create a test space.""" """Create a test space."""
@@ -148,11 +164,11 @@ def test_task_with_subtask(db, test_project, test_admin, test_status, test_task)
class TestSoftDelete: class TestSoftDelete:
"""Tests for soft delete functionality.""" """Tests for soft delete functionality."""
def test_delete_task_soft_deletes(self, client, admin_token, test_task, db): def test_delete_task_soft_deletes(self, client, auth_headers, test_task, db):
"""Test that DELETE soft-deletes a task.""" """Test that DELETE soft-deletes a task."""
response = client.delete( response = client.delete(
f"/api/tasks/{test_task.id}", f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -165,36 +181,36 @@ class TestSoftDelete:
assert test_task.deleted_at is not None assert test_task.deleted_at is not None
assert test_task.deleted_by is not None assert test_task.deleted_by is not None
def test_deleted_task_not_in_list(self, client, admin_token, test_project, test_task, db): def test_deleted_task_not_in_list(self, client, auth_headers, test_project, test_task, db):
"""Test that deleted tasks are not shown in list.""" """Test that deleted tasks are not shown in list."""
# Delete the task # Delete the task
client.delete( client.delete(
f"/api/tasks/{test_task.id}", f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
# List tasks # List tasks
response = client.get( response = client.get(
f"/api/projects/{test_project.id}/tasks", f"/api/projects/{test_project.id}/tasks",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["total"] == 0 assert data["total"] == 0
def test_admin_can_list_deleted_with_include_deleted(self, client, admin_token, test_project, test_task, db): def test_admin_can_list_deleted_with_include_deleted(self, client, auth_headers, test_project, test_task, db):
"""Test that admin can see deleted tasks with include_deleted parameter.""" """Test that admin can see deleted tasks with include_deleted parameter."""
# Delete the task # Delete the task
client.delete( client.delete(
f"/api/tasks/{test_task.id}", f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
# List with include_deleted # List with include_deleted
response = client.get( response = client.get(
f"/api/projects/{test_project.id}/tasks?include_deleted=true", f"/api/projects/{test_project.id}/tasks?include_deleted=true",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -202,12 +218,12 @@ class TestSoftDelete:
assert data["total"] == 1 assert data["total"] == 1
assert data["tasks"][0]["id"] == test_task.id assert data["tasks"][0]["id"] == test_task.id
def test_regular_user_cannot_see_deleted_with_include_deleted(self, client, regular_token, test_project, test_task, admin_token, db): def test_regular_user_cannot_see_deleted_with_include_deleted(self, client, regular_token, test_project, test_task, auth_headers, db, csrf_token):
"""Test that non-admin cannot see deleted tasks even with include_deleted.""" """Test that non-admin cannot see deleted tasks even with include_deleted."""
# Delete the task as admin # Delete the task as admin
client.delete( client.delete(
f"/api/tasks/{test_task.id}", f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
# Try to list with include_deleted as regular user # Try to list with include_deleted as regular user
@@ -220,12 +236,12 @@ class TestSoftDelete:
data = response.json() data = response.json()
assert data["total"] == 0 assert data["total"] == 0
def test_get_deleted_task_returns_404_for_regular_user(self, client, admin_token, regular_token, test_task, db): def test_get_deleted_task_returns_404_for_regular_user(self, client, auth_headers, regular_token, test_task, db):
"""Test that getting a deleted task returns 404 for non-admin.""" """Test that getting a deleted task returns 404 for non-admin."""
# Delete the task # Delete the task
client.delete( client.delete(
f"/api/tasks/{test_task.id}", f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
# Try to get as regular user # Try to get as regular user
@@ -236,28 +252,28 @@ class TestSoftDelete:
assert response.status_code == 404 assert response.status_code == 404
def test_admin_can_view_deleted_task(self, client, admin_token, test_task, db): def test_admin_can_view_deleted_task(self, client, auth_headers, test_task, db):
"""Test that admin can view a deleted task.""" """Test that admin can view a deleted task."""
# Delete the task # Delete the task
client.delete( client.delete(
f"/api/tasks/{test_task.id}", f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
# Get as admin # Get as admin
response = client.get( response = client.get(
f"/api/tasks/{test_task.id}", f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
def test_cascade_soft_delete_subtasks(self, client, admin_token, test_task, test_task_with_subtask, db): def test_cascade_soft_delete_subtasks(self, client, auth_headers, test_task, test_task_with_subtask, db):
"""Test that deleting a parent task soft-deletes its subtasks.""" """Test that deleting a parent task soft-deletes its subtasks."""
# Delete the parent task # Delete the parent task
response = client.delete( response = client.delete(
f"/api/tasks/{test_task.id}", f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -270,18 +286,18 @@ class TestSoftDelete:
class TestRestoreTask: class TestRestoreTask:
"""Tests for task restoration functionality.""" """Tests for task restoration functionality."""
def test_restore_task(self, client, admin_token, test_task, db): def test_restore_task(self, client, auth_headers, test_task, db):
"""Test that admin can restore a deleted task.""" """Test that admin can restore a deleted task."""
# Delete the task # Delete the task
client.delete( client.delete(
f"/api/tasks/{test_task.id}", f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
# Restore the task # Restore the task
response = client.post( response = client.post(
f"/api/tasks/{test_task.id}/restore", f"/api/tasks/{test_task.id}/restore",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -292,27 +308,29 @@ class TestRestoreTask:
assert test_task.deleted_at is None assert test_task.deleted_at is None
assert test_task.deleted_by is None assert test_task.deleted_by is None
def test_regular_user_cannot_restore(self, client, admin_token, regular_token, test_task, db): def test_regular_user_cannot_restore(self, client, auth_headers, regular_token, test_task, db, test_regular_user):
"""Test that non-admin cannot restore a deleted task.""" """Test that non-admin cannot restore a deleted task."""
from app.core.security import generate_csrf_token
# Delete the task # Delete the task
client.delete( client.delete(
f"/api/tasks/{test_task.id}", f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
# Try to restore as regular user # Try to restore as regular user
regular_csrf = generate_csrf_token(test_regular_user.id)
response = client.post( response = client.post(
f"/api/tasks/{test_task.id}/restore", f"/api/tasks/{test_task.id}/restore",
headers={"Authorization": f"Bearer {regular_token}"}, headers={"Authorization": f"Bearer {regular_token}", "X-CSRF-Token": regular_csrf},
) )
assert response.status_code == 403 assert response.status_code == 403
def test_cannot_restore_non_deleted_task(self, client, admin_token, test_task, db): def test_cannot_restore_non_deleted_task(self, client, auth_headers, test_task, db):
"""Test that restoring a non-deleted task returns error.""" """Test that restoring a non-deleted task returns error."""
response = client.post( response = client.post(
f"/api/tasks/{test_task.id}/restore", f"/api/tasks/{test_task.id}/restore",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -322,12 +340,12 @@ class TestRestoreTask:
class TestSubtaskCount: class TestSubtaskCount:
"""Tests for subtask count excluding deleted.""" """Tests for subtask count excluding deleted."""
def test_subtask_count_excludes_deleted(self, client, admin_token, test_task, test_task_with_subtask, db): def test_subtask_count_excludes_deleted(self, client, auth_headers, test_task, test_task_with_subtask, db):
"""Test that subtask_count excludes deleted subtasks.""" """Test that subtask_count excludes deleted subtasks."""
# Get parent task before deletion # Get parent task before deletion
response = client.get( response = client.get(
f"/api/tasks/{test_task.id}", f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["subtask_count"] == 1 assert response.json()["subtask_count"] == 1
@@ -335,13 +353,13 @@ class TestSubtaskCount:
# Delete subtask # Delete subtask
client.delete( client.delete(
f"/api/tasks/{test_task_with_subtask.id}", f"/api/tasks/{test_task_with_subtask.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
# Get parent task after deletion # Get parent task after deletion
response = client.get( response = client.get(
f"/api/tasks/{test_task.id}", f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["subtask_count"] == 0 assert response.json()["subtask_count"] == 0

View File

@@ -57,7 +57,7 @@ class TestSpacesAPI:
"/api/spaces", "/api/spaces",
json={"name": "Test Space", "description": "Test"} json={"name": "Test Space", "description": "Test"}
) )
assert response.status_code == 403 # No auth header assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized
def test_space_routes_exist(self): def test_space_routes_exist(self):
"""Test that all space routes are registered.""" """Test that all space routes are registered."""

View File

@@ -783,7 +783,7 @@ class TestDateValidation:
class TestDependencyCRUDAPI: class TestDependencyCRUDAPI:
"""Test dependency CRUD API endpoints.""" """Test dependency CRUD API endpoints."""
def test_create_dependency(self, client, db, admin_token): def test_create_dependency(self, client, db, admin_token, csrf_token):
"""Test creating a dependency via API.""" """Test creating a dependency via API."""
# Create test data # Create test data
space = Space( space = Space(
@@ -838,7 +838,7 @@ class TestDependencyCRUDAPI:
"dependency_type": "FS", "dependency_type": "FS",
"lag_days": 0 "lag_days": 0
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
) )
assert response.status_code == 201 assert response.status_code == 201
@@ -914,7 +914,7 @@ class TestDependencyCRUDAPI:
assert data["total"] >= 1 assert data["total"] >= 1
assert any(d["predecessor_id"] == "task-api-list-1" for d in data["dependencies"]) assert any(d["predecessor_id"] == "task-api-list-1" for d in data["dependencies"])
def test_delete_dependency(self, client, db, admin_token): def test_delete_dependency(self, client, db, admin_token, csrf_token):
"""Test deleting a dependency.""" """Test deleting a dependency."""
# Create test data # Create test data
space = Space( space = Space(
@@ -973,7 +973,7 @@ class TestDependencyCRUDAPI:
# Delete dependency # Delete dependency
response = client.delete( response = client.delete(
"/api/task-dependencies/dep-api-del", "/api/task-dependencies/dep-api-del",
headers={"Authorization": f"Bearer {admin_token}"}, headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
) )
assert response.status_code == 204 assert response.status_code == 204
@@ -984,7 +984,7 @@ class TestDependencyCRUDAPI:
).first() ).first()
assert dep_check is None assert dep_check is None
def test_circular_dependency_rejected_via_api(self, client, db, admin_token): def test_circular_dependency_rejected_via_api(self, client, db, admin_token, csrf_token):
"""Test that circular dependencies are rejected via API.""" """Test that circular dependencies are rejected via API."""
# Create test data # Create test data
space = Space( space = Space(
@@ -1049,7 +1049,7 @@ class TestDependencyCRUDAPI:
"dependency_type": "FS", "dependency_type": "FS",
"lag_days": 0 "lag_days": 0
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -1060,7 +1060,7 @@ class TestDependencyCRUDAPI:
class TestTaskDateValidationAPI: class TestTaskDateValidationAPI:
"""Test task date validation in task API.""" """Test task date validation in task API."""
def test_create_task_with_invalid_dates_rejected(self, client, db, admin_token): def test_create_task_with_invalid_dates_rejected(self, client, db, admin_token, csrf_token):
"""Test that creating a task with start_date > due_date is rejected.""" """Test that creating a task with start_date > due_date is rejected."""
# Create test data # Create test data
space = Space( space = Space(
@@ -1099,13 +1099,13 @@ class TestTaskDateValidationAPI:
"start_date": (now + timedelta(days=10)).isoformat(), "start_date": (now + timedelta(days=10)).isoformat(),
"due_date": now.isoformat(), "due_date": now.isoformat(),
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
) )
assert response.status_code == 400 assert response.status_code == 400
assert "Start date cannot be after due date" in response.json()["detail"] assert "Start date cannot be after due date" in response.json()["detail"]
def test_update_task_with_invalid_dates_rejected(self, client, db, admin_token): def test_update_task_with_invalid_dates_rejected(self, client, db, admin_token, csrf_token):
"""Test that updating a task to have start_date > due_date is rejected.""" """Test that updating a task to have start_date > due_date is rejected."""
# Create test data # Create test data
space = Space( space = Space(
@@ -1153,12 +1153,12 @@ class TestTaskDateValidationAPI:
json={ json={
"start_date": (now + timedelta(days=20)).isoformat(), "start_date": (now + timedelta(days=20)).isoformat(),
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
) )
assert response.status_code == 400 assert response.status_code == 400
def test_create_task_with_valid_dates_accepted(self, client, db, admin_token): def test_create_task_with_valid_dates_accepted(self, client, db, admin_token, csrf_token):
"""Test that creating a task with valid dates is accepted.""" """Test that creating a task with valid dates is accepted."""
# Create test data # Create test data
space = Space( space = Space(
@@ -1197,7 +1197,7 @@ class TestTaskDateValidationAPI:
"start_date": now.isoformat(), "start_date": now.isoformat(),
"due_date": (now + timedelta(days=10)).isoformat(), "due_date": (now + timedelta(days=10)).isoformat(),
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
) )
assert response.status_code == 201 assert response.status_code == 201
@@ -1217,7 +1217,7 @@ class TestDependencyTypes:
assert DependencyType.FF.value == "FF" assert DependencyType.FF.value == "FF"
assert DependencyType.SF.value == "SF" assert DependencyType.SF.value == "SF"
def test_create_dependency_with_different_types(self, client, db, admin_token): def test_create_dependency_with_different_types(self, client, db, admin_token, csrf_token):
"""Test creating dependencies with different types via API.""" """Test creating dependencies with different types via API."""
# Create test data # Create test data
space = Space( space = Space(
@@ -1268,7 +1268,7 @@ class TestDependencyTypes:
"dependency_type": dep_type, "dependency_type": dep_type,
"lag_days": i "lag_days": i
}, },
headers={"Authorization": f"Bearer {admin_token}"}, headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
) )
assert response.status_code == 201 assert response.status_code == 201

View File

@@ -43,6 +43,22 @@ def test_user_token(client, mock_redis, test_user):
return token return token
@pytest.fixture
def test_user_csrf_token(test_user):
"""Generate a CSRF token for the test user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_user.id)
@pytest.fixture
def test_user_auth_headers(test_user_token, test_user_csrf_token):
"""Get complete auth headers for test user."""
return {
"Authorization": f"Bearer {test_user_token}",
"X-CSRF-Token": test_user_csrf_token,
}
@pytest.fixture @pytest.fixture
def test_space(db, test_user): def test_space(db, test_user):
"""Create a test space.""" """Create a test space."""
@@ -513,11 +529,11 @@ class TestTriggerNotifications:
class TestTriggerAPI: class TestTriggerAPI:
"""Tests for Trigger API endpoints.""" """Tests for Trigger API endpoints."""
def test_create_trigger(self, client, test_user_token, test_project, test_status): def test_create_trigger(self, client, test_user_auth_headers, test_project, test_status):
"""Test creating a trigger.""" """Test creating a trigger."""
response = client.post( response = client.post(
f"/api/projects/{test_project.id}/triggers", f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"}, headers=test_user_auth_headers,
json={ json={
"name": "New Trigger", "name": "New Trigger",
"description": "Test trigger", "description": "Test trigger",
@@ -563,11 +579,11 @@ class TestTriggerAPI:
assert data["id"] == test_trigger.id assert data["id"] == test_trigger.id
assert data["name"] == test_trigger.name assert data["name"] == test_trigger.name
def test_update_trigger(self, client, test_user_token, test_trigger): def test_update_trigger(self, client, test_user_auth_headers, test_trigger):
"""Test updating a trigger.""" """Test updating a trigger."""
response = client.put( response = client.put(
f"/api/triggers/{test_trigger.id}", f"/api/triggers/{test_trigger.id}",
headers={"Authorization": f"Bearer {test_user_token}"}, headers=test_user_auth_headers,
json={ json={
"name": "Updated Trigger", "name": "Updated Trigger",
"is_active": False, "is_active": False,
@@ -579,11 +595,11 @@ class TestTriggerAPI:
assert data["name"] == "Updated Trigger" assert data["name"] == "Updated Trigger"
assert data["is_active"] is False assert data["is_active"] is False
def test_delete_trigger(self, client, test_user_token, test_trigger): def test_delete_trigger(self, client, test_user_auth_headers, test_trigger, test_user_token):
"""Test deleting a trigger.""" """Test deleting a trigger."""
response = client.delete( response = client.delete(
f"/api/triggers/{test_trigger.id}", f"/api/triggers/{test_trigger.id}",
headers={"Authorization": f"Bearer {test_user_token}"}, headers=test_user_auth_headers,
) )
assert response.status_code == 204 assert response.status_code == 204
@@ -616,11 +632,11 @@ class TestTriggerAPI:
data = response.json() data = response.json()
assert data["total"] >= 1 assert data["total"] >= 1
def test_create_trigger_invalid_field(self, client, test_user_token, test_project): def test_create_trigger_invalid_field(self, client, test_user_auth_headers, test_project):
"""Test creating a trigger with invalid field.""" """Test creating a trigger with invalid field."""
response = client.post( response = client.post(
f"/api/projects/{test_project.id}/triggers", f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"}, headers=test_user_auth_headers,
json={ json={
"name": "Invalid Trigger", "name": "Invalid Trigger",
"trigger_type": "field_change", "trigger_type": "field_change",
@@ -636,11 +652,11 @@ class TestTriggerAPI:
assert response.status_code == 400 assert response.status_code == 400
assert "Invalid condition field" in response.json()["detail"] assert "Invalid condition field" in response.json()["detail"]
def test_create_trigger_invalid_operator(self, client, test_user_token, test_project): def test_create_trigger_invalid_operator(self, client, test_user_auth_headers, test_project):
"""Test creating a trigger with invalid operator.""" """Test creating a trigger with invalid operator."""
response = client.post( response = client.post(
f"/api/projects/{test_project.id}/triggers", f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"}, headers=test_user_auth_headers,
json={ json={
"name": "Invalid Trigger", "name": "Invalid Trigger",
"trigger_type": "field_change", "trigger_type": "field_change",

View File

@@ -1,6 +1,7 @@
import pytest import pytest
from app.models.user import User from app.models.user import User
from app.models.department import Department from app.models.department import Department
from app.core.security import generate_csrf_token
class TestUserEndpoints: class TestUserEndpoints:
@@ -35,7 +36,7 @@ class TestUserEndpoints:
) )
assert response.status_code == 404 assert response.status_code == 404
def test_update_user(self, client, admin_token, db): def test_update_user(self, client, auth_headers, db):
"""Test updating a user.""" """Test updating a user."""
# Create a test user # Create a test user
test_user = User( test_user = User(
@@ -49,7 +50,7 @@ class TestUserEndpoints:
response = client.patch( response = client.patch(
"/api/users/test-user-001", "/api/users/test-user-001",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={"name": "Updated Name"}, json={"name": "Updated Name"},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -84,9 +85,10 @@ class TestUserEndpoints:
mock_redis.setex("session:non-admin-001", 900, token) mock_redis.setex("session:non-admin-001", 900, token)
# Try to modify system admin - should fail with 403 # Try to modify system admin - should fail with 403
csrf_token = generate_csrf_token("non-admin-001")
response = client.patch( response = client.patch(
"/api/users/00000000-0000-0000-0000-000000000001", "/api/users/00000000-0000-0000-0000-000000000001",
headers={"Authorization": f"Bearer {token}"}, headers={"Authorization": f"Bearer {token}", "X-CSRF-Token": csrf_token},
json={"name": "Hacked Name"}, json={"name": "Hacked Name"},
) )
# Engineer role doesn't have users.write permission # Engineer role doesn't have users.write permission
@@ -123,16 +125,17 @@ class TestCapacityUpdate:
mock_redis.setex("session:capacity-user-001", 900, token) mock_redis.setex("session:capacity-user-001", 900, token)
# Update own capacity # Update own capacity
csrf_token = generate_csrf_token("capacity-user-001")
response = client.put( response = client.put(
"/api/users/capacity-user-001/capacity", "/api/users/capacity-user-001/capacity",
headers={"Authorization": f"Bearer {token}"}, headers={"Authorization": f"Bearer {token}", "X-CSRF-Token": csrf_token},
json={"capacity_hours": 35.5}, json={"capacity_hours": 35.5},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert float(data["capacity"]) == 35.5 assert float(data["capacity"]) == 35.5
def test_admin_can_update_other_user_capacity(self, client, admin_token, db): def test_admin_can_update_other_user_capacity(self, client, auth_headers, db):
"""Test that admin can update another user's capacity.""" """Test that admin can update another user's capacity."""
# Create a test user # Create a test user
test_user = User( test_user = User(
@@ -148,7 +151,7 @@ class TestCapacityUpdate:
# Admin updates another user's capacity # Admin updates another user's capacity
response = client.put( response = client.put(
"/api/users/capacity-user-002/capacity", "/api/users/capacity-user-002/capacity",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={"capacity_hours": 20.0}, json={"capacity_hours": 20.0},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -189,15 +192,16 @@ class TestCapacityUpdate:
mock_redis.setex("session:capacity-user-003", 900, token) mock_redis.setex("session:capacity-user-003", 900, token)
# User1 tries to update user2's capacity - should fail # User1 tries to update user2's capacity - should fail
csrf_token = generate_csrf_token("capacity-user-003")
response = client.put( response = client.put(
"/api/users/capacity-user-004/capacity", "/api/users/capacity-user-004/capacity",
headers={"Authorization": f"Bearer {token}"}, headers={"Authorization": f"Bearer {token}", "X-CSRF-Token": csrf_token},
json={"capacity_hours": 30.0}, json={"capacity_hours": 30.0},
) )
assert response.status_code == 403 assert response.status_code == 403
assert "Only admin, manager, or the user themselves" in response.json()["detail"] assert "Only admin, manager, or the user themselves" in response.json()["detail"]
def test_update_capacity_invalid_value_negative(self, client, admin_token, db): def test_update_capacity_invalid_value_negative(self, client, auth_headers, db):
"""Test that negative capacity hours are rejected.""" """Test that negative capacity hours are rejected."""
# Create a test user # Create a test user
test_user = User( test_user = User(
@@ -212,7 +216,7 @@ class TestCapacityUpdate:
response = client.put( response = client.put(
"/api/users/capacity-user-005/capacity", "/api/users/capacity-user-005/capacity",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={"capacity_hours": -5.0}, json={"capacity_hours": -5.0},
) )
# Pydantic validation returns 422 Unprocessable Entity # Pydantic validation returns 422 Unprocessable Entity
@@ -221,7 +225,7 @@ class TestCapacityUpdate:
# Check validation error message in Pydantic format # Check validation error message in Pydantic format
assert any("non-negative" in str(err).lower() for err in error_detail) assert any("non-negative" in str(err).lower() for err in error_detail)
def test_update_capacity_invalid_value_too_high(self, client, admin_token, db): def test_update_capacity_invalid_value_too_high(self, client, auth_headers, db):
"""Test that capacity hours exceeding 168 are rejected.""" """Test that capacity hours exceeding 168 are rejected."""
# Create a test user # Create a test user
test_user = User( test_user = User(
@@ -236,7 +240,7 @@ class TestCapacityUpdate:
response = client.put( response = client.put(
"/api/users/capacity-user-006/capacity", "/api/users/capacity-user-006/capacity",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={"capacity_hours": 200.0}, json={"capacity_hours": 200.0},
) )
# Pydantic validation returns 422 Unprocessable Entity # Pydantic validation returns 422 Unprocessable Entity
@@ -245,11 +249,11 @@ class TestCapacityUpdate:
# Check validation error message in Pydantic format # Check validation error message in Pydantic format
assert any("168" in str(err) for err in error_detail) assert any("168" in str(err) for err in error_detail)
def test_update_capacity_nonexistent_user(self, client, admin_token): def test_update_capacity_nonexistent_user(self, client, auth_headers):
"""Test updating capacity for a nonexistent user.""" """Test updating capacity for a nonexistent user."""
response = client.put( response = client.put(
"/api/users/nonexistent-user-id/capacity", "/api/users/nonexistent-user-id/capacity",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={"capacity_hours": 40.0}, json={"capacity_hours": 40.0},
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -303,16 +307,17 @@ class TestCapacityUpdate:
mock_redis.setex("session:manager-cap-001", 900, token) mock_redis.setex("session:manager-cap-001", 900, token)
# Manager updates regular user's capacity # Manager updates regular user's capacity
csrf_token = generate_csrf_token("manager-cap-001")
response = client.put( response = client.put(
"/api/users/regular-cap-001/capacity", "/api/users/regular-cap-001/capacity",
headers={"Authorization": f"Bearer {token}"}, headers={"Authorization": f"Bearer {token}", "X-CSRF-Token": csrf_token},
json={"capacity_hours": 30.0}, json={"capacity_hours": 30.0},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert float(data["capacity"]) == 30.0 assert float(data["capacity"]) == 30.0
def test_capacity_change_creates_audit_log(self, client, admin_token, db): def test_capacity_change_creates_audit_log(self, client, auth_headers, db):
"""Test that capacity changes are recorded in audit trail.""" """Test that capacity changes are recorded in audit trail."""
from app.models import AuditLog from app.models import AuditLog
@@ -330,7 +335,7 @@ class TestCapacityUpdate:
# Update capacity # Update capacity
response = client.put( response = client.put(
"/api/users/capacity-audit-001/capacity", "/api/users/capacity-audit-001/capacity",
headers={"Authorization": f"Bearer {admin_token}"}, headers=auth_headers,
json={"capacity_hours": 35.0}, json={"capacity_hours": 35.0},
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -449,7 +449,7 @@ class TestWorkloadAPI:
def test_unauthorized_access(self, client, db): def test_unauthorized_access(self, client, db):
"""Unauthenticated requests should fail.""" """Unauthenticated requests should fail."""
response = client.get("/api/workload/heatmap") response = client.get("/api/workload/heatmap")
assert response.status_code == 403 # No auth header assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized
class TestWorkloadAccessControl: class TestWorkloadAccessControl:

View File

@@ -19,7 +19,8 @@
"emailRequired": "Email is required", "emailRequired": "Email is required",
"passwordRequired": "Password is required", "passwordRequired": "Password is required",
"invalidEmail": "Please enter a valid email address", "invalidEmail": "Please enter a valid email address",
"loginFailed": "Login failed. Please try again later." "loginFailed": "Login failed. Please try again later.",
"sessionExpired": "Your session has expired. Please sign in again."
}, },
"welcome": { "welcome": {
"title": "Project Control Center", "title": "Project Control Center",

View File

@@ -129,5 +129,11 @@
"message": "Unable to display this widget.", "message": "Unable to display this widget.",
"errorSuffix": "error" "errorSuffix": "error"
} }
},
"attachments": {
"dropzone": "Drop files here or click to upload",
"maxFileSize": "Maximum file size: {{size}}",
"uploading": "Uploading {{filename}} ({{current}}/{{total}})...",
"uploadFailed": "Upload failed"
} }
} }

View File

@@ -19,7 +19,8 @@
"emailRequired": "請輸入電子郵件", "emailRequired": "請輸入電子郵件",
"passwordRequired": "請輸入密碼", "passwordRequired": "請輸入密碼",
"invalidEmail": "請輸入有效的電子郵件地址", "invalidEmail": "請輸入有效的電子郵件地址",
"loginFailed": "登入失敗,請稍後再試" "loginFailed": "登入失敗,請稍後再試",
"sessionExpired": "您的登入時段已過期,請重新登入。"
}, },
"welcome": { "welcome": {
"title": "專案控制中心", "title": "專案控制中心",

View File

@@ -129,5 +129,11 @@
"message": "無法顯示此元件。", "message": "無法顯示此元件。",
"errorSuffix": "發生錯誤" "errorSuffix": "發生錯誤"
} }
},
"attachments": {
"dropzone": "拖曳檔案至此或點擊上傳",
"maxFileSize": "檔案大小上限:{{size}}",
"uploading": "正在上傳 {{filename}} ({{current}}/{{total}})...",
"uploadFailed": "上傳失敗"
} }
} }

View File

@@ -1,21 +1,35 @@
import { lazy, Suspense } from 'react'
import { Routes, Route, Navigate } from 'react-router-dom' import { Routes, Route, Navigate } from 'react-router-dom'
import { useAuth } from './contexts/AuthContext' import { useAuth } from './contexts/AuthContext'
import { Skeleton } from './components/Skeleton' import { Skeleton } from './components/Skeleton'
import { ErrorBoundary } from './components/ErrorBoundary' import { ErrorBoundary } from './components/ErrorBoundary'
import { SectionErrorBoundary } from './components/ErrorBoundaryWithI18n' import { SectionErrorBoundary } from './components/ErrorBoundaryWithI18n'
import Login from './pages/Login'
import Dashboard from './pages/Dashboard'
import Spaces from './pages/Spaces'
import Projects from './pages/Projects'
import Tasks from './pages/Tasks'
import ProjectSettings from './pages/ProjectSettings'
import MySettings from './pages/MySettings'
import AuditPage from './pages/AuditPage'
import WorkloadPage from './pages/WorkloadPage'
import ProjectHealthPage from './pages/ProjectHealthPage'
import ProtectedRoute from './components/ProtectedRoute' import ProtectedRoute from './components/ProtectedRoute'
import Layout from './components/Layout' import Layout from './components/Layout'
// Lazy load pages for code splitting
const Login = lazy(() => import('./pages/Login'))
const Dashboard = lazy(() => import('./pages/Dashboard'))
const Spaces = lazy(() => import('./pages/Spaces'))
const Projects = lazy(() => import('./pages/Projects'))
const Tasks = lazy(() => import('./pages/Tasks'))
const ProjectSettings = lazy(() => import('./pages/ProjectSettings'))
const MySettings = lazy(() => import('./pages/MySettings'))
const AuditPage = lazy(() => import('./pages/AuditPage'))
const WorkloadPage = lazy(() => import('./pages/WorkloadPage'))
const ProjectHealthPage = lazy(() => import('./pages/ProjectHealthPage'))
// Loading fallback component for Suspense
function PageLoadingFallback() {
return (
<div className="container" style={{ padding: '24px', maxWidth: '1200px', margin: '0 auto' }}>
<Skeleton variant="text" width={200} height={32} style={{ marginBottom: '16px' }} />
<Skeleton variant="rect" width="100%" height={200} style={{ marginBottom: '16px' }} />
<Skeleton variant="rect" width="100%" height={150} />
</div>
)
}
function App() { function App() {
const { isAuthenticated, loading } = useAuth() const { isAuthenticated, loading } = useAuth()
@@ -30,120 +44,122 @@ function App() {
return ( return (
<ErrorBoundary variant="page"> <ErrorBoundary variant="page">
<Routes> <Suspense fallback={<PageLoadingFallback />}>
<Route <Routes>
path="/login" <Route
element={isAuthenticated ? <Navigate to="/" /> : <Login />} path="/login"
/> element={isAuthenticated ? <Navigate to="/" /> : <Login />}
<Route />
path="/" <Route
element={ path="/"
<ProtectedRoute> element={
<Layout> <ProtectedRoute>
<SectionErrorBoundary sectionName="Dashboard"> <Layout>
<Dashboard /> <SectionErrorBoundary sectionName="Dashboard">
</SectionErrorBoundary> <Dashboard />
</Layout> </SectionErrorBoundary>
</ProtectedRoute> </Layout>
} </ProtectedRoute>
/> }
<Route />
path="/spaces" <Route
element={ path="/spaces"
<ProtectedRoute> element={
<Layout> <ProtectedRoute>
<SectionErrorBoundary sectionName="Spaces"> <Layout>
<Spaces /> <SectionErrorBoundary sectionName="Spaces">
</SectionErrorBoundary> <Spaces />
</Layout> </SectionErrorBoundary>
</ProtectedRoute> </Layout>
} </ProtectedRoute>
/> }
<Route />
path="/spaces/:spaceId" <Route
element={ path="/spaces/:spaceId"
<ProtectedRoute> element={
<Layout> <ProtectedRoute>
<SectionErrorBoundary sectionName="Projects"> <Layout>
<Projects /> <SectionErrorBoundary sectionName="Projects">
</SectionErrorBoundary> <Projects />
</Layout> </SectionErrorBoundary>
</ProtectedRoute> </Layout>
} </ProtectedRoute>
/> }
<Route />
path="/projects/:projectId" <Route
element={ path="/projects/:projectId"
<ProtectedRoute> element={
<Layout> <ProtectedRoute>
<SectionErrorBoundary sectionName="Tasks"> <Layout>
<Tasks /> <SectionErrorBoundary sectionName="Tasks">
</SectionErrorBoundary> <Tasks />
</Layout> </SectionErrorBoundary>
</ProtectedRoute> </Layout>
} </ProtectedRoute>
/> }
<Route />
path="/projects/:projectId/settings" <Route
element={ path="/projects/:projectId/settings"
<ProtectedRoute> element={
<Layout> <ProtectedRoute>
<SectionErrorBoundary sectionName="Project Settings"> <Layout>
<ProjectSettings /> <SectionErrorBoundary sectionName="Project Settings">
</SectionErrorBoundary> <ProjectSettings />
</Layout> </SectionErrorBoundary>
</ProtectedRoute> </Layout>
} </ProtectedRoute>
/> }
<Route />
path="/audit" <Route
element={ path="/audit"
<ProtectedRoute> element={
<Layout> <ProtectedRoute>
<SectionErrorBoundary sectionName="Audit"> <Layout>
<AuditPage /> <SectionErrorBoundary sectionName="Audit">
</SectionErrorBoundary> <AuditPage />
</Layout> </SectionErrorBoundary>
</ProtectedRoute> </Layout>
} </ProtectedRoute>
/> }
<Route />
path="/workload" <Route
element={ path="/workload"
<ProtectedRoute> element={
<Layout> <ProtectedRoute>
<SectionErrorBoundary sectionName="Workload"> <Layout>
<WorkloadPage /> <SectionErrorBoundary sectionName="Workload">
</SectionErrorBoundary> <WorkloadPage />
</Layout> </SectionErrorBoundary>
</ProtectedRoute> </Layout>
} </ProtectedRoute>
/> }
<Route />
path="/project-health" <Route
element={ path="/project-health"
<ProtectedRoute> element={
<Layout> <ProtectedRoute>
<SectionErrorBoundary sectionName="Project Health"> <Layout>
<ProjectHealthPage /> <SectionErrorBoundary sectionName="Project Health">
</SectionErrorBoundary> <ProjectHealthPage />
</Layout> </SectionErrorBoundary>
</ProtectedRoute> </Layout>
} </ProtectedRoute>
/> }
<Route />
path="/my-settings" <Route
element={ path="/my-settings"
<ProtectedRoute> element={
<Layout> <ProtectedRoute>
<SectionErrorBoundary sectionName="Settings"> <Layout>
<MySettings /> <SectionErrorBoundary sectionName="Settings">
</SectionErrorBoundary> <MySettings />
</Layout> </SectionErrorBoundary>
</ProtectedRoute> </Layout>
} </ProtectedRoute>
/> }
</Routes> />
</Routes>
</Suspense>
</ErrorBoundary> </ErrorBoundary>
) )
} }

View File

@@ -1,4 +1,5 @@
import { useState, useRef, useEffect, DragEvent, ChangeEvent } from 'react' import { useState, useRef, useEffect, DragEvent, ChangeEvent } from 'react'
import { useTranslation } from 'react-i18next'
import { attachmentService } from '../services/attachments' import { attachmentService } from '../services/attachments'
// Spinner animation keyframes - injected once via useEffect // Spinner animation keyframes - injected once via useEffect
@@ -10,6 +11,7 @@ interface AttachmentUploadProps {
} }
export function AttachmentUpload({ taskId, onUploadComplete }: AttachmentUploadProps) { export function AttachmentUpload({ taskId, onUploadComplete }: AttachmentUploadProps) {
const { t } = useTranslation('common')
const [isDragging, setIsDragging] = useState(false) const [isDragging, setIsDragging] = useState(false)
const [uploading, setUploading] = useState(false) const [uploading, setUploading] = useState(false)
const [uploadProgress, setUploadProgress] = useState<string | null>(null) const [uploadProgress, setUploadProgress] = useState<string | null>(null)
@@ -79,14 +81,20 @@ export function AttachmentUpload({ taskId, onUploadComplete }: AttachmentUploadP
try { try {
for (let i = 0; i < files.length; i++) { for (let i = 0; i < files.length; i++) {
const file = files[i] const file = files[i]
setUploadProgress(`Uploading ${file.name} (${i + 1}/${files.length})...`) setUploadProgress(
t('attachments.uploading', {
filename: file.name,
current: i + 1,
total: files.length,
})
)
await attachmentService.uploadAttachment(taskId, file) await attachmentService.uploadAttachment(taskId, file)
} }
setUploadProgress(null) setUploadProgress(null)
onUploadComplete?.() onUploadComplete?.()
} catch (err: unknown) { } catch (err: unknown) {
console.error('Upload failed:', err) console.error('Upload failed:', err)
const errorMessage = err instanceof Error ? err.message : 'Upload failed' const errorMessage = err instanceof Error ? err.message : t('attachments.uploadFailed')
setError(errorMessage) setError(errorMessage)
} finally { } finally {
setUploading(false) setUploading(false)
@@ -127,10 +135,10 @@ export function AttachmentUpload({ taskId, onUploadComplete }: AttachmentUploadP
<div style={styles.content}> <div style={styles.content}>
<span style={styles.icon}>📎</span> <span style={styles.icon}>📎</span>
<span style={styles.text}> <span style={styles.text}>
Drop files here or click to upload {t('attachments.dropzone')}
</span> </span>
<span style={styles.hint}> <span style={styles.hint}>
Maximum file size: 50MB {t('attachments.maxFileSize', { size: '50MB' })}
</span> </span>
</div> </div>
)} )}

View File

@@ -35,7 +35,7 @@ export function Comments({ taskId }: CommentsProps) {
} finally { } finally {
setLoading(false) setLoading(false)
} }
}, [taskId]) }, [taskId, t])
useEffect(() => { useEffect(() => {
fetchComments() fetchComments()

View File

@@ -1,5 +1,58 @@
import { createContext, useContext, useState, useEffect, ReactNode } from 'react' import { createContext, useContext, useState, useEffect, ReactNode } from 'react'
import { authApi, User, LoginRequest } from '../services/api' import {
authApi,
User,
LoginRequest,
storeTokens,
clearStoredTokens,
getStoredToken,
isTokenExpired,
} from '../services/api'
/**
* Validates that a parsed object has the required User properties.
* Returns the validated User object or null if validation fails.
*/
function validateUserData(data: unknown): User | null {
// Check if data is an object
if (!data || typeof data !== 'object') {
return null
}
const obj = data as Record<string, unknown>
// Validate required string fields
if (typeof obj.id !== 'string' || obj.id.length === 0) {
return null
}
if (typeof obj.email !== 'string' || obj.email.length === 0) {
return null
}
if (typeof obj.name !== 'string' || obj.name.length === 0) {
return null
}
// Validate optional/nullable fields
if (obj.role !== null && typeof obj.role !== 'string') {
return null
}
if (obj.department_id !== null && typeof obj.department_id !== 'string') {
return null
}
if (typeof obj.is_system_admin !== 'boolean') {
return null
}
// Return validated user object
return {
id: obj.id,
email: obj.email,
name: obj.name,
role: obj.role as string | null,
department_id: obj.department_id as string | null,
is_system_admin: obj.is_system_admin,
}
}
interface AuthContextType { interface AuthContextType {
user: User | null user: User | null
@@ -17,15 +70,35 @@ export function AuthProvider({ children }: { children: ReactNode }) {
useEffect(() => { useEffect(() => {
// Check for existing token on mount // Check for existing token on mount
const token = localStorage.getItem('token') const token = getStoredToken()
const storedUser = localStorage.getItem('user') const storedUser = localStorage.getItem('user')
if (token && storedUser) { if (token && storedUser) {
try { try {
setUser(JSON.parse(storedUser)) // Check if token is expired
} catch { if (isTokenExpired(token)) {
localStorage.removeItem('token') // Token is expired, clear storage and don't restore user
localStorage.removeItem('user') // The refresh will happen automatically on next API call if refresh token exists
clearStoredTokens()
} else {
// Parse and validate stored user data
const parsedUser = JSON.parse(storedUser)
const validatedUser = validateUserData(parsedUser)
if (validatedUser) {
setUser(validatedUser)
} else {
// Invalid user data structure, clear storage and redirect to login
console.warn('Invalid user data in localStorage, clearing session')
clearStoredTokens()
// Don't redirect here as we're in initial loading state
// The app will naturally show login page when user is null
}
}
} catch (err) {
// JSON parse error or other unexpected error
console.error('Error parsing stored user data:', err)
clearStoredTokens()
} }
} }
setLoading(false) setLoading(false)
@@ -33,7 +106,8 @@ export function AuthProvider({ children }: { children: ReactNode }) {
const login = async (data: LoginRequest) => { const login = async (data: LoginRequest) => {
const response = await authApi.login(data) const response = await authApi.login(data)
localStorage.setItem('token', response.access_token) // Store access token and refresh token (if provided by backend)
storeTokens(response.access_token, response.refresh_token)
localStorage.setItem('user', JSON.stringify(response.user)) localStorage.setItem('user', JSON.stringify(response.user))
setUser(response.user) setUser(response.user)
} }
@@ -44,8 +118,8 @@ export function AuthProvider({ children }: { children: ReactNode }) {
} catch { } catch {
// Ignore errors on logout // Ignore errors on logout
} finally { } finally {
localStorage.removeItem('token') // Clear all tokens (access, refresh, and user data)
localStorage.removeItem('user') clearStoredTokens()
setUser(null) setUser(null)
} }
} }

View File

@@ -29,7 +29,7 @@ export default function Dashboard() {
} finally { } finally {
setLoading(false) setLoading(false)
} }
}, []) }, [t])
useEffect(() => { useEffect(() => {
fetchDashboard() fetchDashboard()

View File

@@ -1,5 +1,5 @@
import { useState, FormEvent } from 'react' import { useState, useEffect, FormEvent } from 'react'
import { useNavigate } from 'react-router-dom' import { useNavigate, useSearchParams } from 'react-router-dom'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useAuth } from '../contexts/AuthContext' import { useAuth } from '../contexts/AuthContext'
import { LanguageSwitcher } from '../components/LanguageSwitcher' import { LanguageSwitcher } from '../components/LanguageSwitcher'
@@ -9,9 +9,21 @@ export default function Login() {
const [email, setEmail] = useState('') const [email, setEmail] = useState('')
const [password, setPassword] = useState('') const [password, setPassword] = useState('')
const [error, setError] = useState('') const [error, setError] = useState('')
const [info, setInfo] = useState('')
const [loading, setLoading] = useState(false) const [loading, setLoading] = useState(false)
const { login } = useAuth() const { login } = useAuth()
const navigate = useNavigate() const navigate = useNavigate()
const [searchParams, setSearchParams] = useSearchParams()
// Check for session expired redirect
useEffect(() => {
const reason = searchParams.get('reason')
if (reason === 'session_expired') {
setInfo(t('errors.sessionExpired'))
// Clean up the URL by removing the query parameter
setSearchParams({}, { replace: true })
}
}, [searchParams, setSearchParams, t])
const handleSubmit = async (e: FormEvent) => { const handleSubmit = async (e: FormEvent) => {
e.preventDefault() e.preventDefault()
@@ -45,6 +57,7 @@ export default function Login() {
<p style={styles.subtitle}>{t('login.subtitle')}</p> <p style={styles.subtitle}>{t('login.subtitle')}</p>
<form onSubmit={handleSubmit} style={styles.form}> <form onSubmit={handleSubmit} style={styles.form}>
{info && <div style={styles.info}>{info}</div>}
{error && <div style={styles.error}>{error}</div>} {error && <div style={styles.error}>{error}</div>}
<div style={styles.field}> <div style={styles.field}>
@@ -163,4 +176,11 @@ const styles: { [key: string]: React.CSSProperties } = {
borderRadius: '4px', borderRadius: '4px',
fontSize: '14px', fontSize: '14px',
}, },
info: {
backgroundColor: '#e6f4ff',
color: '#0066cc',
padding: '10px',
borderRadius: '4px',
fontSize: '14px',
},
} }

View File

@@ -1,4 +1,4 @@
import axios, { InternalAxiosRequestConfig } from 'axios' import axios, { InternalAxiosRequestConfig, AxiosError } from 'axios'
// API base URL - using legacy routes until v1 migration is complete // API base URL - using legacy routes until v1 migration is complete
// TODO: Switch to /api/v1 when all routes are migrated // TODO: Switch to /api/v1 when all routes are migrated
@@ -9,10 +9,141 @@ const API_BASE_URL = '/api'
let csrfToken: string | null = null let csrfToken: string | null = null
let csrfTokenExpiry: number | null = null let csrfTokenExpiry: number | null = null
const CSRF_TOKEN_HEADER = 'X-CSRF-Token' const CSRF_TOKEN_HEADER = 'X-CSRF-Token'
const CSRF_PROTECTED_METHODS = ['DELETE', 'PUT', 'PATCH'] const CSRF_PROTECTED_METHODS = ['POST', 'DELETE', 'PUT', 'PATCH']
// Token expires in 1 hour, refresh 5 minutes before expiry // Token expires in 1 hour, refresh 5 minutes before expiry
const CSRF_TOKEN_LIFETIME_MS = 55 * 60 * 1000 const CSRF_TOKEN_LIFETIME_MS = 55 * 60 * 1000
// JWT Token refresh configuration
// Access tokens expire in 60 minutes, refresh 5 minutes before expiry
const TOKEN_REFRESH_THRESHOLD_MS = 5 * 60 * 1000
// Token refresh state management
let isRefreshing = false
let refreshSubscribers: Array<(token: string) => void> = []
/**
* JWT Token Utilities
*/
/**
* Decode a JWT token payload without verification.
* Note: This is for reading claims only, not for security validation.
* Security validation happens on the backend.
*/
export function decodeJwtPayload(token: string): JwtPayload | null {
try {
const parts = token.split('.')
if (parts.length !== 3) {
return null
}
// Decode base64url to base64
const base64 = parts[1].replace(/-/g, '+').replace(/_/g, '/')
// Add padding if needed
const padded = base64 + '='.repeat((4 - (base64.length % 4)) % 4)
const decoded = atob(padded)
return JSON.parse(decoded)
} catch {
return null
}
}
/**
* Get the expiration time (in milliseconds since epoch) from a JWT token.
*/
export function getTokenExpiryTime(token: string): number | null {
const payload = decodeJwtPayload(token)
if (!payload || typeof payload.exp !== 'number') {
return null
}
// JWT exp is in seconds, convert to milliseconds
return payload.exp * 1000
}
/**
* Check if a token is about to expire (within threshold).
* Returns true if token will expire within the threshold or has already expired.
*/
export function isTokenExpiringSoon(
token: string,
thresholdMs: number = TOKEN_REFRESH_THRESHOLD_MS
): boolean {
const expiryTime = getTokenExpiryTime(token)
if (expiryTime === null) {
// If we can't determine expiry, assume it needs refresh
return true
}
return Date.now() >= expiryTime - thresholdMs
}
/**
* Check if a token has already expired.
*/
export function isTokenExpired(token: string): boolean {
const expiryTime = getTokenExpiryTime(token)
if (expiryTime === null) {
return true
}
return Date.now() >= expiryTime
}
interface JwtPayload {
sub: string
email: string
role?: string | null
department_id?: string | null
is_system_admin?: boolean
exp: number
iat: number
}
/**
* Token Storage Utilities
* Note: Using localStorage for token storage. While httpOnly cookies are more
* secure against XSS attacks, localStorage is acceptable for this implementation
* as long as proper XSS protections are in place (Content Security Policy, etc.).
* The refresh token mechanism limits exposure time if a token is compromised.
*/
const TOKEN_KEY = 'token'
const REFRESH_TOKEN_KEY = 'refresh_token'
const USER_KEY = 'user'
export function getStoredToken(): string | null {
return localStorage.getItem(TOKEN_KEY)
}
export function getStoredRefreshToken(): string | null {
return localStorage.getItem(REFRESH_TOKEN_KEY)
}
export function storeTokens(accessToken: string, refreshToken?: string): void {
localStorage.setItem(TOKEN_KEY, accessToken)
if (refreshToken) {
localStorage.setItem(REFRESH_TOKEN_KEY, refreshToken)
}
}
export function clearStoredTokens(): void {
localStorage.removeItem(TOKEN_KEY)
localStorage.removeItem(REFRESH_TOKEN_KEY)
localStorage.removeItem(USER_KEY)
}
/**
* Subscribe to token refresh completion.
* Used to queue requests while a refresh is in progress.
*/
function subscribeToTokenRefresh(callback: (token: string) => void): void {
refreshSubscribers.push(callback)
}
/**
* Notify all subscribers that token has been refreshed.
*/
function onTokenRefreshed(newToken: string): void {
refreshSubscribers.forEach((callback) => callback(newToken))
refreshSubscribers = []
}
const api = axios.create({ const api = axios.create({
baseURL: API_BASE_URL, baseURL: API_BASE_URL,
headers: { headers: {
@@ -77,33 +208,149 @@ export async function prefetchCsrfToken(): Promise<void> {
await fetchCsrfToken() await fetchCsrfToken()
} }
// Add token to requests and CSRF token for protected methods /**
api.interceptors.request.use(async (config: InternalAxiosRequestConfig) => { * Refresh the access token using the refresh token.
const token = localStorage.getItem('token') * This is called automatically when the access token is about to expire.
if (token) { *
config.headers.Authorization = `Bearer ${token}` * @returns The new access token, or null if refresh failed
*/
async function refreshAccessToken(): Promise<string | null> {
const refreshToken = getStoredRefreshToken()
if (!refreshToken) {
return null
}
// Add CSRF token for protected methods try {
const method = config.method?.toUpperCase() // Use axios directly to avoid interceptor loops
if (method && CSRF_PROTECTED_METHODS.includes(method)) { const response = await axios.post<{
const csrf = await getValidCsrfToken() access_token: string
if (csrf) { refresh_token?: string
config.headers[CSRF_TOKEN_HEADER] = csrf token_type: string
}>(
`${API_BASE_URL}/auth/refresh`,
{ refresh_token: refreshToken },
{
headers: {
'Content-Type': 'application/json',
},
}
)
const { access_token, refresh_token: newRefreshToken } = response.data
// Store the new tokens
storeTokens(access_token, newRefreshToken || refreshToken)
return access_token
} catch (error) {
// If refresh fails (401 or other error), the token is invalid
// Clear all tokens and let the response interceptor handle redirect
return null
}
}
/**
* Ensure we have a valid access token, refreshing if necessary.
* This implements a queue mechanism to prevent multiple simultaneous refresh requests.
*
* @returns A promise that resolves with a valid token or null if unavailable
*/
async function ensureValidToken(): Promise<string | null> {
const token = getStoredToken()
if (!token) {
return null
}
// If token is not expiring soon, use it as-is
if (!isTokenExpiringSoon(token)) {
return token
}
// If we're already refreshing, wait for it to complete
if (isRefreshing) {
return new Promise((resolve) => {
subscribeToTokenRefresh((newToken) => {
resolve(newToken)
})
})
}
// Start the refresh process
isRefreshing = true
try {
const newToken = await refreshAccessToken()
if (newToken) {
onTokenRefreshed(newToken)
return newToken
} else {
// Refresh failed - clear tokens and redirect to login
clearStoredTokens()
clearCsrfToken()
// Notify subscribers with empty token (they'll fail but won't retry)
refreshSubscribers = []
window.location.href = '/login?reason=session_expired'
return null
}
} finally {
isRefreshing = false
}
}
// Add token to requests and CSRF token for protected methods
// This interceptor ensures tokens are refreshed before they expire
api.interceptors.request.use(async (config: InternalAxiosRequestConfig) => {
// Skip token handling for auth endpoints that don't require authentication
const isAuthEndpoint =
config.url?.includes('/auth/login') || config.url?.includes('/auth/refresh')
if (!isAuthEndpoint) {
// Ensure we have a valid token (will refresh if expiring soon)
const token = await ensureValidToken()
if (token) {
config.headers.Authorization = `Bearer ${token}`
// Add CSRF token for protected methods
const method = config.method?.toUpperCase()
if (method && CSRF_PROTECTED_METHODS.includes(method)) {
const csrf = await getValidCsrfToken()
if (csrf) {
config.headers[CSRF_TOKEN_HEADER] = csrf
}
} }
} }
} }
return config return config
}) })
// Handle 401 responses // Handle 401 responses - clear tokens and redirect to login
// Note: Token refresh is handled proactively in the request interceptor
// A 401 here means either:
// 1. The token was revoked on the server
// 2. The refresh token has expired
// 3. Some other authentication issue
api.interceptors.response.use( api.interceptors.response.use(
(response) => response, (response) => response,
(error) => { (error: AxiosError) => {
if (error.response?.status === 401) { if (error.response?.status === 401) {
localStorage.removeItem('token') // Check if this is from a refresh endpoint to avoid redirect loops
localStorage.removeItem('user') const isRefreshRequest = error.config?.url?.includes('/auth/refresh')
clearCsrfToken()
window.location.href = '/login' if (!isRefreshRequest) {
// Clear all auth state
clearStoredTokens()
clearCsrfToken()
// Redirect to login with appropriate message
const currentPath = window.location.pathname
if (currentPath !== '/login') {
window.location.href = '/login?reason=session_expired'
}
}
} }
return Promise.reject(error) return Promise.reject(error)
} }
@@ -125,10 +372,17 @@ export interface User {
export interface LoginResponse { export interface LoginResponse {
access_token: string access_token: string
refresh_token?: string // Optional for backward compatibility during migration
token_type: string token_type: string
user: User user: User
} }
export interface RefreshTokenResponse {
access_token: string
refresh_token?: string // New refresh token if rotation is enabled
token_type: string
}
export const authApi = { export const authApi = {
login: async (data: LoginRequest): Promise<LoginResponse> => { login: async (data: LoginRequest): Promise<LoginResponse> => {
const response = await api.post<LoginResponse>('/auth/login', data) const response = await api.post<LoginResponse>('/auth/login', data)

View File

@@ -98,6 +98,31 @@
- **WHEN** 異常行為發生 - **WHEN** 異常行為發生
- **THEN** 系統記錄並發送警示 - **THEN** 系統記錄並發送警示
### Requirement: Security Event Logging
The system SHALL record failed access attempts for security monitoring and intrusion detection.
#### Scenario: Permission denied logged
- **WHEN** server returns 403 Forbidden for a resource access attempt
- **THEN** audit log entry is created with event_type "security.access_denied"
- **AND** entry includes user_id, resource_type, and attempted_action
#### Scenario: Repeated auth failures logged
- **WHEN** same IP has 5+ failed authentication attempts in 10 minutes
- **THEN** audit log entry is created with event_type "security.suspicious_auth_pattern"
- **AND** entry includes IP address and failure count
- **AND** alert is generated for security administrators
### Requirement: Detailed Health Endpoint Security
The detailed system health endpoint SHALL require admin authentication to prevent information disclosure.
#### Scenario: Admin accesses detailed health
- **WHEN** system administrator requests GET /health/detailed
- **THEN** full system status including connection pools is returned
#### Scenario: Non-admin accesses detailed health
- **WHEN** non-admin user or unauthenticated request to GET /health/detailed
- **THEN** request is rejected with 401 Unauthorized or 403 Forbidden
## Data Model ## Data Model
``` ```

View File

@@ -161,6 +161,33 @@ The system SHALL support project templates to standardize project creation.
- **THEN** system creates template with project's CustomField definitions - **THEN** system creates template with project's CustomField definitions
- **THEN** template is available for future project creation - **THEN** template is available for future project creation
### Requirement: Code Splitting
The application SHALL use code splitting with React.lazy() to reduce initial bundle size and improve load times.
#### Scenario: Initial page load
- **WHEN** user navigates to application
- **THEN** only core framework and current route are loaded
- **AND** other routes are loaded on demand
#### Scenario: Route-based splitting
- **WHEN** user navigates to a different page
- **THEN** that page's code chunk is loaded dynamically
- **AND** loading fallback is displayed during load
### Requirement: LocalStorage Data Validation
User data loaded from localStorage SHALL be validated before use to prevent crashes from corrupted data.
#### Scenario: Corrupted localStorage data
- **WHEN** localStorage contains malformed user JSON
- **THEN** invalid data is cleared
- **AND** user is redirected to login page
- **AND** no application crash occurs
#### Scenario: Valid localStorage data
- **WHEN** localStorage contains valid user JSON
- **THEN** user is authenticated from stored data
- **AND** application loads normally
### Requirement: Error Boundary Protection ### Requirement: Error Boundary Protection
The frontend application SHALL gracefully handle component render errors without crashing the entire application. The frontend application SHALL gracefully handle component render errors without crashing the entire application.

View File

@@ -78,6 +78,26 @@
- **WHEN** 管理者嘗試新增第 21 個欄位 - **WHEN** 管理者嘗試新增第 21 個欄位
- **THEN** 系統拒絕新增並顯示數量已達上限的訊息 - **THEN** 系統拒絕新增並顯示數量已達上限的訊息
### Requirement: Input Length Validation
All text input fields SHALL have maximum length constraints to prevent abuse and database issues.
#### Scenario: Task title exceeds limit
- **WHEN** user creates a task with title exceeding 500 characters
- **THEN** request is rejected with 422 Validation Error
- **AND** error indicates field length exceeded
#### Scenario: Description within limit
- **WHEN** user creates a task with description 10000 characters or less
- **THEN** task is created successfully
#### Scenario: Description exceeds limit
- **WHEN** user creates a task with description exceeding 10000 characters
- **THEN** request is rejected with 422 Validation Error
#### Scenario: Comment content limit
- **WHEN** user submits a comment exceeding 5000 characters
- **THEN** request is rejected with 422 Validation Error
### Requirement: Multiple Views ### Requirement: Multiple Views
系統 SHALL 支援多維視角:看板 (Kanban)、甘特圖 (Gantt)、列表 (List)、行事曆 (Calendar)。 系統 SHALL 支援多維視角:看板 (Kanban)、甘特圖 (Gantt)、列表 (List)、行事曆 (Calendar)。

View File

@@ -89,6 +89,34 @@
- **WHEN** 使用者執行登出操作 - **WHEN** 使用者執行登出操作
- **THEN** 系統銷毀 session 並清除 token - **THEN** 系統銷毀 session 並清除 token
### Requirement: Access Token Expiry
Access tokens SHALL expire within 60 minutes to limit exposure window in case of token compromise.
#### Scenario: Access token expiry
- **WHEN** an access token issued 61 minutes ago is used for API authentication
- **THEN** request is rejected with 401 Unauthorized
- **AND** error indicates "Token expired"
### Requirement: Refresh Token Support
The system SHALL support refresh tokens for seamless session continuity without requiring re-authentication.
#### Scenario: Refresh valid token
- **WHEN** POST to /api/auth/refresh with valid refresh token
- **THEN** new access token is issued
- **AND** new refresh token is issued via rotation
- **AND** old refresh token is invalidated
#### Scenario: Refresh expired token
- **WHEN** POST to /api/auth/refresh with expired refresh token
- **THEN** request is rejected with 401 Unauthorized
- **AND** user must re-authenticate via login
#### Scenario: Automatic frontend refresh
- **WHEN** access token expires in less than 5 minutes
- **AND** frontend prepares to make API call
- **THEN** token is automatically refreshed first
- **AND** original request proceeds with new token
### Requirement: API Rate Limiting ### Requirement: API Rate Limiting
The system SHALL implement rate limiting to protect against brute force attacks and DoS attempts. The system SHALL implement rate limiting to protect against brute force attacks and DoS attempts.
@@ -143,8 +171,19 @@ The system SHALL enforce maximum length limits on all user-provided string input
- **WHEN** user submits content with description under 10000 characters - **WHEN** user submits content with description under 10000 characters
- **THEN** system accepts the input and processes normally - **THEN** system accepts the input and processes normally
### Requirement: CORS Security
The system SHALL explicitly define allowed CORS methods and headers instead of using wildcards to reduce attack surface.
#### Scenario: Request with standard headers
- **WHEN** a cross-origin request includes Content-Type, Authorization, or X-CSRF-Token headers
- **THEN** the request is allowed
#### Scenario: Request with non-standard header
- **WHEN** a cross-origin request includes a non-whitelisted custom header
- **THEN** CORS preflight fails and request is rejected
### Requirement: Secure WebSocket Authentication ### Requirement: Secure WebSocket Authentication
The system SHALL authenticate WebSocket connections without exposing tokens in URL query parameters. The system SHALL authenticate WebSocket connections without exposing tokens in URL query parameters. In production environments, query parameter authentication SHALL be disabled.
#### Scenario: WebSocket connection with token in first message #### Scenario: WebSocket connection with token in first message
- **WHEN** client connects to WebSocket endpoint without a query token - **WHEN** client connects to WebSocket endpoint without a query token
@@ -161,6 +200,25 @@ The system SHALL authenticate WebSocket connections without exposing tokens in U
- **WHEN** client connects but does not send authentication within 10 seconds - **WHEN** client connects but does not send authentication within 10 seconds
- **THEN** server closes the connection with appropriate error code - **THEN** server closes the connection with appropriate error code
#### Scenario: Query parameter auth in production
- **WHEN** production environment and WebSocket connection includes token in query parameter
- **THEN** connection is rejected with code 4002
- **AND** error message indicates "Query parameter auth disabled in production"
### Requirement: WebSocket Connection Limits
The system SHALL limit each user to a maximum of 5 concurrent WebSocket connections to prevent resource exhaustion.
#### Scenario: User exceeds connection limit
- **WHEN** user already has 5 active WebSocket connections
- **AND** user attempts to open a 6th connection
- **THEN** connection is rejected with code 4005
- **AND** error message indicates "Too many connections"
#### Scenario: User within connection limit
- **WHEN** user has fewer than 5 active connections
- **AND** user opens a new WebSocket connection
- **THEN** connection is accepted
### Requirement: Path Traversal Protection ### Requirement: Path Traversal Protection
The system SHALL prevent file path traversal attacks by validating all file paths resolve within the designated storage directory. The system SHALL prevent file path traversal attacks by validating all file paths resolve within the designated storage directory.
@@ -187,22 +245,25 @@ The system SHALL validate JWT secret key strength on startup.
- **THEN** the system SHALL log a security warning - **THEN** the system SHALL log a security warning
### Requirement: CSRF Protection ### Requirement: CSRF Protection
The system SHALL protect sensitive state-changing operations with CSRF tokens. The system SHALL protect all state-changing operations (POST, PUT, PATCH, DELETE) with CSRF tokens.
#### Scenario: CSRF token required for password change #### Scenario: POST request without CSRF token
- **WHEN** a user attempts to change their password - **WHEN** an authenticated user makes a POST request without X-CSRF-Token header
- **AND** the request does not include a valid CSRF token
- **THEN** the request SHALL be rejected with 403 Forbidden - **THEN** the request SHALL be rejected with 403 Forbidden
- **AND** error message indicates "CSRF token is required"
#### Scenario: CSRF token required for account deletion #### Scenario: PUT/PATCH/DELETE request without CSRF token
- **WHEN** a user attempts to delete their account or resources - **WHEN** an authenticated user makes a PUT, PATCH, or DELETE request without X-CSRF-Token header
- **AND** the request does not include a valid CSRF token
- **THEN** the request SHALL be rejected with 403 Forbidden - **THEN** the request SHALL be rejected with 403 Forbidden
#### Scenario: Valid CSRF token accepted #### Scenario: Valid CSRF token accepted
- **WHEN** a state-changing request includes a valid CSRF token - **WHEN** a state-changing request includes a valid CSRF token
- **THEN** the request SHALL proceed normally - **THEN** the request SHALL proceed normally
#### Scenario: Public endpoints exempt from CSRF
- **WHEN** POST to /api/auth/login or other public endpoints
- **THEN** CSRF token is not required
## Data Model ## Data Model
``` ```