diff --git a/backend/app/api/auth/router.py b/backend/app/api/auth/router.py index a5221e7..1dc482a 100644 --- a/backend/app/api/auth/router.py +++ b/backend/app/api/auth/router.py @@ -3,12 +3,28 @@ from sqlalchemy.orm import Session from app.core.config import settings from app.core.database import get_db -from app.core.security import create_access_token, create_token_payload +from app.core.security import ( + create_access_token, + create_token_payload, + generate_refresh_token, + store_refresh_token, + validate_refresh_token, + invalidate_refresh_token, + invalidate_all_user_refresh_tokens, + decode_refresh_token_user_id, +) from app.core.redis import get_redis from app.core.rate_limiter import limiter from app.models.user import User from app.models.audit_log import AuditAction -from app.schemas.auth import LoginRequest, LoginResponse, UserInfo, CSRFTokenResponse +from app.schemas.auth import ( + LoginRequest, + LoginResponse, + UserInfo, + CSRFTokenResponse, + RefreshTokenRequest, + RefreshTokenResponse, +) from app.services.auth_client import ( verify_credentials, AuthAPIError, @@ -119,6 +135,9 @@ async def login( # Create access token access_token = create_access_token(token_data) + # Generate refresh token + refresh_token = generate_refresh_token() + # Store session in Redis (sync with JWT expiry) redis_client.setex( f"session:{user.id}", @@ -126,6 +145,9 @@ async def login( access_token, ) + # Store refresh token in Redis with user binding + store_refresh_token(redis_client, user.id, refresh_token) + # Log successful login AuditService.log_event( db=db, @@ -141,6 +163,8 @@ async def login( return LoginResponse( access_token=access_token, + refresh_token=refresh_token, + expires_in=settings.JWT_EXPIRE_MINUTES * 60, user=UserInfo( id=user.id, email=user.email, @@ -158,14 +182,114 @@ async def logout( redis_client=Depends(get_redis), ): """ - Logout user and invalidate session. + Logout user and invalidate session and all refresh tokens. """ # Remove session from Redis redis_client.delete(f"session:{current_user.id}") + # Invalidate all refresh tokens for this user + invalidate_all_user_refresh_tokens(redis_client, current_user.id) + return {"detail": "Successfully logged out"} +@router.post("/refresh", response_model=RefreshTokenResponse) +@limiter.limit("10/minute") +async def refresh_access_token( + request: Request, + refresh_request: RefreshTokenRequest, + db: Session = Depends(get_db), + redis_client=Depends(get_redis), +): + """ + Refresh access token using a valid refresh token. + + This endpoint implements refresh token rotation: + - Validates the provided refresh token + - Issues a new access token + - Issues a new refresh token (rotating the old one) + - Invalidates the old refresh token + + This provides enhanced security by ensuring refresh tokens are single-use. + """ + old_refresh_token = refresh_request.refresh_token + + # Find the user ID associated with this refresh token + user_id = decode_refresh_token_user_id(old_refresh_token, redis_client) + + if user_id is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired refresh token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Validate the refresh token is still valid and bound to this user + if not validate_refresh_token(redis_client, user_id, old_refresh_token): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired refresh token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Get user from database + user = db.query(User).filter(User.id == user_id).first() + + if user is None: + # Invalidate the token since user no longer exists + invalidate_refresh_token(redis_client, user_id, old_refresh_token) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User not found", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if not user.is_active: + # Invalidate all tokens for disabled user + invalidate_all_user_refresh_tokens(redis_client, user_id) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User account is disabled", + ) + + # Invalidate the old refresh token (rotation) + invalidate_refresh_token(redis_client, user_id, old_refresh_token) + + # Get role name + role_name = user.role.name if user.role else None + + # Create new token payload + token_data = create_token_payload( + user_id=user.id, + email=user.email, + role=role_name, + department_id=user.department_id, + is_system_admin=user.is_system_admin, + ) + + # Create new access token + new_access_token = create_access_token(token_data) + + # Generate new refresh token (rotation) + new_refresh_token = generate_refresh_token() + + # Store new session in Redis + redis_client.setex( + f"session:{user.id}", + settings.JWT_EXPIRE_MINUTES * 60, + new_access_token, + ) + + # Store new refresh token + store_refresh_token(redis_client, user.id, new_refresh_token) + + return RefreshTokenResponse( + access_token=new_access_token, + refresh_token=new_refresh_token, + expires_in=settings.JWT_EXPIRE_MINUTES * 60, + ) + + @router.get("/me", response_model=UserInfo) async def get_current_user_info( current_user: User = Depends(get_current_user), diff --git a/backend/app/api/websocket/router.py b/backend/app/api/websocket/router.py index 6572d4e..c8f80be 100644 --- a/backend/app/api/websocket/router.py +++ b/backend/app/api/websocket/router.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import Session from app.core import database from app.core.security import decode_access_token from app.core.redis import get_redis_sync +from app.core.config import settings from app.models import User, Notification, Project from app.services.websocket_manager import manager from app.core.redis_pubsub import NotificationSubscriber, ProjectTaskSubscriber @@ -72,14 +73,24 @@ async def authenticate_websocket( Supports two authentication methods: 1. First message authentication (preferred, more secure) - Client sends: {"type": "auth", "token": ""} - 2. Query parameter authentication (deprecated, for backward compatibility) + 2. Query parameter authentication (disabled in production, for backward compatibility only) - Client connects with: ?token= Returns: Tuple of (user_id, error_reason). user_id is None if authentication fails. + Error reasons: "invalid_token", "invalid_message", "missing_token", + "timeout", "error", "query_auth_disabled" """ # If token provided via query parameter (backward compatibility) if query_token: + # Reject query parameter auth in production for security + if settings.ENVIRONMENT == "production": + logger.warning( + "WebSocket query parameter authentication attempted in production environment. " + "This is disabled for security reasons." + ) + return None, "query_auth_disabled" + logger.warning( "WebSocket authentication via query parameter is deprecated. " "Please use first-message authentication for better security." @@ -195,9 +206,21 @@ async def websocket_notifications( user_id, error_reason = await authenticate_websocket(websocket, token) if user_id is None: - if error_reason == "invalid_token": + if error_reason == "query_auth_disabled": + await websocket.send_json({"type": "error", "message": "Query parameter auth disabled in production"}) + await websocket.close(code=4002, reason="Query parameter auth disabled in production") + elif error_reason == "invalid_token": await websocket.send_json({"type": "error", "message": "Invalid or expired token"}) - await websocket.close(code=4001, reason="Invalid or expired token") + await websocket.close(code=4001, reason="Invalid or expired token") + else: + await websocket.close(code=4001, reason="Invalid or expired token") + return + + # Check connection limit before accepting + can_connect, reject_reason = await manager.check_connection_limit(user_id) + if not can_connect: + await websocket.send_json({"type": "error", "message": reject_reason}) + await websocket.close(code=4005, reason=reject_reason) return await manager.connect(websocket, user_id) @@ -394,9 +417,21 @@ async def websocket_project_sync( user_id, error_reason = await authenticate_websocket(websocket, token) if user_id is None: - if error_reason == "invalid_token": + if error_reason == "query_auth_disabled": + await websocket.send_json({"type": "error", "message": "Query parameter auth disabled in production"}) + await websocket.close(code=4002, reason="Query parameter auth disabled in production") + elif error_reason == "invalid_token": await websocket.send_json({"type": "error", "message": "Invalid or expired token"}) - await websocket.close(code=4001, reason="Invalid or expired token") + await websocket.close(code=4001, reason="Invalid or expired token") + else: + await websocket.close(code=4001, reason="Invalid or expired token") + return + + # Check connection limit before accepting + can_connect, reject_reason = await manager.check_connection_limit(user_id) + if not can_connect: + await websocket.send_json({"type": "error", "message": reject_reason}) + await websocket.close(code=4005, reason=reject_reason) return # Verify user has access to the project diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 6a171f7..51523a3 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -28,7 +28,8 @@ class Settings(BaseSettings): # JWT - Must be set in environment, no default allowed JWT_SECRET_KEY: str = "" JWT_ALGORITHM: str = "HS256" - JWT_EXPIRE_MINUTES: int = 10080 # 7 days + JWT_EXPIRE_MINUTES: int = 60 # 1 hour (short-lived access token) + REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # Refresh token valid for 7 days @field_validator("JWT_SECRET_KEY") @classmethod @@ -127,6 +128,12 @@ class Settings(BaseSettings): QUERY_LOGGING: bool = False # Enable SQLAlchemy query logging QUERY_COUNT_THRESHOLD: int = 10 # Warn when query count exceeds this threshold + # Environment + ENVIRONMENT: str = "development" # Options: development, staging, production + + # WebSocket Settings + MAX_WEBSOCKET_CONNECTIONS_PER_USER: int = 5 # Maximum concurrent WebSocket connections per user + class Config: env_file = ".env" case_sensitive = True diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 95783bf..6e97dda 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -356,3 +356,140 @@ def create_token_payload( "department_id": department_id, "is_system_admin": is_system_admin, } + + +# Refresh Token Functions +REFRESH_TOKEN_BYTES = 32 + + +def generate_refresh_token() -> str: + """ + Generate a cryptographically secure refresh token. + + Returns: + A URL-safe base64-encoded random token + """ + return secrets.token_urlsafe(REFRESH_TOKEN_BYTES) + + +def get_refresh_token_key(user_id: str, token: str) -> str: + """ + Generate the Redis key for a refresh token. + + Args: + user_id: The user's ID + token: The refresh token + + Returns: + Redis key string + """ + # Hash the token to avoid storing it directly as a key + token_hash = hashlib.sha256(token.encode()).hexdigest()[:16] + return f"refresh_token:{user_id}:{token_hash}" + + +def store_refresh_token(redis_client, user_id: str, token: str) -> None: + """ + Store a refresh token in Redis with user binding. + + Args: + redis_client: Redis client instance + user_id: The user's ID + token: The refresh token to store + """ + key = get_refresh_token_key(user_id, token) + # Store with TTL based on REFRESH_TOKEN_EXPIRE_DAYS + ttl_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60 + redis_client.setex(key, ttl_seconds, user_id) + + +def validate_refresh_token(redis_client, user_id: str, token: str) -> bool: + """ + Validate a refresh token exists in Redis and is bound to the user. + + Args: + redis_client: Redis client instance + user_id: The expected user ID + token: The refresh token to validate + + Returns: + True if token is valid, False otherwise + """ + key = get_refresh_token_key(user_id, token) + stored_user_id = redis_client.get(key) + + if stored_user_id is None: + return False + + # Handle Redis bytes type + if isinstance(stored_user_id, bytes): + stored_user_id = stored_user_id.decode("utf-8") + + return stored_user_id == user_id + + +def invalidate_refresh_token(redis_client, user_id: str, token: str) -> bool: + """ + Invalidate (delete) a refresh token from Redis. + + Args: + redis_client: Redis client instance + user_id: The user's ID + token: The refresh token to invalidate + + Returns: + True if token was deleted, False if it didn't exist + """ + key = get_refresh_token_key(user_id, token) + result = redis_client.delete(key) + return result > 0 if isinstance(result, int) else bool(result) + + +def invalidate_all_user_refresh_tokens(redis_client, user_id: str) -> int: + """ + Invalidate all refresh tokens for a user. + + Args: + redis_client: Redis client instance + user_id: The user's ID + + Returns: + Number of tokens invalidated + """ + pattern = f"refresh_token:{user_id}:*" + count = 0 + for key in redis_client.scan_iter(match=pattern): + redis_client.delete(key) + count += 1 + return count + + +def decode_refresh_token_user_id(token: str, redis_client) -> Optional[str]: + """ + Find the user ID associated with a refresh token by searching Redis. + + This is used when we only have the token and need to find which user it belongs to. + Note: This is less efficient but necessary for refresh token validation when + the user_id is not provided in the request. + + Args: + token: The refresh token + redis_client: Redis client instance + + Returns: + User ID if found, None otherwise + """ + # We need to search for the token across all users + # This is done by checking the token hash pattern + token_hash = hashlib.sha256(token.encode()).hexdigest()[:16] + pattern = f"refresh_token:*:{token_hash}" + + for key in redis_client.scan_iter(match=pattern): + # Extract user_id from key format: refresh_token:{user_id}:{token_hash} + if isinstance(key, bytes): + key = key.decode("utf-8") + parts = key.split(":") + if len(parts) == 3: + return parts[1] + + return None diff --git a/backend/app/main.py b/backend/app/main.py index 4ef6114..d412f31 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,7 +1,7 @@ import os from contextlib import asynccontextmanager from datetime import datetime -from fastapi import FastAPI, Request, APIRouter +from fastapi import FastAPI, Request, APIRouter, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from slowapi import _rate_limit_exceeded_handler @@ -9,6 +9,9 @@ from slowapi.errors import RateLimitExceeded from sqlalchemy import text from app.middleware.audit import AuditMiddleware +from app.middleware.csrf import CSRFMiddleware +from app.middleware.security_audit import SecurityAuditMiddleware +from app.middleware.error_sanitizer import ErrorSanitizerMiddleware from app.core.scheduler import start_scheduler, shutdown_scheduler, scheduler from app.core.rate_limiter import limiter from app.core.deprecation import DeprecationMiddleware @@ -61,6 +64,8 @@ from app.core.database import get_pool_status, engine from app.core.redis import redis_client from app.services.notification_service import get_redis_fallback_status from app.services.file_storage_service import file_storage_service +from app.middleware.auth import require_system_admin +from app.models import User app = FastAPI( title="Project Control API", @@ -73,18 +78,28 @@ app = FastAPI( app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) -# CORS middleware +# CORS middleware - Explicit methods and headers for security app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS, allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-CSRF-Token", "X-Request-ID"], ) +# Error sanitizer middleware - sanitizes error messages in production +# Must be first in the chain to intercept all error responses +app.add_middleware(ErrorSanitizerMiddleware) + # Audit middleware - extracts request metadata for audit logging app.add_middleware(AuditMiddleware) +# Security audit middleware - logs 401/403 responses to audit trail +app.add_middleware(SecurityAuditMiddleware) + +# CSRF middleware - validates CSRF tokens for state-changing requests +app.add_middleware(CSRFMiddleware) + # Deprecation middleware - adds deprecation headers to legacy /api/ routes app.add_middleware(DeprecationMiddleware) @@ -252,14 +267,20 @@ async def readiness_check(): @app.get("/health/detailed") -async def detailed_health_check(): - """Detailed health check endpoint. +async def detailed_health_check( + current_user: User = Depends(require_system_admin), +): + """Detailed health check endpoint (requires system admin). Returns comprehensive status of all system components: - database: Connection pool status and connectivity - redis: Connection status and fallback queue status - storage: File storage validation status - scheduler: Background job scheduler status + + Note: This endpoint requires system admin authentication because it exposes + sensitive infrastructure details including connection pool statistics and + internal service states. """ db_health = check_database_health() redis_health = check_redis_health() diff --git a/backend/app/middleware/csrf.py b/backend/app/middleware/csrf.py index 1191c09..2d405cc 100644 --- a/backend/app/middleware/csrf.py +++ b/backend/app/middleware/csrf.py @@ -1,38 +1,55 @@ """ CSRF (Cross-Site Request Forgery) Protection Middleware. -This module provides CSRF protection for sensitive state-changing operations. -It validates CSRF tokens for specified protected endpoints. +This module provides CSRF protection for all state-changing operations. +It validates CSRF tokens globally for authenticated POST, PUT, PATCH, DELETE requests. """ -from fastapi import Request, HTTPException, status, Depends -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from typing import Optional, Callable, List +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse +from fastapi import HTTPException, status +from typing import Optional, Callable, List, Set from functools import wraps import logging -from app.core.security import validate_csrf_token, generate_csrf_token +from app.core.security import validate_csrf_token, generate_csrf_token, decode_access_token logger = logging.getLogger(__name__) # Header name for CSRF token CSRF_TOKEN_HEADER = "X-CSRF-Token" -# List of endpoint patterns that require CSRF protection -# These are sensitive state-changing operations -CSRF_PROTECTED_PATTERNS = [ - # User operations - "/api/v1/users/{user_id}/admin", # Admin status change - "/api/users/{user_id}/admin", # Legacy - # Password changes would go here if implemented - # Delete operations - "/api/attachments/{attachment_id}", # DELETE method - "/api/tasks/{task_id}", # DELETE method (soft delete) - "/api/projects/{project_id}", # DELETE method -] +# Methods that require CSRF protection (all state-changing operations) +CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"} -# Methods that require CSRF protection -CSRF_PROTECTED_METHODS = ["DELETE", "PUT", "PATCH"] +# Safe methods that don't require CSRF protection +CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS"} + +# Public endpoints that don't require CSRF validation +# These are endpoints that either: +# 1. Don't require authentication (login, health checks) +# 2. Are not state-changing in a security-sensitive way +CSRF_EXCLUDED_PATHS: Set[str] = { + # Authentication endpoints (unauthenticated) + "/api/auth/login", + "/api/v1/auth/login", + # Health check endpoints (unauthenticated) + "/health", + "/health/live", + "/health/ready", + "/health/detailed", + # WebSocket endpoints (use different auth mechanism) + "/api/ws", + "/ws", +} + +# Path prefixes that are excluded from CSRF validation +CSRF_EXCLUDED_PREFIXES: List[str] = [ + # WebSocket paths + "/api/ws/", + "/ws/", +] class CSRFProtectionError(HTTPException): @@ -45,6 +62,114 @@ class CSRFProtectionError(HTTPException): ) +class CSRFMiddleware(BaseHTTPMiddleware): + """ + Global CSRF protection middleware. + + Validates CSRF tokens for all authenticated state-changing requests + (POST, PUT, PATCH, DELETE) except for explicitly excluded endpoints. + """ + + async def dispatch(self, request: Request, call_next): + """Process the request and validate CSRF token if needed.""" + method = request.method.upper() + path = request.url.path + + # Skip CSRF validation for safe methods + if method in CSRF_SAFE_METHODS: + return await call_next(request) + + # Skip CSRF validation for excluded paths + if self._is_excluded_path(path): + logger.debug("CSRF validation skipped for excluded path: %s", path) + return await call_next(request) + + # Try to extract user ID from the Authorization header + user_id = self._extract_user_id_from_token(request) + + # If no user ID (unauthenticated request), skip CSRF validation + # The authentication middleware will handle unauthorized access + if user_id is None: + logger.debug( + "CSRF validation skipped (no auth token): %s %s", + method, path + ) + return await call_next(request) + + # Get CSRF token from header + csrf_token = request.headers.get(CSRF_TOKEN_HEADER) + + if not csrf_token: + logger.warning( + "CSRF validation failed: Missing token for user %s on %s %s", + user_id, method, path + ) + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"detail": "CSRF token is required"} + ) + + # Validate the token + is_valid, error_message = validate_csrf_token(csrf_token, user_id) + + if not is_valid: + logger.warning( + "CSRF validation failed for user %s on %s %s: %s", + user_id, method, path, error_message + ) + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"detail": error_message} + ) + + logger.debug( + "CSRF validation passed for user %s on %s %s", + user_id, method, path + ) + + return await call_next(request) + + def _is_excluded_path(self, path: str) -> bool: + """Check if the path is excluded from CSRF validation.""" + # Check exact path matches + if path in CSRF_EXCLUDED_PATHS: + return True + + # Check path prefixes + for prefix in CSRF_EXCLUDED_PREFIXES: + if path.startswith(prefix): + return True + + return False + + def _extract_user_id_from_token(self, request: Request) -> Optional[str]: + """ + Extract user ID from the Authorization header. + + Returns None if no valid token is found (unauthenticated request). + """ + auth_header = request.headers.get("Authorization") + if not auth_header: + return None + + # Parse Bearer token + parts = auth_header.split() + if len(parts) != 2 or parts[0].lower() != "bearer": + return None + + token = parts[1] + + # Decode the token to get user ID + try: + payload = decode_access_token(token) + if payload is None: + return None + return payload.get("sub") + except Exception as e: + logger.debug("Failed to decode token for CSRF validation: %s", e) + return None + + def require_csrf_token(func: Callable) -> Callable: """ Decorator to require CSRF token validation for an endpoint. diff --git a/backend/app/middleware/error_sanitizer.py b/backend/app/middleware/error_sanitizer.py new file mode 100644 index 0000000..4c4bce2 --- /dev/null +++ b/backend/app/middleware/error_sanitizer.py @@ -0,0 +1,187 @@ +"""Error message sanitization middleware for production environments. + +This middleware intercepts error responses and sanitizes them to prevent +information disclosure in production environments. Detailed error messages +are only shown when DEBUG mode is enabled. +""" +import json +import logging +from typing import Optional +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response, JSONResponse + +from app.core.config import settings + +logger = logging.getLogger(__name__) + +# Generic error messages for production +GENERIC_ERROR_MESSAGES = { + 400: "Bad Request", + 401: "Authentication required", + 403: "Access denied", + 404: "Resource not found", + 405: "Method not allowed", + 409: "Request conflict", + 422: "Validation error", + 429: "Too many requests", + 500: "Internal server error", + 502: "Service unavailable", + 503: "Service temporarily unavailable", + 504: "Request timeout", +} + +# Status codes that should preserve their original message even in production +# These are typically user-facing validation errors that don't leak sensitive info +PRESERVE_MESSAGE_CODES = { + 400, # Bad request - users need to know what's wrong with their request + 401, # Unauthorized - users need to know why auth failed + 403, # Forbidden - users need to know what permission they lack + 404, # Not found - usually safe to preserve + 409, # Conflict - users need to know about conflicts + 422, # Validation errors - users need to know what to fix +} + +# Patterns that indicate sensitive information in error messages +SENSITIVE_PATTERNS = [ + "traceback", + "stack trace", + "file path", + "/usr/", + "/var/", + "/home/", + "connection refused", + "connection error", + "timeout connecting", + "database error", + "sql", + "query failed", + "password", + "secret", + "token", + "key=", + "credentials", + ".py line", + "exception in", +] + + +def _contains_sensitive_info(message: str) -> bool: + """Check if an error message contains potentially sensitive information.""" + if not message: + return False + message_lower = message.lower() + return any(pattern.lower() in message_lower for pattern in SENSITIVE_PATTERNS) + + +def _sanitize_detail(detail: any, status_code: int) -> any: + """Sanitize error detail, removing sensitive information in production. + + Args: + detail: The error detail (can be string, list, or dict) + status_code: The HTTP status code + + Returns: + Sanitized detail for production, or original detail for debug mode + """ + # In debug mode, return original detail + if settings.DEBUG: + return detail + + # For preserved status codes, keep the detail if it doesn't contain sensitive info + if status_code in PRESERVE_MESSAGE_CODES: + if isinstance(detail, str) and not _contains_sensitive_info(detail): + return detail + if isinstance(detail, list): + # For validation errors (list of dicts), keep the structure but sanitize + sanitized = [] + for item in detail: + if isinstance(item, dict): + # Keep loc, msg, type for pydantic validation errors + sanitized_item = {} + if 'loc' in item: + sanitized_item['loc'] = item['loc'] + if 'msg' in item and not _contains_sensitive_info(str(item['msg'])): + sanitized_item['msg'] = item['msg'] + else: + sanitized_item['msg'] = 'Validation failed' + if 'type' in item: + sanitized_item['type'] = item['type'] + sanitized.append(sanitized_item) + else: + sanitized.append(item if not _contains_sensitive_info(str(item)) else 'Invalid value') + return sanitized + return detail + + # For other status codes, use generic message + return GENERIC_ERROR_MESSAGES.get(status_code, "An error occurred") + + +class ErrorSanitizerMiddleware(BaseHTTPMiddleware): + """Middleware to sanitize error responses in production. + + This middleware: + 1. Intercepts error responses (4xx and 5xx status codes) + 2. Parses JSON response bodies + 3. Sanitizes the 'detail' field to remove sensitive information + 4. Returns the sanitized response + + In DEBUG mode, original error messages are preserved for development. + """ + + async def dispatch(self, request: Request, call_next) -> Response: + response = await call_next(request) + + # Only process error responses with JSON content + if response.status_code < 400: + return response + + content_type = response.headers.get("content-type", "") + if "application/json" not in content_type: + return response + + # Read the response body + body = b"" + async for chunk in response.body_iterator: + body += chunk + + if not body: + return response + + try: + data = json.loads(body) + except (json.JSONDecodeError, UnicodeDecodeError): + # Not valid JSON, return as-is + return Response( + content=body, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.media_type, + ) + + # Sanitize the detail field if present + if "detail" in data: + original_detail = data["detail"] + data["detail"] = _sanitize_detail(original_detail, response.status_code) + + # Log the original error in production for debugging + if not settings.DEBUG and original_detail != data["detail"]: + logger.warning( + "Sanitized error response", + extra={ + "status_code": response.status_code, + "path": str(request.url.path), + "method": request.method, + "original_detail_length": len(str(original_detail)), + } + ) + + # Return the sanitized response + return JSONResponse( + content=data, + status_code=response.status_code, + headers={ + k: v for k, v in response.headers.items() + if k.lower() not in ("content-length", "content-type") + }, + ) diff --git a/backend/app/middleware/security_audit.py b/backend/app/middleware/security_audit.py new file mode 100644 index 0000000..009c355 --- /dev/null +++ b/backend/app/middleware/security_audit.py @@ -0,0 +1,215 @@ +"""Security audit middleware for logging access denials and suspicious auth patterns.""" +import time +from collections import defaultdict +from datetime import datetime, timedelta, timezone +from typing import Dict, List, Tuple +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response +from sqlalchemy.orm import Session + +from app.core.database import SessionLocal +from app.models import AuditLog, AuditAction +from app.services.audit_service import AuditService + + +# In-memory storage for tracking auth failures +# Structure: {ip_address: [(timestamp, path), ...]} +_auth_failure_tracker: Dict[str, List[Tuple[float, str]]] = defaultdict(list) + +# Configuration constants +AUTH_FAILURE_THRESHOLD = 5 # Number of failures to trigger suspicious pattern alert +AUTH_FAILURE_WINDOW_SECONDS = 600 # 10 minutes window + + +def _cleanup_old_failures(ip: str) -> None: + """Remove auth failures older than the tracking window.""" + if ip not in _auth_failure_tracker: + return + + cutoff = time.time() - AUTH_FAILURE_WINDOW_SECONDS + _auth_failure_tracker[ip] = [ + (ts, path) for ts, path in _auth_failure_tracker[ip] + if ts > cutoff + ] + + # Clean up empty entries + if not _auth_failure_tracker[ip]: + del _auth_failure_tracker[ip] + + +def _track_auth_failure(ip: str, path: str) -> int: + """Track an auth failure and return the count in the window.""" + _cleanup_old_failures(ip) + _auth_failure_tracker[ip].append((time.time(), path)) + return len(_auth_failure_tracker[ip]) + + +def _get_recent_failures(ip: str) -> List[str]: + """Get list of paths that failed auth for this IP.""" + _cleanup_old_failures(ip) + return [path for _, path in _auth_failure_tracker.get(ip, [])] + + +class SecurityAuditMiddleware(BaseHTTPMiddleware): + """Middleware to audit security-related events like 401/403 responses.""" + + async def dispatch(self, request: Request, call_next) -> Response: + response = await call_next(request) + + # Only process 401 and 403 responses + if response.status_code not in (401, 403): + return response + + # Get client IP from audit metadata if available + ip_address = self._get_client_ip(request) + path = str(request.url.path) + method = request.method + + # Get user_id if available from request state (set by auth middleware) + user_id = getattr(request.state, 'user_id', None) + + db: Session = SessionLocal() + try: + if response.status_code == 403: + self._log_access_denied(db, ip_address, path, method, user_id, request) + elif response.status_code == 401: + self._log_auth_failure(db, ip_address, path, method, request) + + db.commit() + except Exception: + db.rollback() + # Don't fail the request due to audit logging errors + finally: + db.close() + + return response + + def _get_client_ip(self, request: Request) -> str: + """Get the real client IP address from request.""" + # Check for audit metadata first (set by AuditMiddleware) + audit_metadata = getattr(request.state, 'audit_metadata', None) + if audit_metadata and 'ip_address' in audit_metadata: + return audit_metadata['ip_address'] + + # Fallback to checking headers directly + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + + real_ip = request.headers.get("x-real-ip") + if real_ip: + return real_ip + + if request.client: + return request.client.host + + return "unknown" + + def _log_access_denied( + self, + db: Session, + ip_address: str, + path: str, + method: str, + user_id: str | None, + request: Request, + ) -> None: + """Log a 403 Forbidden response to the audit trail.""" + request_metadata = { + "ip_address": ip_address, + "user_agent": request.headers.get("user-agent", ""), + "method": method, + "path": path, + } + + # Try to extract resource info from path + resource_type = self._extract_resource_type(path) + + AuditService.log_event( + db=db, + event_type="security.access_denied", + resource_type=resource_type, + action=AuditAction.ACCESS_DENIED, + user_id=user_id, + resource_id=None, + changes=[{ + "attempted_path": path, + "attempted_method": method, + "ip_address": ip_address, + }], + request_metadata=request_metadata, + ) + + def _log_auth_failure( + self, + db: Session, + ip_address: str, + path: str, + method: str, + request: Request, + ) -> None: + """Log a 401 Unauthorized response and check for suspicious patterns.""" + # Track this failure + failure_count = _track_auth_failure(ip_address, path) + + request_metadata = { + "ip_address": ip_address, + "user_agent": request.headers.get("user-agent", ""), + "method": method, + "path": path, + } + + resource_type = self._extract_resource_type(path) + + # Log the auth failure + AuditService.log_event( + db=db, + event_type="security.auth_failed", + resource_type=resource_type, + action=AuditAction.AUTH_FAILED, + user_id=None, # No user for 401 + resource_id=None, + changes=[{ + "attempted_path": path, + "attempted_method": method, + "ip_address": ip_address, + "failure_count_in_window": failure_count, + }], + request_metadata=request_metadata, + ) + + # Check for suspicious pattern + if failure_count >= AUTH_FAILURE_THRESHOLD: + recent_paths = _get_recent_failures(ip_address) + AuditService.log_event( + db=db, + event_type="security.suspicious_auth_pattern", + resource_type="security", + action=AuditAction.AUTH_FAILED, + user_id=None, + resource_id=None, + changes=[{ + "ip_address": ip_address, + "failure_count": failure_count, + "window_minutes": AUTH_FAILURE_WINDOW_SECONDS // 60, + "attempted_paths": list(set(recent_paths)), + }], + request_metadata=request_metadata, + ) + + def _extract_resource_type(self, path: str) -> str: + """Extract resource type from path for audit logging.""" + # Remove /api/ or /api/v1/ prefix + clean_path = path + if clean_path.startswith("/api/v1/"): + clean_path = clean_path[8:] + elif clean_path.startswith("/api/"): + clean_path = clean_path[5:] + + # Get the first path segment as resource type + parts = clean_path.strip("/").split("/") + if parts and parts[0]: + return parts[0] + + return "unknown" diff --git a/backend/app/models/audit_log.py b/backend/app/models/audit_log.py index a7aa95c..d1d6804 100644 --- a/backend/app/models/audit_log.py +++ b/backend/app/models/audit_log.py @@ -13,6 +13,8 @@ class AuditAction(str, enum.Enum): RESTORE = "restore" LOGIN = "login" LOGOUT = "logout" + ACCESS_DENIED = "access_denied" + AUTH_FAILED = "auth_failed" class SensitivityLevel(str, enum.Enum): @@ -42,10 +44,20 @@ EVENT_SENSITIVITY = { "attachment.upload": SensitivityLevel.LOW, "attachment.download": SensitivityLevel.LOW, "attachment.delete": SensitivityLevel.MEDIUM, + # Security events + "security.access_denied": SensitivityLevel.MEDIUM, + "security.auth_failed": SensitivityLevel.MEDIUM, + "security.suspicious_auth_pattern": SensitivityLevel.HIGH, } # Events that should trigger alerts -ALERT_EVENTS = {"project.delete", "user.permission_change", "user.admin_change", "role.permission_change"} +ALERT_EVENTS = { + "project.delete", + "user.permission_change", + "user.admin_change", + "role.permission_change", + "security.suspicious_auth_pattern", +} class AuditLog(Base): @@ -57,7 +69,7 @@ class AuditLog(Base): resource_id = Column(String(36), nullable=True) user_id = Column(String(36), ForeignKey("pjctrl_users.id", ondelete="SET NULL"), nullable=True) action = Column( - Enum("create", "update", "delete", "restore", "login", "logout", name="audit_action_enum"), + Enum("create", "update", "delete", "restore", "login", "logout", "access_denied", "auth_failed", name="audit_action_enum"), nullable=False ) changes = Column(JSON, nullable=True) diff --git a/backend/app/schemas/auth.py b/backend/app/schemas/auth.py index 872ddf1..ea2d2ad 100644 --- a/backend/app/schemas/auth.py +++ b/backend/app/schemas/auth.py @@ -9,10 +9,25 @@ class LoginRequest(BaseModel): class LoginResponse(BaseModel): access_token: str + refresh_token: str token_type: str = "bearer" + expires_in: int = Field(default=3600, description="Access token expiry in seconds") user: "UserInfo" +class RefreshTokenRequest(BaseModel): + """Request body for refresh token endpoint.""" + refresh_token: str = Field(..., description="The refresh token to use for obtaining a new access token") + + +class RefreshTokenResponse(BaseModel): + """Response for refresh token endpoint.""" + access_token: str + refresh_token: str # New refresh token (rotation) + token_type: str = "bearer" + expires_in: int = Field(default=3600, description="Access token expiry in seconds") + + class UserInfo(BaseModel): id: str email: str diff --git a/backend/app/schemas/comment.py b/backend/app/schemas/comment.py index e8cf670..4c138bf 100644 --- a/backend/app/schemas/comment.py +++ b/backend/app/schemas/comment.py @@ -4,12 +4,12 @@ from pydantic import BaseModel, Field class CommentCreate(BaseModel): - content: str = Field(..., min_length=1, max_length=10000) + content: str = Field(..., min_length=1, max_length=5000) parent_comment_id: Optional[str] = None class CommentUpdate(BaseModel): - content: str = Field(..., min_length=1, max_length=10000) + content: str = Field(..., min_length=1, max_length=5000) class CommentAuthor(BaseModel): diff --git a/backend/app/schemas/project_template.py b/backend/app/schemas/project_template.py index d01a4c5..89425fa 100644 --- a/backend/app/schemas/project_template.py +++ b/backend/app/schemas/project_template.py @@ -25,7 +25,7 @@ class CustomFieldDefinition(BaseModel): class ProjectTemplateBase(BaseModel): """Base schema for project template.""" name: str = Field(..., min_length=1, max_length=200) - description: Optional[str] = None + description: Optional[str] = Field(None, max_length=2000) is_public: bool = Field(default=False) task_statuses: Optional[List[TaskStatusDefinition]] = None custom_fields: Optional[List[CustomFieldDefinition]] = None @@ -43,7 +43,7 @@ class ProjectTemplateCreate(ProjectTemplateBase): class ProjectTemplateUpdate(BaseModel): """Schema for updating a project template.""" name: Optional[str] = Field(None, min_length=1, max_length=200) - description: Optional[str] = None + description: Optional[str] = Field(None, max_length=2000) is_public: Optional[bool] = None task_statuses: Optional[List[TaskStatusDefinition]] = None custom_fields: Optional[List[CustomFieldDefinition]] = None diff --git a/backend/app/services/websocket_manager.py b/backend/app/services/websocket_manager.py index 5a1c055..cd42f0b 100644 --- a/backend/app/services/websocket_manager.py +++ b/backend/app/services/websocket_manager.py @@ -4,6 +4,7 @@ import logging from typing import Dict, Set, Optional, Tuple from fastapi import WebSocket from app.core.redis import get_redis_sync +from app.core.config import settings logger = logging.getLogger(__name__) @@ -19,13 +20,48 @@ class ConnectionManager: self._lock = asyncio.Lock() self._project_lock = asyncio.Lock() + async def check_connection_limit(self, user_id: str) -> Tuple[bool, Optional[str]]: + """ + Check if user can create a new WebSocket connection. + + Args: + user_id: The user's ID + + Returns: + Tuple of (can_connect: bool, reject_reason: str | None) + - can_connect: True if user is within connection limit + - reject_reason: Error message if connection should be rejected + """ + max_connections = settings.MAX_WEBSOCKET_CONNECTIONS_PER_USER + async with self._lock: + current_count = len(self.active_connections.get(user_id, set())) + if current_count >= max_connections: + logger.warning( + f"User {user_id} exceeded WebSocket connection limit " + f"({current_count}/{max_connections})" + ) + return False, "Too many connections" + return True, None + + def get_user_connection_count(self, user_id: str) -> int: + """Get the current number of WebSocket connections for a user.""" + return len(self.active_connections.get(user_id, set())) + async def connect(self, websocket: WebSocket, user_id: str): - """Accept and track a new WebSocket connection.""" - await websocket.accept() + """ + Track a new WebSocket connection. + + Note: WebSocket must already be accepted before calling this method. + Connection limit should be checked via check_connection_limit() before calling. + """ async with self._lock: if user_id not in self.active_connections: self.active_connections[user_id] = set() self.active_connections[user_id].add(websocket) + logger.debug( + f"User {user_id} connected. Total connections: " + f"{len(self.active_connections[user_id])}" + ) async def disconnect(self, websocket: WebSocket, user_id: str): """Remove a WebSocket connection.""" diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 7dadc52..44465c9 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -166,3 +166,20 @@ def admin_token(client, mock_redis): mock_redis.setex("session:00000000-0000-0000-0000-000000000001", 900, token) return token + + +@pytest.fixture +def csrf_token(): + """Generate a CSRF token for the admin user.""" + from app.core.security import generate_csrf_token + + return generate_csrf_token("00000000-0000-0000-0000-000000000001") + + +@pytest.fixture +def auth_headers(admin_token, csrf_token): + """Get complete auth headers including both Authorization and CSRF token.""" + return { + "Authorization": f"Bearer {admin_token}", + "X-CSRF-Token": csrf_token, + } diff --git a/backend/tests/test_api_enhancements.py b/backend/tests/test_api_enhancements.py index b88b858..641464c 100644 --- a/backend/tests/test_api_enhancements.py +++ b/backend/tests/test_api_enhancements.py @@ -173,7 +173,7 @@ class TestProjectTemplates: # Should return list of templates assert "templates" in data or isinstance(data, list) - def test_create_template(self, client, admin_token, db): + def test_create_template(self, client, auth_headers, db): """Test creating a new project template.""" from app.models import Space @@ -192,14 +192,14 @@ class TestProjectTemplates: {"name": "Done", "color": "#00FF00"} ] }, - headers={"Authorization": f"Bearer {admin_token}"} + headers=auth_headers ) assert response.status_code in [200, 201] data = response.json() assert data.get("name") == "Test Template" - def test_create_project_from_template(self, client, admin_token, db): + def test_create_project_from_template(self, client, auth_headers, db): """Test creating a project from a template.""" from app.models import Space, ProjectTemplate @@ -228,14 +228,14 @@ class TestProjectTemplates: "description": "Created from template", "template_id": "test-template-id" }, - headers={"Authorization": f"Bearer {admin_token}"} + headers=auth_headers ) assert response.status_code in [200, 201] data = response.json() assert data.get("name") == "Project from Template" - def test_delete_template(self, client, admin_token, db): + def test_delete_template(self, client, auth_headers, db): """Test deleting a project template.""" from app.models import ProjectTemplate @@ -251,7 +251,7 @@ class TestProjectTemplates: response = client.delete( "/api/templates/delete-template-id", - headers={"Authorization": f"Bearer {admin_token}"} + headers=auth_headers ) assert response.status_code in [200, 204] diff --git a/backend/tests/test_attachments.py b/backend/tests/test_attachments.py index f39b048..78616b0 100644 --- a/backend/tests/test_attachments.py +++ b/backend/tests/test_attachments.py @@ -42,6 +42,22 @@ def test_user_token(client, mock_redis, test_user): return token +@pytest.fixture +def test_user_csrf_token(test_user): + """Generate a CSRF token for the test user.""" + from app.core.security import generate_csrf_token + return generate_csrf_token(test_user.id) + + +@pytest.fixture +def test_user_auth_headers(test_user_token, test_user_csrf_token): + """Get complete auth headers for test user.""" + return { + "Authorization": f"Bearer {test_user_token}", + "X-CSRF-Token": test_user_csrf_token, + } + + @pytest.fixture def test_space(db, test_user): """Create a test space.""" @@ -154,7 +170,7 @@ class TestFileStorageService: class TestAttachmentAPI: """Tests for Attachment API endpoints.""" - def test_upload_attachment(self, client, test_user_token, test_task, db, monkeypatch, temp_upload_dir): + def test_upload_attachment(self, client, test_user_auth_headers, test_task, db, monkeypatch, temp_upload_dir): """Test uploading an attachment.""" monkeypatch.setattr("app.core.config.settings.UPLOAD_DIR", temp_upload_dir) @@ -163,7 +179,7 @@ class TestAttachmentAPI: response = client.post( f"/api/tasks/{test_task.id}/attachments", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers=test_user_auth_headers, files=files, ) @@ -271,14 +287,14 @@ class TestAttachmentAPI: db.refresh(attachment) assert attachment.is_deleted == True - def test_upload_blocked_file_type(self, client, test_user_token, test_task): + def test_upload_blocked_file_type(self, client, test_user_auth_headers, test_task): """Test that blocked file types are rejected.""" content = b"malicious content" files = {"file": ("virus.exe", BytesIO(content), "application/octet-stream")} response = client.post( f"/api/tasks/{test_task.id}/attachments", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers=test_user_auth_headers, files=files, ) @@ -322,7 +338,7 @@ class TestAttachmentAPI: assert data["total"] == 2 assert len(data["versions"]) == 2 - def test_restore_version(self, client, test_user_token, test_task, db): + def test_restore_version(self, client, test_user_auth_headers, test_task, db): """Test restoring to a previous version.""" attachment = Attachment( id=str(uuid.uuid4()), @@ -351,7 +367,7 @@ class TestAttachmentAPI: response = client.post( f"/api/attachments/{attachment.id}/restore/1", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers=test_user_auth_headers, ) assert response.status_code == 200 diff --git a/backend/tests/test_audit.py b/backend/tests/test_audit.py index 0488d04..eb82f9e 100644 --- a/backend/tests/test_audit.py +++ b/backend/tests/test_audit.py @@ -253,7 +253,7 @@ class TestAuditAPI: assert data["total"] == 3 assert all(log["resource_id"] == resource_id for log in data["logs"]) - def test_verify_integrity(self, client, admin_token, db): + def test_verify_integrity(self, client, auth_headers, db): """Test integrity verification.""" now = datetime.utcnow() @@ -270,7 +270,7 @@ class TestAuditAPI: response = client.post( "/api/audit-logs/verify-integrity", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={ "start_date": (now - timedelta(hours=1)).isoformat(), "end_date": (now + timedelta(hours=1)).isoformat(), @@ -281,7 +281,7 @@ class TestAuditAPI: assert data["total_checked"] >= 1 assert data["invalid_count"] == 0 - def test_acknowledge_alert(self, client, admin_token, db): + def test_acknowledge_alert(self, client, auth_headers, db): """Test acknowledging an alert.""" # Create a log and alert log = AuditLog( @@ -309,7 +309,7 @@ class TestAuditAPI: response = client.put( f"/api/audit-alerts/{alert.id}/acknowledge", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 data = response.json() diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py index 0c8eb58..2e9923a 100644 --- a/backend/tests/test_auth.py +++ b/backend/tests/test_auth.py @@ -1,5 +1,16 @@ import pytest -from app.core.security import create_access_token, decode_access_token, create_token_payload +from app.core.security import ( + create_access_token, + decode_access_token, + create_token_payload, + generate_refresh_token, + store_refresh_token, + validate_refresh_token, + invalidate_refresh_token, + invalidate_all_user_refresh_tokens, + decode_refresh_token_user_id, + get_refresh_token_key, +) class TestJWT: @@ -59,7 +70,7 @@ class TestAuthEndpoints: def test_get_me_without_auth(self, client): """Test accessing /me without authentication.""" response = client.get("/api/auth/me") - assert response.status_code == 403 + assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized def test_get_me_with_auth(self, client, admin_token): """Test accessing /me with valid authentication.""" @@ -72,13 +83,196 @@ class TestAuthEndpoints: assert data["email"] == "ymirliu@panjit.com.tw" assert data["is_system_admin"] is True - def test_logout(self, client, admin_token, mock_redis): + def test_logout(self, client, auth_headers, mock_redis): """Test logout endpoint.""" response = client.post( "/api/auth/logout", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 # Verify session is removed assert mock_redis.get("session:00000000-0000-0000-0000-000000000001") is None + + +class TestRefreshToken: + """Test refresh token functionality.""" + + def test_generate_refresh_token(self): + """Test that refresh tokens are generated correctly.""" + token = generate_refresh_token() + assert token is not None + assert isinstance(token, str) + assert len(token) > 20 # URL-safe base64 encoded 32 bytes + + def test_generate_unique_refresh_tokens(self): + """Test that each generated token is unique.""" + tokens = [generate_refresh_token() for _ in range(100)] + assert len(set(tokens)) == 100 # All tokens should be unique + + def test_store_and_validate_refresh_token(self, mock_redis): + """Test storing and validating refresh tokens.""" + user_id = "test-user-123" + token = generate_refresh_token() + + # Store the token + store_refresh_token(mock_redis, user_id, token) + + # Validate the token + assert validate_refresh_token(mock_redis, user_id, token) is True + + # Wrong user should fail + assert validate_refresh_token(mock_redis, "wrong-user", token) is False + + # Wrong token should fail + assert validate_refresh_token(mock_redis, user_id, "wrong-token") is False + + def test_invalidate_refresh_token(self, mock_redis): + """Test invalidating a refresh token.""" + user_id = "test-user-123" + token = generate_refresh_token() + + # Store and verify + store_refresh_token(mock_redis, user_id, token) + assert validate_refresh_token(mock_redis, user_id, token) is True + + # Invalidate + result = invalidate_refresh_token(mock_redis, user_id, token) + assert result is True + + # Should no longer be valid + assert validate_refresh_token(mock_redis, user_id, token) is False + + def test_invalidate_all_user_refresh_tokens(self, mock_redis): + """Test invalidating all refresh tokens for a user.""" + user_id = "test-user-123" + tokens = [generate_refresh_token() for _ in range(3)] + + # Store multiple tokens + for token in tokens: + store_refresh_token(mock_redis, user_id, token) + + # Verify all are valid + for token in tokens: + assert validate_refresh_token(mock_redis, user_id, token) is True + + # Invalidate all + count = invalidate_all_user_refresh_tokens(mock_redis, user_id) + assert count == 3 + + # All should be invalid now + for token in tokens: + assert validate_refresh_token(mock_redis, user_id, token) is False + + def test_decode_refresh_token_user_id(self, mock_redis): + """Test finding user ID from refresh token.""" + user_id = "test-user-456" + token = generate_refresh_token() + + # Store the token + store_refresh_token(mock_redis, user_id, token) + + # Find user ID + found_user_id = decode_refresh_token_user_id(token, mock_redis) + assert found_user_id == user_id + + # Invalid token should return None + assert decode_refresh_token_user_id("invalid-token", mock_redis) is None + + +class TestRefreshTokenEndpoint: + """Test the refresh token API endpoint.""" + + def test_refresh_token_success(self, client, db, mock_redis): + """Test successful token refresh.""" + user_id = "00000000-0000-0000-0000-000000000001" + + # Generate and store a refresh token + refresh_token = generate_refresh_token() + store_refresh_token(mock_redis, user_id, refresh_token) + + # Call refresh endpoint + response = client.post( + "/api/auth/refresh", + json={"refresh_token": refresh_token}, + ) + + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + assert "refresh_token" in data + assert data["token_type"] == "bearer" + assert data["expires_in"] > 0 + + # Old refresh token should be invalidated (rotation) + assert validate_refresh_token(mock_redis, user_id, refresh_token) is False + + # New refresh token should be valid + assert validate_refresh_token(mock_redis, user_id, data["refresh_token"]) is True + + def test_refresh_token_invalid(self, client, mock_redis): + """Test refresh with invalid token.""" + response = client.post( + "/api/auth/refresh", + json={"refresh_token": "invalid-token"}, + ) + + assert response.status_code == 401 + assert "Invalid or expired refresh token" in response.json()["detail"] + + def test_refresh_token_rotation(self, client, db, mock_redis): + """Test that refresh tokens are rotated (old one invalidated).""" + user_id = "00000000-0000-0000-0000-000000000001" + + # Generate and store initial refresh token + initial_token = generate_refresh_token() + store_refresh_token(mock_redis, user_id, initial_token) + + # First refresh + response1 = client.post( + "/api/auth/refresh", + json={"refresh_token": initial_token}, + ) + assert response1.status_code == 200 + new_token = response1.json()["refresh_token"] + + # Try to reuse the old token (should fail due to rotation) + response2 = client.post( + "/api/auth/refresh", + json={"refresh_token": initial_token}, + ) + assert response2.status_code == 401 + + # New token should still work + response3 = client.post( + "/api/auth/refresh", + json={"refresh_token": new_token}, + ) + assert response3.status_code == 200 + + def test_refresh_token_disabled_user(self, client, db, mock_redis): + """Test that disabled users cannot refresh tokens.""" + from app.models.user import User + + # Create a disabled user + disabled_user = User( + id="disabled-user-123", + email="disabled@example.com", + name="Disabled User", + is_active=False, + ) + db.add(disabled_user) + db.commit() + + # Generate and store refresh token for disabled user + refresh_token = generate_refresh_token() + store_refresh_token(mock_redis, disabled_user.id, refresh_token) + + # Try to refresh + response = client.post( + "/api/auth/refresh", + json={"refresh_token": refresh_token}, + ) + + assert response.status_code == 403 + assert "disabled" in response.json()["detail"].lower() diff --git a/backend/tests/test_backend_reliability.py b/backend/tests/test_backend_reliability.py index a1cb7f5..654bcd2 100644 --- a/backend/tests/test_backend_reliability.py +++ b/backend/tests/test_backend_reliability.py @@ -128,7 +128,7 @@ class TestRedisFailover: class TestBlockerDeletionCheck: """Test blocker check before task deletion.""" - def test_delete_task_with_blockers_warning(self, client, admin_token, db): + def test_delete_task_with_blockers_warning(self, client, admin_token, csrf_token, db): """Test that deleting task with blockers shows warning.""" from app.models import Space, Project, Task, TaskStatus, TaskDependency @@ -174,7 +174,7 @@ class TestBlockerDeletionCheck: # Try to delete without force response = client.delete( "/api/tasks/blocker-task", - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token} ) # Should return warning or require confirmation @@ -185,7 +185,7 @@ class TestBlockerDeletionCheck: if "warning" in data or "blocker_count" in data: assert data.get("blocker_count", 0) >= 1 or "blocker" in str(data).lower() - def test_force_delete_resolves_blockers(self, client, admin_token, db): + def test_force_delete_resolves_blockers(self, client, admin_token, csrf_token, db): """Test that force delete resolves blockers.""" from app.models import Space, Project, Task, TaskStatus, TaskDependency @@ -231,7 +231,7 @@ class TestBlockerDeletionCheck: # Force delete response = client.delete( "/api/tasks/force-del-task?force_delete=true", - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token} ) assert response.status_code == 200 @@ -240,7 +240,7 @@ class TestBlockerDeletionCheck: db.refresh(task_to_delete) assert task_to_delete.is_deleted is True - def test_delete_task_without_blockers(self, client, admin_token, db): + def test_delete_task_without_blockers(self, client, admin_token, csrf_token, db): """Test deleting task without blockers succeeds normally.""" from app.models import Space, Project, Task, TaskStatus @@ -267,7 +267,7 @@ class TestBlockerDeletionCheck: # Delete should succeed without warning response = client.delete( "/api/tasks/no-blocker-task", - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token} ) assert response.status_code == 200 diff --git a/backend/tests/test_collaboration.py b/backend/tests/test_collaboration.py index 84b139d..8c743ed 100644 --- a/backend/tests/test_collaboration.py +++ b/backend/tests/test_collaboration.py @@ -36,6 +36,13 @@ def user_token(client, mock_redis, test_user): return token +@pytest.fixture +def user_csrf_token(test_user): + """Generate a CSRF token for the test user.""" + from app.core.security import generate_csrf_token + return generate_csrf_token(test_user.id) + + @pytest.fixture def test_space(db): """Create a test space.""" @@ -100,11 +107,11 @@ def test_task(db, test_project, test_status): class TestComments: """Tests for Comments API.""" - def test_create_comment(self, client, admin_token, test_task): + def test_create_comment(self, client, auth_headers, test_task): """Test creating a comment.""" response = client.post( f"/api/tasks/{test_task.id}/comments", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={"content": "This is a test comment"}, ) assert response.status_code == 201 @@ -136,7 +143,7 @@ class TestComments: assert len(data["comments"]) == 1 assert data["comments"][0]["content"] == "Test comment" - def test_update_comment(self, client, admin_token, db, test_task): + def test_update_comment(self, client, auth_headers, db, test_task): """Test updating a comment.""" comment = Comment( id=str(uuid.uuid4()), @@ -149,7 +156,7 @@ class TestComments: response = client.put( f"/api/comments/{comment.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={"content": "Updated content"}, ) assert response.status_code == 200 @@ -157,7 +164,7 @@ class TestComments: assert data["content"] == "Updated content" assert data["is_edited"] is True - def test_delete_comment(self, client, admin_token, db, test_task): + def test_delete_comment(self, client, auth_headers, db, test_task): """Test deleting a comment (soft delete).""" comment = Comment( id=str(uuid.uuid4()), @@ -170,7 +177,7 @@ class TestComments: response = client.delete( f"/api/comments/{comment.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 204 @@ -178,13 +185,13 @@ class TestComments: db.refresh(comment) assert comment.is_deleted is True - def test_mention_limit(self, client, admin_token, test_task): + def test_mention_limit(self, client, auth_headers, test_task): """Test that @mention limit is enforced.""" # Create content with more than 10 mentions mentions = " ".join([f"@user{i}" for i in range(15)]) response = client.post( f"/api/tasks/{test_task.id}/comments", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={"content": f"Test with many mentions: {mentions}"}, ) assert response.status_code == 400 @@ -218,7 +225,7 @@ class TestNotifications: assert data["total"] >= 1 assert data["unread_count"] >= 1 - def test_mark_notification_as_read(self, client, admin_token, db): + def test_mark_notification_as_read(self, client, auth_headers, db): """Test marking a notification as read.""" notification = Notification( id=str(uuid.uuid4()), @@ -233,14 +240,14 @@ class TestNotifications: response = client.put( f"/api/notifications/{notification.id}/read", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 data = response.json() assert data["is_read"] is True assert data["read_at"] is not None - def test_mark_all_as_read(self, client, admin_token, db): + def test_mark_all_as_read(self, client, auth_headers, db): """Test marking all notifications as read.""" # Create multiple unread notifications for i in range(3): @@ -257,7 +264,7 @@ class TestNotifications: response = client.put( "/api/notifications/read-all", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 data = response.json() @@ -290,11 +297,11 @@ class TestNotifications: class TestBlockers: """Tests for Blockers API.""" - def test_create_blocker(self, client, admin_token, test_task): + def test_create_blocker(self, client, auth_headers, test_task): """Test creating a blocker.""" response = client.post( f"/api/tasks/{test_task.id}/blockers", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={"reason": "Waiting for external dependency"}, ) assert response.status_code == 201 @@ -302,7 +309,7 @@ class TestBlockers: assert data["reason"] == "Waiting for external dependency" assert data["resolved_at"] is None - def test_resolve_blocker(self, client, admin_token, db, test_task): + def test_resolve_blocker(self, client, auth_headers, db, test_task): """Test resolving a blocker.""" blocker = Blocker( id=str(uuid.uuid4()), @@ -316,7 +323,7 @@ class TestBlockers: response = client.put( f"/api/blockers/{blocker.id}/resolve", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={"resolution_note": "Issue resolved by updating config"}, ) assert response.status_code == 200 @@ -348,7 +355,7 @@ class TestBlockers: assert data["total"] == 1 assert data["blockers"][0]["reason"] == "Test blocker" - def test_cannot_create_duplicate_active_blocker(self, client, admin_token, db, test_task): + def test_cannot_create_duplicate_active_blocker(self, client, auth_headers, db, test_task): """Test that duplicate active blockers are prevented.""" # Create first blocker blocker = Blocker( @@ -363,7 +370,7 @@ class TestBlockers: # Try to create second blocker response = client.post( f"/api/tasks/{test_task.id}/blockers", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={"reason": "Second blocker"}, ) assert response.status_code == 400 diff --git a/backend/tests/test_concurrency_reliability.py b/backend/tests/test_concurrency_reliability.py index 1d7680b..6ef2b00 100644 --- a/backend/tests/test_concurrency_reliability.py +++ b/backend/tests/test_concurrency_reliability.py @@ -18,7 +18,7 @@ from datetime import datetime, timedelta class TestOptimisticLocking: """Test optimistic locking for concurrent updates.""" - def test_version_increments_on_update(self, client, admin_token, db): + def test_version_increments_on_update(self, client, admin_token, csrf_token, db): """Test that task version increments on successful update.""" from app.models import Space, Project, Task, TaskStatus @@ -47,7 +47,7 @@ class TestOptimisticLocking: response = client.patch( "/api/tasks/task-1", json={"title": "Updated Task", "version": 1}, - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token} ) assert response.status_code == 200 @@ -55,7 +55,7 @@ class TestOptimisticLocking: assert data["title"] == "Updated Task" assert data["version"] == 2 # Version should increment - def test_version_conflict_returns_409(self, client, admin_token, db): + def test_version_conflict_returns_409(self, client, admin_token, csrf_token, db): """Test that stale version returns 409 Conflict.""" from app.models import Space, Project, Task, TaskStatus @@ -84,7 +84,7 @@ class TestOptimisticLocking: response = client.patch( "/api/tasks/task-2", json={"title": "Stale Update", "version": 1}, - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token} ) assert response.status_code == 409 @@ -94,7 +94,7 @@ class TestOptimisticLocking: assert detail.get("current_version") == 5 assert detail.get("provided_version") == 1 - def test_update_without_version_succeeds(self, client, admin_token, db): + def test_update_without_version_succeeds(self, client, admin_token, csrf_token, db): """Test that update without version (for backward compatibility) still works.""" from app.models import Space, Project, Task, TaskStatus @@ -123,7 +123,7 @@ class TestOptimisticLocking: response = client.patch( "/api/tasks/task-3", json={"title": "No Version Update"}, - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token} ) # Should succeed (backward compatibility) @@ -179,7 +179,7 @@ class TestTriggerRetryMechanism: class TestCascadeRestore: """Test cascade restore for soft-deleted tasks.""" - def test_restore_parent_with_children(self, client, admin_token, db): + def test_restore_parent_with_children(self, client, admin_token, csrf_token, db): """Test restoring parent task also restores children deleted at same time.""" from app.models import Space, Project, Task, TaskStatus from datetime import datetime @@ -236,7 +236,7 @@ class TestCascadeRestore: response = client.post( "/api/tasks/parent-task/restore", json={"cascade": True}, - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token} ) assert response.status_code == 200 @@ -254,7 +254,7 @@ class TestCascadeRestore: assert child_task1.is_deleted is False assert child_task2.is_deleted is False - def test_restore_parent_only(self, client, admin_token, db): + def test_restore_parent_only(self, client, admin_token, csrf_token, db): """Test restoring parent task without cascade leaves children deleted.""" from app.models import Space, Project, Task, TaskStatus from datetime import datetime @@ -299,7 +299,7 @@ class TestCascadeRestore: response = client.post( "/api/tasks/parent-task-2/restore", json={"cascade": False}, - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token} ) assert response.status_code == 200 diff --git a/backend/tests/test_custom_fields.py b/backend/tests/test_custom_fields.py index bb1b644..37be39b 100644 --- a/backend/tests/test_custom_fields.py +++ b/backend/tests/test_custom_fields.py @@ -39,7 +39,7 @@ class TestCustomFieldsCRUD: db.commit() return project - def test_create_text_field(self, client, db, admin_token): + def test_create_text_field(self, client, db, auth_headers): """Test creating a text custom field.""" project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") @@ -50,7 +50,7 @@ class TestCustomFieldsCRUD: "field_type": "text", "is_required": False, }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 201 @@ -59,7 +59,7 @@ class TestCustomFieldsCRUD: assert data["field_type"] == "text" assert data["is_required"] is False - def test_create_number_field(self, client, db, admin_token): + def test_create_number_field(self, client, db, auth_headers): """Test creating a number custom field.""" project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") @@ -70,7 +70,7 @@ class TestCustomFieldsCRUD: "field_type": "number", "is_required": True, }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 201 @@ -79,7 +79,7 @@ class TestCustomFieldsCRUD: assert data["field_type"] == "number" assert data["is_required"] is True - def test_create_dropdown_field(self, client, db, admin_token): + def test_create_dropdown_field(self, client, db, auth_headers): """Test creating a dropdown custom field.""" project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") @@ -91,7 +91,7 @@ class TestCustomFieldsCRUD: "options": ["Frontend", "Backend", "Database", "API"], "is_required": False, }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 201 @@ -100,7 +100,7 @@ class TestCustomFieldsCRUD: assert data["field_type"] == "dropdown" assert data["options"] == ["Frontend", "Backend", "Database", "API"] - def test_create_dropdown_field_without_options_fails(self, client, db, admin_token): + def test_create_dropdown_field_without_options_fails(self, client, db, auth_headers): """Test that creating a dropdown field without options fails.""" project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") @@ -111,12 +111,12 @@ class TestCustomFieldsCRUD: "field_type": "dropdown", "options": [], }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 422 # Validation error - def test_create_formula_field(self, client, db, admin_token): + def test_create_formula_field(self, client, db, auth_headers): """Test creating a formula custom field.""" project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") @@ -127,7 +127,7 @@ class TestCustomFieldsCRUD: "name": "hours_worked", "field_type": "number", }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) # Create formula field @@ -138,7 +138,7 @@ class TestCustomFieldsCRUD: "field_type": "formula", "formula": "{time_spent} / {original_estimate} * 100", }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 201 @@ -147,7 +147,7 @@ class TestCustomFieldsCRUD: assert data["field_type"] == "formula" assert "{time_spent}" in data["formula"] - def test_list_custom_fields(self, client, db, admin_token): + def test_list_custom_fields(self, client, db, auth_headers): """Test listing custom fields for a project.""" project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") @@ -155,17 +155,17 @@ class TestCustomFieldsCRUD: client.post( f"/api/projects/{project.id}/custom-fields", json={"name": "Field 1", "field_type": "text"}, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) client.post( f"/api/projects/{project.id}/custom-fields", json={"name": "Field 2", "field_type": "number"}, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) response = client.get( f"/api/projects/{project.id}/custom-fields", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 @@ -173,7 +173,7 @@ class TestCustomFieldsCRUD: assert data["total"] == 2 assert len(data["fields"]) == 2 - def test_update_custom_field(self, client, db, admin_token): + def test_update_custom_field(self, client, db, auth_headers): """Test updating a custom field.""" project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") @@ -181,7 +181,7 @@ class TestCustomFieldsCRUD: create_response = client.post( f"/api/projects/{project.id}/custom-fields", json={"name": "Original Name", "field_type": "text"}, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) field_id = create_response.json()["id"] @@ -189,7 +189,7 @@ class TestCustomFieldsCRUD: response = client.put( f"/api/custom-fields/{field_id}", json={"name": "Updated Name", "is_required": True}, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 @@ -197,7 +197,7 @@ class TestCustomFieldsCRUD: assert data["name"] == "Updated Name" assert data["is_required"] is True - def test_delete_custom_field(self, client, db, admin_token): + def test_delete_custom_field(self, client, db, auth_headers): """Test deleting a custom field.""" project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") @@ -205,14 +205,14 @@ class TestCustomFieldsCRUD: create_response = client.post( f"/api/projects/{project.id}/custom-fields", json={"name": "To Delete", "field_type": "text"}, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) field_id = create_response.json()["id"] # Delete it response = client.delete( f"/api/custom-fields/{field_id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 204 @@ -220,11 +220,11 @@ class TestCustomFieldsCRUD: # Verify it's gone get_response = client.get( f"/api/custom-fields/{field_id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert get_response.status_code == 404 - def test_max_fields_limit(self, client, db, admin_token): + def test_max_fields_limit(self, client, db, auth_headers): """Test that maximum 20 custom fields per project is enforced.""" project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") @@ -233,7 +233,7 @@ class TestCustomFieldsCRUD: response = client.post( f"/api/projects/{project.id}/custom-fields", json={"name": f"Field {i}", "field_type": "text"}, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 201 @@ -241,12 +241,12 @@ class TestCustomFieldsCRUD: response = client.post( f"/api/projects/{project.id}/custom-fields", json={"name": "Field 21", "field_type": "text"}, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 400 assert "Maximum" in response.json()["detail"] - def test_duplicate_name_rejected(self, client, db, admin_token): + def test_duplicate_name_rejected(self, client, db, auth_headers): """Test that duplicate field names are rejected.""" project = self.setup_project(db, "00000000-0000-0000-0000-000000000001") @@ -254,14 +254,14 @@ class TestCustomFieldsCRUD: client.post( f"/api/projects/{project.id}/custom-fields", json={"name": "Unique Name", "field_type": "text"}, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) # Try to create another with same name response = client.post( f"/api/projects/{project.id}/custom-fields", json={"name": "Unique Name", "field_type": "number"}, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 400 assert "already exists" in response.json()["detail"] @@ -311,7 +311,7 @@ class TestFormulaService: class TestCustomValuesWithTasks: """Test custom values integration with tasks.""" - def setup_project_with_fields(self, db, client, admin_token, owner_id: str): + def setup_project_with_fields(self, db, client, auth_headers, owner_id: str): """Create a project with custom fields for testing.""" space = Space( id="test-space-002", @@ -342,23 +342,23 @@ class TestCustomValuesWithTasks: text_response = client.post( f"/api/projects/{project.id}/custom-fields", json={"name": "sprint_number", "field_type": "text"}, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) text_field_id = text_response.json()["id"] number_response = client.post( f"/api/projects/{project.id}/custom-fields", json={"name": "story_points", "field_type": "number"}, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) number_field_id = number_response.json()["id"] return project, text_field_id, number_field_id - def test_create_task_with_custom_values(self, client, db, admin_token): + def test_create_task_with_custom_values(self, client, db, auth_headers): """Test creating a task with custom values.""" project, text_field_id, number_field_id = self.setup_project_with_fields( - db, client, admin_token, "00000000-0000-0000-0000-000000000001" + db, client, auth_headers, "00000000-0000-0000-0000-000000000001" ) response = client.post( @@ -370,15 +370,15 @@ class TestCustomValuesWithTasks: {"field_id": number_field_id, "value": "8"}, ], }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 201 - def test_get_task_includes_custom_values(self, client, db, admin_token): + def test_get_task_includes_custom_values(self, client, db, auth_headers): """Test that getting a task includes custom values.""" project, text_field_id, number_field_id = self.setup_project_with_fields( - db, client, admin_token, "00000000-0000-0000-0000-000000000001" + db, client, auth_headers, "00000000-0000-0000-0000-000000000001" ) # Create task with custom values @@ -391,14 +391,14 @@ class TestCustomValuesWithTasks: {"field_id": number_field_id, "value": "8"}, ], }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) task_id = create_response.json()["id"] # Get task and check custom values get_response = client.get( f"/api/tasks/{task_id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert get_response.status_code == 200 @@ -406,10 +406,10 @@ class TestCustomValuesWithTasks: assert data["custom_values"] is not None assert len(data["custom_values"]) >= 2 - def test_update_task_custom_values(self, client, db, admin_token): + def test_update_task_custom_values(self, client, db, auth_headers): """Test updating custom values on a task.""" project, text_field_id, number_field_id = self.setup_project_with_fields( - db, client, admin_token, "00000000-0000-0000-0000-000000000001" + db, client, auth_headers, "00000000-0000-0000-0000-000000000001" ) # Create task @@ -421,7 +421,7 @@ class TestCustomValuesWithTasks: {"field_id": text_field_id, "value": "Sprint 5"}, ], }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) task_id = create_response.json()["id"] @@ -434,7 +434,7 @@ class TestCustomValuesWithTasks: {"field_id": number_field_id, "value": "13"}, ], }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert update_response.status_code == 200 diff --git a/backend/tests/test_dashboard.py b/backend/tests/test_dashboard.py index b3e9c15..b087175 100644 --- a/backend/tests/test_dashboard.py +++ b/backend/tests/test_dashboard.py @@ -619,7 +619,7 @@ class TestDashboardAPI: def test_dashboard_unauthorized(self, client, db): """Unauthenticated requests should fail.""" response = client.get("/api/dashboard") - assert response.status_code == 403 + assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized def test_dashboard_with_user_tasks(self, client, db, admin_token): """Dashboard should reflect user's tasks correctly.""" diff --git a/backend/tests/test_encryption.py b/backend/tests/test_encryption.py index 56886d7..c50133f 100644 --- a/backend/tests/test_encryption.py +++ b/backend/tests/test_encryption.py @@ -312,6 +312,13 @@ class TestConfidentialProjectUpload: mock_redis.setex(f"session:{test_user.id}", 900, token) return token + @pytest.fixture + def test_user_csrf_token(self, test_user): + """Generate a CSRF token for the test user.""" + from app.core.security import generate_csrf_token + + return generate_csrf_token(test_user.id) + @pytest.fixture def test_space(self, db, test_user): """Create a test space.""" @@ -364,7 +371,7 @@ class TestConfidentialProjectUpload: return task def test_upload_confidential_project_encryption_unavailable( - self, client, test_user_token, test_task, db + self, client, test_user_token, test_user_csrf_token, test_task, db ): """Test that uploading to confidential project returns 400 when encryption is unavailable.""" from io import BytesIO @@ -378,7 +385,7 @@ class TestConfidentialProjectUpload: response = client.post( f"/api/tasks/{test_task.id}/attachments", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token}, files=files, ) @@ -387,7 +394,7 @@ class TestConfidentialProjectUpload: assert "environment variable" in response.json()["detail"] def test_upload_confidential_project_no_active_key( - self, client, test_user_token, test_task, db + self, client, test_user_token, test_user_csrf_token, test_task, db ): """Test that uploading to confidential project returns 400 when no active encryption key exists.""" from io import BytesIO @@ -408,7 +415,7 @@ class TestConfidentialProjectUpload: response = client.post( f"/api/tasks/{test_task.id}/attachments", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token}, files=files, ) diff --git a/backend/tests/test_health.py b/backend/tests/test_health.py index e148176..7160b2a 100644 --- a/backend/tests/test_health.py +++ b/backend/tests/test_health.py @@ -614,7 +614,7 @@ class TestHealthAPI: def test_unauthorized_access(self, client, db): """Unauthenticated requests should fail.""" response = client.get("/api/projects/health/dashboard") - assert response.status_code == 403 + assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized def test_dashboard_with_status_filter(self, client, db, admin_token): """Dashboard should respect status filter.""" diff --git a/backend/tests/test_reports.py b/backend/tests/test_reports.py index 2fdedc3..e654c7c 100644 --- a/backend/tests/test_reports.py +++ b/backend/tests/test_reports.py @@ -38,6 +38,14 @@ def test_user_token(client, mock_redis, test_user): return token +@pytest.fixture +def test_user_csrf_token(test_user): + """Generate a CSRF token for the test user.""" + from app.core.security import generate_csrf_token + + return generate_csrf_token(test_user.id) + + @pytest.fixture def test_space(db, test_user): """Create a test space.""" @@ -284,11 +292,11 @@ class TestReportAPI: assert "projects" in data assert data["summary"]["total_tasks"] == 3 - def test_generate_weekly_report_api(self, client, test_user_token, test_project, test_tasks, test_statuses): + def test_generate_weekly_report_api(self, client, test_user_token, test_user_csrf_token, test_project, test_tasks, test_statuses): """Test generating weekly report via API.""" response = client.post( "/api/reports/weekly/generate", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token}, ) assert response.status_code == 200 @@ -297,7 +305,7 @@ class TestReportAPI: assert "report_id" in data assert "summary" in data - def test_weekly_report_subscription_toggle(self, client, test_user_token, db, test_user): + def test_weekly_report_subscription_toggle(self, client, test_user_token, test_user_csrf_token, db, test_user): """Test weekly report subscription toggle endpoints.""" response = client.get( "/api/reports/weekly/subscription", @@ -308,7 +316,7 @@ class TestReportAPI: response = client.put( "/api/reports/weekly/subscription", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token}, json={"is_active": True}, ) assert response.status_code == 200 @@ -323,7 +331,7 @@ class TestReportAPI: response = client.put( "/api/reports/weekly/subscription", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token}, json={"is_active": False}, ) assert response.status_code == 200 diff --git a/backend/tests/test_schedule_triggers.py b/backend/tests/test_schedule_triggers.py index 836df8b..5e22cea 100644 --- a/backend/tests/test_schedule_triggers.py +++ b/backend/tests/test_schedule_triggers.py @@ -52,6 +52,14 @@ def test_user_token(client, mock_redis, test_user): return token +@pytest.fixture +def test_user_csrf_token(test_user): + """Generate a CSRF token for the test user.""" + from app.core.security import generate_csrf_token + + return generate_csrf_token(test_user.id) + + @pytest.fixture def test_space(db, test_user): """Create a test space.""" @@ -445,11 +453,11 @@ class TestDeadlineReminderLogic: class TestScheduleTriggerAPI: """Tests for Schedule Trigger API endpoints.""" - def test_create_cron_trigger(self, client, test_user_token, test_project): + def test_create_cron_trigger(self, client, test_user_token, test_user_csrf_token, test_project): """Test creating a schedule trigger with cron expression.""" response = client.post( f"/api/projects/{test_project.id}/triggers", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token}, json={ "name": "Weekly Monday Reminder", "description": "Remind every Monday at 9am", @@ -471,11 +479,11 @@ class TestScheduleTriggerAPI: assert data["trigger_type"] == "schedule" assert data["conditions"]["cron_expression"] == "0 9 * * 1" - def test_create_deadline_trigger(self, client, test_user_token, test_project): + def test_create_deadline_trigger(self, client, test_user_token, test_user_csrf_token, test_project): """Test creating a schedule trigger with deadline reminder.""" response = client.post( f"/api/projects/{test_project.id}/triggers", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token}, json={ "name": "Deadline Reminder", "description": "Remind 5 days before deadline", @@ -494,11 +502,11 @@ class TestScheduleTriggerAPI: data = response.json() assert data["conditions"]["deadline_reminder_days"] == 5 - def test_create_schedule_trigger_invalid_cron(self, client, test_user_token, test_project): + def test_create_schedule_trigger_invalid_cron(self, client, test_user_token, test_user_csrf_token, test_project): """Test creating a schedule trigger with invalid cron expression.""" response = client.post( f"/api/projects/{test_project.id}/triggers", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token}, json={ "name": "Invalid Cron Trigger", "trigger_type": "schedule", @@ -512,11 +520,11 @@ class TestScheduleTriggerAPI: assert response.status_code == 400 assert "Invalid cron expression" in response.json()["detail"] - def test_create_schedule_trigger_missing_condition(self, client, test_user_token, test_project): + def test_create_schedule_trigger_missing_condition(self, client, test_user_token, test_user_csrf_token, test_project): """Test creating a schedule trigger without cron or deadline condition.""" response = client.post( f"/api/projects/{test_project.id}/triggers", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token}, json={ "name": "Empty Schedule Trigger", "trigger_type": "schedule", @@ -528,11 +536,11 @@ class TestScheduleTriggerAPI: assert response.status_code == 400 assert "require either cron_expression or deadline_reminder_days" in response.json()["detail"] - def test_update_schedule_trigger_cron(self, client, test_user_token, cron_trigger): + def test_update_schedule_trigger_cron(self, client, test_user_token, test_user_csrf_token, cron_trigger): """Test updating a schedule trigger's cron expression.""" response = client.put( f"/api/triggers/{cron_trigger.id}", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token}, json={ "conditions": { "cron_expression": "0 10 * * *", # Changed to 10am @@ -544,11 +552,11 @@ class TestScheduleTriggerAPI: data = response.json() assert data["conditions"]["cron_expression"] == "0 10 * * *" - def test_update_schedule_trigger_invalid_cron(self, client, test_user_token, cron_trigger): + def test_update_schedule_trigger_invalid_cron(self, client, test_user_token, test_user_csrf_token, cron_trigger): """Test updating a schedule trigger with invalid cron expression.""" response = client.put( f"/api/triggers/{cron_trigger.id}", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token}, json={ "conditions": { "cron_expression": "not valid", diff --git a/backend/tests/test_soft_delete.py b/backend/tests/test_soft_delete.py index 4f4daa8..0e6fcdf 100644 --- a/backend/tests/test_soft_delete.py +++ b/backend/tests/test_soft_delete.py @@ -69,6 +69,22 @@ def regular_token(client, mock_redis, test_regular_user): return token +@pytest.fixture +def csrf_token(test_admin): + """Generate a CSRF token for the test admin user.""" + from app.core.security import generate_csrf_token + return generate_csrf_token(test_admin.id) + + +@pytest.fixture +def auth_headers(admin_token, csrf_token): + """Get complete auth headers including both Authorization and CSRF token.""" + return { + "Authorization": f"Bearer {admin_token}", + "X-CSRF-Token": csrf_token, + } + + @pytest.fixture def test_space(db, test_admin): """Create a test space.""" @@ -148,11 +164,11 @@ def test_task_with_subtask(db, test_project, test_admin, test_status, test_task) class TestSoftDelete: """Tests for soft delete functionality.""" - def test_delete_task_soft_deletes(self, client, admin_token, test_task, db): + def test_delete_task_soft_deletes(self, client, auth_headers, test_task, db): """Test that DELETE soft-deletes a task.""" response = client.delete( f"/api/tasks/{test_task.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 @@ -165,36 +181,36 @@ class TestSoftDelete: assert test_task.deleted_at is not None assert test_task.deleted_by is not None - def test_deleted_task_not_in_list(self, client, admin_token, test_project, test_task, db): + def test_deleted_task_not_in_list(self, client, auth_headers, test_project, test_task, db): """Test that deleted tasks are not shown in list.""" # Delete the task client.delete( f"/api/tasks/{test_task.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) # List tasks response = client.get( f"/api/projects/{test_project.id}/tasks", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 data = response.json() assert data["total"] == 0 - def test_admin_can_list_deleted_with_include_deleted(self, client, admin_token, test_project, test_task, db): + def test_admin_can_list_deleted_with_include_deleted(self, client, auth_headers, test_project, test_task, db): """Test that admin can see deleted tasks with include_deleted parameter.""" # Delete the task client.delete( f"/api/tasks/{test_task.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) # List with include_deleted response = client.get( f"/api/projects/{test_project.id}/tasks?include_deleted=true", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 @@ -202,12 +218,12 @@ class TestSoftDelete: assert data["total"] == 1 assert data["tasks"][0]["id"] == test_task.id - def test_regular_user_cannot_see_deleted_with_include_deleted(self, client, regular_token, test_project, test_task, admin_token, db): + def test_regular_user_cannot_see_deleted_with_include_deleted(self, client, regular_token, test_project, test_task, auth_headers, db, csrf_token): """Test that non-admin cannot see deleted tasks even with include_deleted.""" # Delete the task as admin client.delete( f"/api/tasks/{test_task.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) # Try to list with include_deleted as regular user @@ -220,12 +236,12 @@ class TestSoftDelete: data = response.json() assert data["total"] == 0 - def test_get_deleted_task_returns_404_for_regular_user(self, client, admin_token, regular_token, test_task, db): + def test_get_deleted_task_returns_404_for_regular_user(self, client, auth_headers, regular_token, test_task, db): """Test that getting a deleted task returns 404 for non-admin.""" # Delete the task client.delete( f"/api/tasks/{test_task.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) # Try to get as regular user @@ -236,28 +252,28 @@ class TestSoftDelete: assert response.status_code == 404 - def test_admin_can_view_deleted_task(self, client, admin_token, test_task, db): + def test_admin_can_view_deleted_task(self, client, auth_headers, test_task, db): """Test that admin can view a deleted task.""" # Delete the task client.delete( f"/api/tasks/{test_task.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) # Get as admin response = client.get( f"/api/tasks/{test_task.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 - def test_cascade_soft_delete_subtasks(self, client, admin_token, test_task, test_task_with_subtask, db): + def test_cascade_soft_delete_subtasks(self, client, auth_headers, test_task, test_task_with_subtask, db): """Test that deleting a parent task soft-deletes its subtasks.""" # Delete the parent task response = client.delete( f"/api/tasks/{test_task.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 @@ -270,18 +286,18 @@ class TestSoftDelete: class TestRestoreTask: """Tests for task restoration functionality.""" - def test_restore_task(self, client, admin_token, test_task, db): + def test_restore_task(self, client, auth_headers, test_task, db): """Test that admin can restore a deleted task.""" # Delete the task client.delete( f"/api/tasks/{test_task.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) # Restore the task response = client.post( f"/api/tasks/{test_task.id}/restore", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 @@ -292,27 +308,29 @@ class TestRestoreTask: assert test_task.deleted_at is None assert test_task.deleted_by is None - def test_regular_user_cannot_restore(self, client, admin_token, regular_token, test_task, db): + def test_regular_user_cannot_restore(self, client, auth_headers, regular_token, test_task, db, test_regular_user): """Test that non-admin cannot restore a deleted task.""" + from app.core.security import generate_csrf_token # Delete the task client.delete( f"/api/tasks/{test_task.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) # Try to restore as regular user + regular_csrf = generate_csrf_token(test_regular_user.id) response = client.post( f"/api/tasks/{test_task.id}/restore", - headers={"Authorization": f"Bearer {regular_token}"}, + headers={"Authorization": f"Bearer {regular_token}", "X-CSRF-Token": regular_csrf}, ) assert response.status_code == 403 - def test_cannot_restore_non_deleted_task(self, client, admin_token, test_task, db): + def test_cannot_restore_non_deleted_task(self, client, auth_headers, test_task, db): """Test that restoring a non-deleted task returns error.""" response = client.post( f"/api/tasks/{test_task.id}/restore", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 400 @@ -322,12 +340,12 @@ class TestRestoreTask: class TestSubtaskCount: """Tests for subtask count excluding deleted.""" - def test_subtask_count_excludes_deleted(self, client, admin_token, test_task, test_task_with_subtask, db): + def test_subtask_count_excludes_deleted(self, client, auth_headers, test_task, test_task_with_subtask, db): """Test that subtask_count excludes deleted subtasks.""" # Get parent task before deletion response = client.get( f"/api/tasks/{test_task.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 assert response.json()["subtask_count"] == 1 @@ -335,13 +353,13 @@ class TestSubtaskCount: # Delete subtask client.delete( f"/api/tasks/{test_task_with_subtask.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) # Get parent task after deletion response = client.get( f"/api/tasks/{test_task.id}", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, ) assert response.status_code == 200 assert response.json()["subtask_count"] == 0 diff --git a/backend/tests/test_spaces.py b/backend/tests/test_spaces.py index b8a3fe5..45240bb 100644 --- a/backend/tests/test_spaces.py +++ b/backend/tests/test_spaces.py @@ -57,7 +57,7 @@ class TestSpacesAPI: "/api/spaces", json={"name": "Test Space", "description": "Test"} ) - assert response.status_code == 403 # No auth header + assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized def test_space_routes_exist(self): """Test that all space routes are registered.""" diff --git a/backend/tests/test_task_dependencies.py b/backend/tests/test_task_dependencies.py index 02b7b73..549ac27 100644 --- a/backend/tests/test_task_dependencies.py +++ b/backend/tests/test_task_dependencies.py @@ -783,7 +783,7 @@ class TestDateValidation: class TestDependencyCRUDAPI: """Test dependency CRUD API endpoints.""" - def test_create_dependency(self, client, db, admin_token): + def test_create_dependency(self, client, db, admin_token, csrf_token): """Test creating a dependency via API.""" # Create test data space = Space( @@ -838,7 +838,7 @@ class TestDependencyCRUDAPI: "dependency_type": "FS", "lag_days": 0 }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}, ) assert response.status_code == 201 @@ -914,7 +914,7 @@ class TestDependencyCRUDAPI: assert data["total"] >= 1 assert any(d["predecessor_id"] == "task-api-list-1" for d in data["dependencies"]) - def test_delete_dependency(self, client, db, admin_token): + def test_delete_dependency(self, client, db, admin_token, csrf_token): """Test deleting a dependency.""" # Create test data space = Space( @@ -973,7 +973,7 @@ class TestDependencyCRUDAPI: # Delete dependency response = client.delete( "/api/task-dependencies/dep-api-del", - headers={"Authorization": f"Bearer {admin_token}"}, + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}, ) assert response.status_code == 204 @@ -984,7 +984,7 @@ class TestDependencyCRUDAPI: ).first() assert dep_check is None - def test_circular_dependency_rejected_via_api(self, client, db, admin_token): + def test_circular_dependency_rejected_via_api(self, client, db, admin_token, csrf_token): """Test that circular dependencies are rejected via API.""" # Create test data space = Space( @@ -1049,7 +1049,7 @@ class TestDependencyCRUDAPI: "dependency_type": "FS", "lag_days": 0 }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}, ) assert response.status_code == 400 @@ -1060,7 +1060,7 @@ class TestDependencyCRUDAPI: class TestTaskDateValidationAPI: """Test task date validation in task API.""" - def test_create_task_with_invalid_dates_rejected(self, client, db, admin_token): + def test_create_task_with_invalid_dates_rejected(self, client, db, admin_token, csrf_token): """Test that creating a task with start_date > due_date is rejected.""" # Create test data space = Space( @@ -1099,13 +1099,13 @@ class TestTaskDateValidationAPI: "start_date": (now + timedelta(days=10)).isoformat(), "due_date": now.isoformat(), }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}, ) assert response.status_code == 400 assert "Start date cannot be after due date" in response.json()["detail"] - def test_update_task_with_invalid_dates_rejected(self, client, db, admin_token): + def test_update_task_with_invalid_dates_rejected(self, client, db, admin_token, csrf_token): """Test that updating a task to have start_date > due_date is rejected.""" # Create test data space = Space( @@ -1153,12 +1153,12 @@ class TestTaskDateValidationAPI: json={ "start_date": (now + timedelta(days=20)).isoformat(), }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}, ) assert response.status_code == 400 - def test_create_task_with_valid_dates_accepted(self, client, db, admin_token): + def test_create_task_with_valid_dates_accepted(self, client, db, admin_token, csrf_token): """Test that creating a task with valid dates is accepted.""" # Create test data space = Space( @@ -1197,7 +1197,7 @@ class TestTaskDateValidationAPI: "start_date": now.isoformat(), "due_date": (now + timedelta(days=10)).isoformat(), }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}, ) assert response.status_code == 201 @@ -1217,7 +1217,7 @@ class TestDependencyTypes: assert DependencyType.FF.value == "FF" assert DependencyType.SF.value == "SF" - def test_create_dependency_with_different_types(self, client, db, admin_token): + def test_create_dependency_with_different_types(self, client, db, admin_token, csrf_token): """Test creating dependencies with different types via API.""" # Create test data space = Space( @@ -1268,7 +1268,7 @@ class TestDependencyTypes: "dependency_type": dep_type, "lag_days": i }, - headers={"Authorization": f"Bearer {admin_token}"}, + headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}, ) assert response.status_code == 201 diff --git a/backend/tests/test_triggers.py b/backend/tests/test_triggers.py index ade96ac..32878a1 100644 --- a/backend/tests/test_triggers.py +++ b/backend/tests/test_triggers.py @@ -43,6 +43,22 @@ def test_user_token(client, mock_redis, test_user): return token +@pytest.fixture +def test_user_csrf_token(test_user): + """Generate a CSRF token for the test user.""" + from app.core.security import generate_csrf_token + return generate_csrf_token(test_user.id) + + +@pytest.fixture +def test_user_auth_headers(test_user_token, test_user_csrf_token): + """Get complete auth headers for test user.""" + return { + "Authorization": f"Bearer {test_user_token}", + "X-CSRF-Token": test_user_csrf_token, + } + + @pytest.fixture def test_space(db, test_user): """Create a test space.""" @@ -513,11 +529,11 @@ class TestTriggerNotifications: class TestTriggerAPI: """Tests for Trigger API endpoints.""" - def test_create_trigger(self, client, test_user_token, test_project, test_status): + def test_create_trigger(self, client, test_user_auth_headers, test_project, test_status): """Test creating a trigger.""" response = client.post( f"/api/projects/{test_project.id}/triggers", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers=test_user_auth_headers, json={ "name": "New Trigger", "description": "Test trigger", @@ -563,11 +579,11 @@ class TestTriggerAPI: assert data["id"] == test_trigger.id assert data["name"] == test_trigger.name - def test_update_trigger(self, client, test_user_token, test_trigger): + def test_update_trigger(self, client, test_user_auth_headers, test_trigger): """Test updating a trigger.""" response = client.put( f"/api/triggers/{test_trigger.id}", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers=test_user_auth_headers, json={ "name": "Updated Trigger", "is_active": False, @@ -579,11 +595,11 @@ class TestTriggerAPI: assert data["name"] == "Updated Trigger" assert data["is_active"] is False - def test_delete_trigger(self, client, test_user_token, test_trigger): + def test_delete_trigger(self, client, test_user_auth_headers, test_trigger, test_user_token): """Test deleting a trigger.""" response = client.delete( f"/api/triggers/{test_trigger.id}", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers=test_user_auth_headers, ) assert response.status_code == 204 @@ -616,11 +632,11 @@ class TestTriggerAPI: data = response.json() assert data["total"] >= 1 - def test_create_trigger_invalid_field(self, client, test_user_token, test_project): + def test_create_trigger_invalid_field(self, client, test_user_auth_headers, test_project): """Test creating a trigger with invalid field.""" response = client.post( f"/api/projects/{test_project.id}/triggers", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers=test_user_auth_headers, json={ "name": "Invalid Trigger", "trigger_type": "field_change", @@ -636,11 +652,11 @@ class TestTriggerAPI: assert response.status_code == 400 assert "Invalid condition field" in response.json()["detail"] - def test_create_trigger_invalid_operator(self, client, test_user_token, test_project): + def test_create_trigger_invalid_operator(self, client, test_user_auth_headers, test_project): """Test creating a trigger with invalid operator.""" response = client.post( f"/api/projects/{test_project.id}/triggers", - headers={"Authorization": f"Bearer {test_user_token}"}, + headers=test_user_auth_headers, json={ "name": "Invalid Trigger", "trigger_type": "field_change", diff --git a/backend/tests/test_users.py b/backend/tests/test_users.py index e8086e0..9fd2ec1 100644 --- a/backend/tests/test_users.py +++ b/backend/tests/test_users.py @@ -1,6 +1,7 @@ import pytest from app.models.user import User from app.models.department import Department +from app.core.security import generate_csrf_token class TestUserEndpoints: @@ -35,7 +36,7 @@ class TestUserEndpoints: ) assert response.status_code == 404 - def test_update_user(self, client, admin_token, db): + def test_update_user(self, client, auth_headers, db): """Test updating a user.""" # Create a test user test_user = User( @@ -49,7 +50,7 @@ class TestUserEndpoints: response = client.patch( "/api/users/test-user-001", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={"name": "Updated Name"}, ) assert response.status_code == 200 @@ -84,9 +85,10 @@ class TestUserEndpoints: mock_redis.setex("session:non-admin-001", 900, token) # Try to modify system admin - should fail with 403 + csrf_token = generate_csrf_token("non-admin-001") response = client.patch( "/api/users/00000000-0000-0000-0000-000000000001", - headers={"Authorization": f"Bearer {token}"}, + headers={"Authorization": f"Bearer {token}", "X-CSRF-Token": csrf_token}, json={"name": "Hacked Name"}, ) # Engineer role doesn't have users.write permission @@ -123,16 +125,17 @@ class TestCapacityUpdate: mock_redis.setex("session:capacity-user-001", 900, token) # Update own capacity + csrf_token = generate_csrf_token("capacity-user-001") response = client.put( "/api/users/capacity-user-001/capacity", - headers={"Authorization": f"Bearer {token}"}, + headers={"Authorization": f"Bearer {token}", "X-CSRF-Token": csrf_token}, json={"capacity_hours": 35.5}, ) assert response.status_code == 200 data = response.json() assert float(data["capacity"]) == 35.5 - def test_admin_can_update_other_user_capacity(self, client, admin_token, db): + def test_admin_can_update_other_user_capacity(self, client, auth_headers, db): """Test that admin can update another user's capacity.""" # Create a test user test_user = User( @@ -148,7 +151,7 @@ class TestCapacityUpdate: # Admin updates another user's capacity response = client.put( "/api/users/capacity-user-002/capacity", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={"capacity_hours": 20.0}, ) assert response.status_code == 200 @@ -189,15 +192,16 @@ class TestCapacityUpdate: mock_redis.setex("session:capacity-user-003", 900, token) # User1 tries to update user2's capacity - should fail + csrf_token = generate_csrf_token("capacity-user-003") response = client.put( "/api/users/capacity-user-004/capacity", - headers={"Authorization": f"Bearer {token}"}, + headers={"Authorization": f"Bearer {token}", "X-CSRF-Token": csrf_token}, json={"capacity_hours": 30.0}, ) assert response.status_code == 403 assert "Only admin, manager, or the user themselves" in response.json()["detail"] - def test_update_capacity_invalid_value_negative(self, client, admin_token, db): + def test_update_capacity_invalid_value_negative(self, client, auth_headers, db): """Test that negative capacity hours are rejected.""" # Create a test user test_user = User( @@ -212,7 +216,7 @@ class TestCapacityUpdate: response = client.put( "/api/users/capacity-user-005/capacity", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={"capacity_hours": -5.0}, ) # Pydantic validation returns 422 Unprocessable Entity @@ -221,7 +225,7 @@ class TestCapacityUpdate: # Check validation error message in Pydantic format assert any("non-negative" in str(err).lower() for err in error_detail) - def test_update_capacity_invalid_value_too_high(self, client, admin_token, db): + def test_update_capacity_invalid_value_too_high(self, client, auth_headers, db): """Test that capacity hours exceeding 168 are rejected.""" # Create a test user test_user = User( @@ -236,7 +240,7 @@ class TestCapacityUpdate: response = client.put( "/api/users/capacity-user-006/capacity", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={"capacity_hours": 200.0}, ) # Pydantic validation returns 422 Unprocessable Entity @@ -245,11 +249,11 @@ class TestCapacityUpdate: # Check validation error message in Pydantic format assert any("168" in str(err) for err in error_detail) - def test_update_capacity_nonexistent_user(self, client, admin_token): + def test_update_capacity_nonexistent_user(self, client, auth_headers): """Test updating capacity for a nonexistent user.""" response = client.put( "/api/users/nonexistent-user-id/capacity", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={"capacity_hours": 40.0}, ) assert response.status_code == 404 @@ -303,16 +307,17 @@ class TestCapacityUpdate: mock_redis.setex("session:manager-cap-001", 900, token) # Manager updates regular user's capacity + csrf_token = generate_csrf_token("manager-cap-001") response = client.put( "/api/users/regular-cap-001/capacity", - headers={"Authorization": f"Bearer {token}"}, + headers={"Authorization": f"Bearer {token}", "X-CSRF-Token": csrf_token}, json={"capacity_hours": 30.0}, ) assert response.status_code == 200 data = response.json() assert float(data["capacity"]) == 30.0 - def test_capacity_change_creates_audit_log(self, client, admin_token, db): + def test_capacity_change_creates_audit_log(self, client, auth_headers, db): """Test that capacity changes are recorded in audit trail.""" from app.models import AuditLog @@ -330,7 +335,7 @@ class TestCapacityUpdate: # Update capacity response = client.put( "/api/users/capacity-audit-001/capacity", - headers={"Authorization": f"Bearer {admin_token}"}, + headers=auth_headers, json={"capacity_hours": 35.0}, ) assert response.status_code == 200 diff --git a/backend/tests/test_workload.py b/backend/tests/test_workload.py index 235a51e..e793bb1 100644 --- a/backend/tests/test_workload.py +++ b/backend/tests/test_workload.py @@ -449,7 +449,7 @@ class TestWorkloadAPI: def test_unauthorized_access(self, client, db): """Unauthenticated requests should fail.""" response = client.get("/api/workload/heatmap") - assert response.status_code == 403 # No auth header + assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized class TestWorkloadAccessControl: diff --git a/frontend/public/locales/en/auth.json b/frontend/public/locales/en/auth.json index 746d56e..bbe2e7f 100644 --- a/frontend/public/locales/en/auth.json +++ b/frontend/public/locales/en/auth.json @@ -19,7 +19,8 @@ "emailRequired": "Email is required", "passwordRequired": "Password is required", "invalidEmail": "Please enter a valid email address", - "loginFailed": "Login failed. Please try again later." + "loginFailed": "Login failed. Please try again later.", + "sessionExpired": "Your session has expired. Please sign in again." }, "welcome": { "title": "Project Control Center", diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 9de2a71..c79b43c 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -129,5 +129,11 @@ "message": "Unable to display this widget.", "errorSuffix": "error" } + }, + "attachments": { + "dropzone": "Drop files here or click to upload", + "maxFileSize": "Maximum file size: {{size}}", + "uploading": "Uploading {{filename}} ({{current}}/{{total}})...", + "uploadFailed": "Upload failed" } } diff --git a/frontend/public/locales/zh-TW/auth.json b/frontend/public/locales/zh-TW/auth.json index eff4189..1e2435a 100644 --- a/frontend/public/locales/zh-TW/auth.json +++ b/frontend/public/locales/zh-TW/auth.json @@ -19,7 +19,8 @@ "emailRequired": "請輸入電子郵件", "passwordRequired": "請輸入密碼", "invalidEmail": "請輸入有效的電子郵件地址", - "loginFailed": "登入失敗,請稍後再試" + "loginFailed": "登入失敗,請稍後再試", + "sessionExpired": "您的登入時段已過期,請重新登入。" }, "welcome": { "title": "專案控制中心", diff --git a/frontend/public/locales/zh-TW/common.json b/frontend/public/locales/zh-TW/common.json index 6492ff2..e72f79f 100644 --- a/frontend/public/locales/zh-TW/common.json +++ b/frontend/public/locales/zh-TW/common.json @@ -129,5 +129,11 @@ "message": "無法顯示此元件。", "errorSuffix": "發生錯誤" } + }, + "attachments": { + "dropzone": "拖曳檔案至此或點擊上傳", + "maxFileSize": "檔案大小上限:{{size}}", + "uploading": "正在上傳 {{filename}} ({{current}}/{{total}})...", + "uploadFailed": "上傳失敗" } } diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 9481b3a..5c41418 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,21 +1,35 @@ +import { lazy, Suspense } from 'react' import { Routes, Route, Navigate } from 'react-router-dom' import { useAuth } from './contexts/AuthContext' import { Skeleton } from './components/Skeleton' import { ErrorBoundary } from './components/ErrorBoundary' import { SectionErrorBoundary } from './components/ErrorBoundaryWithI18n' -import Login from './pages/Login' -import Dashboard from './pages/Dashboard' -import Spaces from './pages/Spaces' -import Projects from './pages/Projects' -import Tasks from './pages/Tasks' -import ProjectSettings from './pages/ProjectSettings' -import MySettings from './pages/MySettings' -import AuditPage from './pages/AuditPage' -import WorkloadPage from './pages/WorkloadPage' -import ProjectHealthPage from './pages/ProjectHealthPage' import ProtectedRoute from './components/ProtectedRoute' import Layout from './components/Layout' +// Lazy load pages for code splitting +const Login = lazy(() => import('./pages/Login')) +const Dashboard = lazy(() => import('./pages/Dashboard')) +const Spaces = lazy(() => import('./pages/Spaces')) +const Projects = lazy(() => import('./pages/Projects')) +const Tasks = lazy(() => import('./pages/Tasks')) +const ProjectSettings = lazy(() => import('./pages/ProjectSettings')) +const MySettings = lazy(() => import('./pages/MySettings')) +const AuditPage = lazy(() => import('./pages/AuditPage')) +const WorkloadPage = lazy(() => import('./pages/WorkloadPage')) +const ProjectHealthPage = lazy(() => import('./pages/ProjectHealthPage')) + +// Loading fallback component for Suspense +function PageLoadingFallback() { + return ( +
+ + + +
+ ) +} + function App() { const { isAuthenticated, loading } = useAuth() @@ -30,120 +44,122 @@ function App() { return ( - - : } - /> - - - - - - - - } - /> - - - - - - - - } - /> - - - - - - - - } - /> - - - - - - - - } - /> - - - - - - - - } - /> - - - - - - - - } - /> - - - - - - - - } - /> - - - - - - - - } - /> - - - - - - - - } - /> - + }> + + : } + /> + + + + + + + + } + /> + + + + + + + + } + /> + + + + + + + + } + /> + + + + + + + + } + /> + + + + + + + + } + /> + + + + + + + + } + /> + + + + + + + + } + /> + + + + + + + + } + /> + + + + + + + + } + /> + + ) } diff --git a/frontend/src/components/AttachmentUpload.tsx b/frontend/src/components/AttachmentUpload.tsx index 868b3fd..a343538 100644 --- a/frontend/src/components/AttachmentUpload.tsx +++ b/frontend/src/components/AttachmentUpload.tsx @@ -1,4 +1,5 @@ import { useState, useRef, useEffect, DragEvent, ChangeEvent } from 'react' +import { useTranslation } from 'react-i18next' import { attachmentService } from '../services/attachments' // Spinner animation keyframes - injected once via useEffect @@ -10,6 +11,7 @@ interface AttachmentUploadProps { } export function AttachmentUpload({ taskId, onUploadComplete }: AttachmentUploadProps) { + const { t } = useTranslation('common') const [isDragging, setIsDragging] = useState(false) const [uploading, setUploading] = useState(false) const [uploadProgress, setUploadProgress] = useState(null) @@ -79,14 +81,20 @@ export function AttachmentUpload({ taskId, onUploadComplete }: AttachmentUploadP try { for (let i = 0; i < files.length; i++) { const file = files[i] - setUploadProgress(`Uploading ${file.name} (${i + 1}/${files.length})...`) + setUploadProgress( + t('attachments.uploading', { + filename: file.name, + current: i + 1, + total: files.length, + }) + ) await attachmentService.uploadAttachment(taskId, file) } setUploadProgress(null) onUploadComplete?.() } catch (err: unknown) { console.error('Upload failed:', err) - const errorMessage = err instanceof Error ? err.message : 'Upload failed' + const errorMessage = err instanceof Error ? err.message : t('attachments.uploadFailed') setError(errorMessage) } finally { setUploading(false) @@ -127,10 +135,10 @@ export function AttachmentUpload({ taskId, onUploadComplete }: AttachmentUploadP
📎 - Drop files here or click to upload + {t('attachments.dropzone')} - Maximum file size: 50MB + {t('attachments.maxFileSize', { size: '50MB' })}
)} diff --git a/frontend/src/components/Comments.tsx b/frontend/src/components/Comments.tsx index ce16574..7f37f66 100644 --- a/frontend/src/components/Comments.tsx +++ b/frontend/src/components/Comments.tsx @@ -35,7 +35,7 @@ export function Comments({ taskId }: CommentsProps) { } finally { setLoading(false) } - }, [taskId]) + }, [taskId, t]) useEffect(() => { fetchComments() diff --git a/frontend/src/contexts/AuthContext.tsx b/frontend/src/contexts/AuthContext.tsx index bcc3b64..4d71e5f 100644 --- a/frontend/src/contexts/AuthContext.tsx +++ b/frontend/src/contexts/AuthContext.tsx @@ -1,5 +1,58 @@ import { createContext, useContext, useState, useEffect, ReactNode } from 'react' -import { authApi, User, LoginRequest } from '../services/api' +import { + authApi, + User, + LoginRequest, + storeTokens, + clearStoredTokens, + getStoredToken, + isTokenExpired, +} from '../services/api' + +/** + * Validates that a parsed object has the required User properties. + * Returns the validated User object or null if validation fails. + */ +function validateUserData(data: unknown): User | null { + // Check if data is an object + if (!data || typeof data !== 'object') { + return null + } + + const obj = data as Record + + // Validate required string fields + if (typeof obj.id !== 'string' || obj.id.length === 0) { + return null + } + if (typeof obj.email !== 'string' || obj.email.length === 0) { + return null + } + if (typeof obj.name !== 'string' || obj.name.length === 0) { + return null + } + + // Validate optional/nullable fields + if (obj.role !== null && typeof obj.role !== 'string') { + return null + } + if (obj.department_id !== null && typeof obj.department_id !== 'string') { + return null + } + if (typeof obj.is_system_admin !== 'boolean') { + return null + } + + // Return validated user object + return { + id: obj.id, + email: obj.email, + name: obj.name, + role: obj.role as string | null, + department_id: obj.department_id as string | null, + is_system_admin: obj.is_system_admin, + } +} interface AuthContextType { user: User | null @@ -17,15 +70,35 @@ export function AuthProvider({ children }: { children: ReactNode }) { useEffect(() => { // Check for existing token on mount - const token = localStorage.getItem('token') + const token = getStoredToken() const storedUser = localStorage.getItem('user') if (token && storedUser) { try { - setUser(JSON.parse(storedUser)) - } catch { - localStorage.removeItem('token') - localStorage.removeItem('user') + // Check if token is expired + if (isTokenExpired(token)) { + // Token is expired, clear storage and don't restore user + // The refresh will happen automatically on next API call if refresh token exists + clearStoredTokens() + } else { + // Parse and validate stored user data + const parsedUser = JSON.parse(storedUser) + const validatedUser = validateUserData(parsedUser) + + if (validatedUser) { + setUser(validatedUser) + } else { + // Invalid user data structure, clear storage and redirect to login + console.warn('Invalid user data in localStorage, clearing session') + clearStoredTokens() + // Don't redirect here as we're in initial loading state + // The app will naturally show login page when user is null + } + } + } catch (err) { + // JSON parse error or other unexpected error + console.error('Error parsing stored user data:', err) + clearStoredTokens() } } setLoading(false) @@ -33,7 +106,8 @@ export function AuthProvider({ children }: { children: ReactNode }) { const login = async (data: LoginRequest) => { const response = await authApi.login(data) - localStorage.setItem('token', response.access_token) + // Store access token and refresh token (if provided by backend) + storeTokens(response.access_token, response.refresh_token) localStorage.setItem('user', JSON.stringify(response.user)) setUser(response.user) } @@ -44,8 +118,8 @@ export function AuthProvider({ children }: { children: ReactNode }) { } catch { // Ignore errors on logout } finally { - localStorage.removeItem('token') - localStorage.removeItem('user') + // Clear all tokens (access, refresh, and user data) + clearStoredTokens() setUser(null) } } diff --git a/frontend/src/pages/Dashboard.tsx b/frontend/src/pages/Dashboard.tsx index 3d13ae3..c88b65e 100644 --- a/frontend/src/pages/Dashboard.tsx +++ b/frontend/src/pages/Dashboard.tsx @@ -29,7 +29,7 @@ export default function Dashboard() { } finally { setLoading(false) } - }, []) + }, [t]) useEffect(() => { fetchDashboard() diff --git a/frontend/src/pages/Login.tsx b/frontend/src/pages/Login.tsx index 5473ffb..19d77a6 100644 --- a/frontend/src/pages/Login.tsx +++ b/frontend/src/pages/Login.tsx @@ -1,5 +1,5 @@ -import { useState, FormEvent } from 'react' -import { useNavigate } from 'react-router-dom' +import { useState, useEffect, FormEvent } from 'react' +import { useNavigate, useSearchParams } from 'react-router-dom' import { useTranslation } from 'react-i18next' import { useAuth } from '../contexts/AuthContext' import { LanguageSwitcher } from '../components/LanguageSwitcher' @@ -9,9 +9,21 @@ export default function Login() { const [email, setEmail] = useState('') const [password, setPassword] = useState('') const [error, setError] = useState('') + const [info, setInfo] = useState('') const [loading, setLoading] = useState(false) const { login } = useAuth() const navigate = useNavigate() + const [searchParams, setSearchParams] = useSearchParams() + + // Check for session expired redirect + useEffect(() => { + const reason = searchParams.get('reason') + if (reason === 'session_expired') { + setInfo(t('errors.sessionExpired')) + // Clean up the URL by removing the query parameter + setSearchParams({}, { replace: true }) + } + }, [searchParams, setSearchParams, t]) const handleSubmit = async (e: FormEvent) => { e.preventDefault() @@ -45,6 +57,7 @@ export default function Login() {

{t('login.subtitle')}

+ {info &&
{info}
} {error &&
{error}
}
@@ -163,4 +176,11 @@ const styles: { [key: string]: React.CSSProperties } = { borderRadius: '4px', fontSize: '14px', }, + info: { + backgroundColor: '#e6f4ff', + color: '#0066cc', + padding: '10px', + borderRadius: '4px', + fontSize: '14px', + }, } diff --git a/frontend/src/services/api.ts b/frontend/src/services/api.ts index 936af42..1567702 100644 --- a/frontend/src/services/api.ts +++ b/frontend/src/services/api.ts @@ -1,4 +1,4 @@ -import axios, { InternalAxiosRequestConfig } from 'axios' +import axios, { InternalAxiosRequestConfig, AxiosError } from 'axios' // API base URL - using legacy routes until v1 migration is complete // TODO: Switch to /api/v1 when all routes are migrated @@ -9,10 +9,141 @@ const API_BASE_URL = '/api' let csrfToken: string | null = null let csrfTokenExpiry: number | null = null const CSRF_TOKEN_HEADER = 'X-CSRF-Token' -const CSRF_PROTECTED_METHODS = ['DELETE', 'PUT', 'PATCH'] +const CSRF_PROTECTED_METHODS = ['POST', 'DELETE', 'PUT', 'PATCH'] // Token expires in 1 hour, refresh 5 minutes before expiry const CSRF_TOKEN_LIFETIME_MS = 55 * 60 * 1000 +// JWT Token refresh configuration +// Access tokens expire in 60 minutes, refresh 5 minutes before expiry +const TOKEN_REFRESH_THRESHOLD_MS = 5 * 60 * 1000 + +// Token refresh state management +let isRefreshing = false +let refreshSubscribers: Array<(token: string) => void> = [] + +/** + * JWT Token Utilities + */ + +/** + * Decode a JWT token payload without verification. + * Note: This is for reading claims only, not for security validation. + * Security validation happens on the backend. + */ +export function decodeJwtPayload(token: string): JwtPayload | null { + try { + const parts = token.split('.') + if (parts.length !== 3) { + return null + } + // Decode base64url to base64 + const base64 = parts[1].replace(/-/g, '+').replace(/_/g, '/') + // Add padding if needed + const padded = base64 + '='.repeat((4 - (base64.length % 4)) % 4) + const decoded = atob(padded) + return JSON.parse(decoded) + } catch { + return null + } +} + +/** + * Get the expiration time (in milliseconds since epoch) from a JWT token. + */ +export function getTokenExpiryTime(token: string): number | null { + const payload = decodeJwtPayload(token) + if (!payload || typeof payload.exp !== 'number') { + return null + } + // JWT exp is in seconds, convert to milliseconds + return payload.exp * 1000 +} + +/** + * Check if a token is about to expire (within threshold). + * Returns true if token will expire within the threshold or has already expired. + */ +export function isTokenExpiringSoon( + token: string, + thresholdMs: number = TOKEN_REFRESH_THRESHOLD_MS +): boolean { + const expiryTime = getTokenExpiryTime(token) + if (expiryTime === null) { + // If we can't determine expiry, assume it needs refresh + return true + } + return Date.now() >= expiryTime - thresholdMs +} + +/** + * Check if a token has already expired. + */ +export function isTokenExpired(token: string): boolean { + const expiryTime = getTokenExpiryTime(token) + if (expiryTime === null) { + return true + } + return Date.now() >= expiryTime +} + +interface JwtPayload { + sub: string + email: string + role?: string | null + department_id?: string | null + is_system_admin?: boolean + exp: number + iat: number +} + +/** + * Token Storage Utilities + * Note: Using localStorage for token storage. While httpOnly cookies are more + * secure against XSS attacks, localStorage is acceptable for this implementation + * as long as proper XSS protections are in place (Content Security Policy, etc.). + * The refresh token mechanism limits exposure time if a token is compromised. + */ +const TOKEN_KEY = 'token' +const REFRESH_TOKEN_KEY = 'refresh_token' +const USER_KEY = 'user' + +export function getStoredToken(): string | null { + return localStorage.getItem(TOKEN_KEY) +} + +export function getStoredRefreshToken(): string | null { + return localStorage.getItem(REFRESH_TOKEN_KEY) +} + +export function storeTokens(accessToken: string, refreshToken?: string): void { + localStorage.setItem(TOKEN_KEY, accessToken) + if (refreshToken) { + localStorage.setItem(REFRESH_TOKEN_KEY, refreshToken) + } +} + +export function clearStoredTokens(): void { + localStorage.removeItem(TOKEN_KEY) + localStorage.removeItem(REFRESH_TOKEN_KEY) + localStorage.removeItem(USER_KEY) +} + +/** + * Subscribe to token refresh completion. + * Used to queue requests while a refresh is in progress. + */ +function subscribeToTokenRefresh(callback: (token: string) => void): void { + refreshSubscribers.push(callback) +} + +/** + * Notify all subscribers that token has been refreshed. + */ +function onTokenRefreshed(newToken: string): void { + refreshSubscribers.forEach((callback) => callback(newToken)) + refreshSubscribers = [] +} + const api = axios.create({ baseURL: API_BASE_URL, headers: { @@ -77,33 +208,149 @@ export async function prefetchCsrfToken(): Promise { await fetchCsrfToken() } -// Add token to requests and CSRF token for protected methods -api.interceptors.request.use(async (config: InternalAxiosRequestConfig) => { - const token = localStorage.getItem('token') - if (token) { - config.headers.Authorization = `Bearer ${token}` +/** + * Refresh the access token using the refresh token. + * This is called automatically when the access token is about to expire. + * + * @returns The new access token, or null if refresh failed + */ +async function refreshAccessToken(): Promise { + const refreshToken = getStoredRefreshToken() + if (!refreshToken) { + return null + } - // Add CSRF token for protected methods - const method = config.method?.toUpperCase() - if (method && CSRF_PROTECTED_METHODS.includes(method)) { - const csrf = await getValidCsrfToken() - if (csrf) { - config.headers[CSRF_TOKEN_HEADER] = csrf + try { + // Use axios directly to avoid interceptor loops + const response = await axios.post<{ + access_token: string + refresh_token?: string + token_type: string + }>( + `${API_BASE_URL}/auth/refresh`, + { refresh_token: refreshToken }, + { + headers: { + 'Content-Type': 'application/json', + }, + } + ) + + const { access_token, refresh_token: newRefreshToken } = response.data + + // Store the new tokens + storeTokens(access_token, newRefreshToken || refreshToken) + + return access_token + } catch (error) { + // If refresh fails (401 or other error), the token is invalid + // Clear all tokens and let the response interceptor handle redirect + return null + } +} + +/** + * Ensure we have a valid access token, refreshing if necessary. + * This implements a queue mechanism to prevent multiple simultaneous refresh requests. + * + * @returns A promise that resolves with a valid token or null if unavailable + */ +async function ensureValidToken(): Promise { + const token = getStoredToken() + + if (!token) { + return null + } + + // If token is not expiring soon, use it as-is + if (!isTokenExpiringSoon(token)) { + return token + } + + // If we're already refreshing, wait for it to complete + if (isRefreshing) { + return new Promise((resolve) => { + subscribeToTokenRefresh((newToken) => { + resolve(newToken) + }) + }) + } + + // Start the refresh process + isRefreshing = true + + try { + const newToken = await refreshAccessToken() + + if (newToken) { + onTokenRefreshed(newToken) + return newToken + } else { + // Refresh failed - clear tokens and redirect to login + clearStoredTokens() + clearCsrfToken() + // Notify subscribers with empty token (they'll fail but won't retry) + refreshSubscribers = [] + window.location.href = '/login?reason=session_expired' + return null + } + } finally { + isRefreshing = false + } +} + +// Add token to requests and CSRF token for protected methods +// This interceptor ensures tokens are refreshed before they expire +api.interceptors.request.use(async (config: InternalAxiosRequestConfig) => { + // Skip token handling for auth endpoints that don't require authentication + const isAuthEndpoint = + config.url?.includes('/auth/login') || config.url?.includes('/auth/refresh') + + if (!isAuthEndpoint) { + // Ensure we have a valid token (will refresh if expiring soon) + const token = await ensureValidToken() + + if (token) { + config.headers.Authorization = `Bearer ${token}` + + // Add CSRF token for protected methods + const method = config.method?.toUpperCase() + if (method && CSRF_PROTECTED_METHODS.includes(method)) { + const csrf = await getValidCsrfToken() + if (csrf) { + config.headers[CSRF_TOKEN_HEADER] = csrf + } } } } + return config }) -// Handle 401 responses +// Handle 401 responses - clear tokens and redirect to login +// Note: Token refresh is handled proactively in the request interceptor +// A 401 here means either: +// 1. The token was revoked on the server +// 2. The refresh token has expired +// 3. Some other authentication issue api.interceptors.response.use( (response) => response, - (error) => { + (error: AxiosError) => { if (error.response?.status === 401) { - localStorage.removeItem('token') - localStorage.removeItem('user') - clearCsrfToken() - window.location.href = '/login' + // Check if this is from a refresh endpoint to avoid redirect loops + const isRefreshRequest = error.config?.url?.includes('/auth/refresh') + + if (!isRefreshRequest) { + // Clear all auth state + clearStoredTokens() + clearCsrfToken() + + // Redirect to login with appropriate message + const currentPath = window.location.pathname + if (currentPath !== '/login') { + window.location.href = '/login?reason=session_expired' + } + } } return Promise.reject(error) } @@ -125,10 +372,17 @@ export interface User { export interface LoginResponse { access_token: string + refresh_token?: string // Optional for backward compatibility during migration token_type: string user: User } +export interface RefreshTokenResponse { + access_token: string + refresh_token?: string // New refresh token if rotation is enabled + token_type: string +} + export const authApi = { login: async (data: LoginRequest): Promise => { const response = await api.post('/auth/login', data) diff --git a/openspec/specs/audit-trail/spec.md b/openspec/specs/audit-trail/spec.md index 99aebd1..4839a60 100644 --- a/openspec/specs/audit-trail/spec.md +++ b/openspec/specs/audit-trail/spec.md @@ -98,6 +98,31 @@ - **WHEN** 異常行為發生 - **THEN** 系統記錄並發送警示 +### Requirement: Security Event Logging +The system SHALL record failed access attempts for security monitoring and intrusion detection. + +#### Scenario: Permission denied logged +- **WHEN** server returns 403 Forbidden for a resource access attempt +- **THEN** audit log entry is created with event_type "security.access_denied" +- **AND** entry includes user_id, resource_type, and attempted_action + +#### Scenario: Repeated auth failures logged +- **WHEN** same IP has 5+ failed authentication attempts in 10 minutes +- **THEN** audit log entry is created with event_type "security.suspicious_auth_pattern" +- **AND** entry includes IP address and failure count +- **AND** alert is generated for security administrators + +### Requirement: Detailed Health Endpoint Security +The detailed system health endpoint SHALL require admin authentication to prevent information disclosure. + +#### Scenario: Admin accesses detailed health +- **WHEN** system administrator requests GET /health/detailed +- **THEN** full system status including connection pools is returned + +#### Scenario: Non-admin accesses detailed health +- **WHEN** non-admin user or unauthenticated request to GET /health/detailed +- **THEN** request is rejected with 401 Unauthorized or 403 Forbidden + ## Data Model ``` diff --git a/openspec/specs/dashboard/spec.md b/openspec/specs/dashboard/spec.md index 2ef9a2d..fd87739 100644 --- a/openspec/specs/dashboard/spec.md +++ b/openspec/specs/dashboard/spec.md @@ -161,6 +161,33 @@ The system SHALL support project templates to standardize project creation. - **THEN** system creates template with project's CustomField definitions - **THEN** template is available for future project creation +### Requirement: Code Splitting +The application SHALL use code splitting with React.lazy() to reduce initial bundle size and improve load times. + +#### Scenario: Initial page load +- **WHEN** user navigates to application +- **THEN** only core framework and current route are loaded +- **AND** other routes are loaded on demand + +#### Scenario: Route-based splitting +- **WHEN** user navigates to a different page +- **THEN** that page's code chunk is loaded dynamically +- **AND** loading fallback is displayed during load + +### Requirement: LocalStorage Data Validation +User data loaded from localStorage SHALL be validated before use to prevent crashes from corrupted data. + +#### Scenario: Corrupted localStorage data +- **WHEN** localStorage contains malformed user JSON +- **THEN** invalid data is cleared +- **AND** user is redirected to login page +- **AND** no application crash occurs + +#### Scenario: Valid localStorage data +- **WHEN** localStorage contains valid user JSON +- **THEN** user is authenticated from stored data +- **AND** application loads normally + ### Requirement: Error Boundary Protection The frontend application SHALL gracefully handle component render errors without crashing the entire application. diff --git a/openspec/specs/task-management/spec.md b/openspec/specs/task-management/spec.md index 6f5319c..3b9f697 100644 --- a/openspec/specs/task-management/spec.md +++ b/openspec/specs/task-management/spec.md @@ -78,6 +78,26 @@ - **WHEN** 管理者嘗試新增第 21 個欄位 - **THEN** 系統拒絕新增並顯示數量已達上限的訊息 +### Requirement: Input Length Validation +All text input fields SHALL have maximum length constraints to prevent abuse and database issues. + +#### Scenario: Task title exceeds limit +- **WHEN** user creates a task with title exceeding 500 characters +- **THEN** request is rejected with 422 Validation Error +- **AND** error indicates field length exceeded + +#### Scenario: Description within limit +- **WHEN** user creates a task with description 10000 characters or less +- **THEN** task is created successfully + +#### Scenario: Description exceeds limit +- **WHEN** user creates a task with description exceeding 10000 characters +- **THEN** request is rejected with 422 Validation Error + +#### Scenario: Comment content limit +- **WHEN** user submits a comment exceeding 5000 characters +- **THEN** request is rejected with 422 Validation Error + ### Requirement: Multiple Views 系統 SHALL 支援多維視角:看板 (Kanban)、甘特圖 (Gantt)、列表 (List)、行事曆 (Calendar)。 diff --git a/openspec/specs/user-auth/spec.md b/openspec/specs/user-auth/spec.md index 0def28b..1b3e549 100644 --- a/openspec/specs/user-auth/spec.md +++ b/openspec/specs/user-auth/spec.md @@ -89,6 +89,34 @@ - **WHEN** 使用者執行登出操作 - **THEN** 系統銷毀 session 並清除 token +### Requirement: Access Token Expiry +Access tokens SHALL expire within 60 minutes to limit exposure window in case of token compromise. + +#### Scenario: Access token expiry +- **WHEN** an access token issued 61 minutes ago is used for API authentication +- **THEN** request is rejected with 401 Unauthorized +- **AND** error indicates "Token expired" + +### Requirement: Refresh Token Support +The system SHALL support refresh tokens for seamless session continuity without requiring re-authentication. + +#### Scenario: Refresh valid token +- **WHEN** POST to /api/auth/refresh with valid refresh token +- **THEN** new access token is issued +- **AND** new refresh token is issued via rotation +- **AND** old refresh token is invalidated + +#### Scenario: Refresh expired token +- **WHEN** POST to /api/auth/refresh with expired refresh token +- **THEN** request is rejected with 401 Unauthorized +- **AND** user must re-authenticate via login + +#### Scenario: Automatic frontend refresh +- **WHEN** access token expires in less than 5 minutes +- **AND** frontend prepares to make API call +- **THEN** token is automatically refreshed first +- **AND** original request proceeds with new token + ### Requirement: API Rate Limiting The system SHALL implement rate limiting to protect against brute force attacks and DoS attempts. @@ -143,8 +171,19 @@ The system SHALL enforce maximum length limits on all user-provided string input - **WHEN** user submits content with description under 10000 characters - **THEN** system accepts the input and processes normally +### Requirement: CORS Security +The system SHALL explicitly define allowed CORS methods and headers instead of using wildcards to reduce attack surface. + +#### Scenario: Request with standard headers +- **WHEN** a cross-origin request includes Content-Type, Authorization, or X-CSRF-Token headers +- **THEN** the request is allowed + +#### Scenario: Request with non-standard header +- **WHEN** a cross-origin request includes a non-whitelisted custom header +- **THEN** CORS preflight fails and request is rejected + ### Requirement: Secure WebSocket Authentication -The system SHALL authenticate WebSocket connections without exposing tokens in URL query parameters. +The system SHALL authenticate WebSocket connections without exposing tokens in URL query parameters. In production environments, query parameter authentication SHALL be disabled. #### Scenario: WebSocket connection with token in first message - **WHEN** client connects to WebSocket endpoint without a query token @@ -161,6 +200,25 @@ The system SHALL authenticate WebSocket connections without exposing tokens in U - **WHEN** client connects but does not send authentication within 10 seconds - **THEN** server closes the connection with appropriate error code +#### Scenario: Query parameter auth in production +- **WHEN** production environment and WebSocket connection includes token in query parameter +- **THEN** connection is rejected with code 4002 +- **AND** error message indicates "Query parameter auth disabled in production" + +### Requirement: WebSocket Connection Limits +The system SHALL limit each user to a maximum of 5 concurrent WebSocket connections to prevent resource exhaustion. + +#### Scenario: User exceeds connection limit +- **WHEN** user already has 5 active WebSocket connections +- **AND** user attempts to open a 6th connection +- **THEN** connection is rejected with code 4005 +- **AND** error message indicates "Too many connections" + +#### Scenario: User within connection limit +- **WHEN** user has fewer than 5 active connections +- **AND** user opens a new WebSocket connection +- **THEN** connection is accepted + ### Requirement: Path Traversal Protection The system SHALL prevent file path traversal attacks by validating all file paths resolve within the designated storage directory. @@ -187,22 +245,25 @@ The system SHALL validate JWT secret key strength on startup. - **THEN** the system SHALL log a security warning ### Requirement: CSRF Protection -The system SHALL protect sensitive state-changing operations with CSRF tokens. +The system SHALL protect all state-changing operations (POST, PUT, PATCH, DELETE) with CSRF tokens. -#### Scenario: CSRF token required for password change -- **WHEN** a user attempts to change their password -- **AND** the request does not include a valid CSRF token +#### Scenario: POST request without CSRF token +- **WHEN** an authenticated user makes a POST request without X-CSRF-Token header - **THEN** the request SHALL be rejected with 403 Forbidden +- **AND** error message indicates "CSRF token is required" -#### Scenario: CSRF token required for account deletion -- **WHEN** a user attempts to delete their account or resources -- **AND** the request does not include a valid CSRF token +#### Scenario: PUT/PATCH/DELETE request without CSRF token +- **WHEN** an authenticated user makes a PUT, PATCH, or DELETE request without X-CSRF-Token header - **THEN** the request SHALL be rejected with 403 Forbidden #### Scenario: Valid CSRF token accepted - **WHEN** a state-changing request includes a valid CSRF token - **THEN** the request SHALL proceed normally +#### Scenario: Public endpoints exempt from CSRF +- **WHEN** POST to /api/auth/login or other public endpoints +- **THEN** CSRF token is not required + ## Data Model ```