feat: implement 5 QA-driven security and quality proposals

Implemented proposals from comprehensive QA review:

1. extend-csrf-protection
   - Add POST to CSRF protected methods in frontend
   - Global CSRF middleware for all state-changing operations
   - Update tests with CSRF token fixtures

2. tighten-cors-websocket-security
   - Replace wildcard CORS with explicit method/header lists
   - Disable query parameter auth in production (code 4002)
   - Add per-user WebSocket connection limit (max 5, code 4005)

3. shorten-jwt-expiry
   - Reduce JWT expiry from 7 days to 60 minutes
   - Add refresh token support with 7-day expiry
   - Implement token rotation on refresh
   - Frontend auto-refresh when token near expiry (<5 min)

4. fix-frontend-quality
   - Add React.lazy() code splitting for all pages
   - Fix useCallback dependency arrays (Dashboard, Comments)
   - Add localStorage data validation in AuthContext
   - Complete i18n for AttachmentUpload component

5. enhance-backend-validation
   - Add SecurityAuditMiddleware for access denied logging
   - Add ErrorSanitizerMiddleware for production error messages
   - Protect /health/detailed with admin authentication
   - Add input length validation (comment 5000, desc 10000)

All 521 backend tests passing. Frontend builds successfully.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-12 23:19:05 +08:00
parent df50d5e7f8
commit 35c90fe76b
48 changed files with 2132 additions and 403 deletions

View File

@@ -3,12 +3,28 @@ from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.database import get_db
from app.core.security import create_access_token, create_token_payload
from app.core.security import (
create_access_token,
create_token_payload,
generate_refresh_token,
store_refresh_token,
validate_refresh_token,
invalidate_refresh_token,
invalidate_all_user_refresh_tokens,
decode_refresh_token_user_id,
)
from app.core.redis import get_redis
from app.core.rate_limiter import limiter
from app.models.user import User
from app.models.audit_log import AuditAction
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo, CSRFTokenResponse
from app.schemas.auth import (
LoginRequest,
LoginResponse,
UserInfo,
CSRFTokenResponse,
RefreshTokenRequest,
RefreshTokenResponse,
)
from app.services.auth_client import (
verify_credentials,
AuthAPIError,
@@ -119,6 +135,9 @@ async def login(
# Create access token
access_token = create_access_token(token_data)
# Generate refresh token
refresh_token = generate_refresh_token()
# Store session in Redis (sync with JWT expiry)
redis_client.setex(
f"session:{user.id}",
@@ -126,6 +145,9 @@ async def login(
access_token,
)
# Store refresh token in Redis with user binding
store_refresh_token(redis_client, user.id, refresh_token)
# Log successful login
AuditService.log_event(
db=db,
@@ -141,6 +163,8 @@ async def login(
return LoginResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=settings.JWT_EXPIRE_MINUTES * 60,
user=UserInfo(
id=user.id,
email=user.email,
@@ -158,14 +182,114 @@ async def logout(
redis_client=Depends(get_redis),
):
"""
Logout user and invalidate session.
Logout user and invalidate session and all refresh tokens.
"""
# Remove session from Redis
redis_client.delete(f"session:{current_user.id}")
# Invalidate all refresh tokens for this user
invalidate_all_user_refresh_tokens(redis_client, current_user.id)
return {"detail": "Successfully logged out"}
@router.post("/refresh", response_model=RefreshTokenResponse)
@limiter.limit("10/minute")
async def refresh_access_token(
request: Request,
refresh_request: RefreshTokenRequest,
db: Session = Depends(get_db),
redis_client=Depends(get_redis),
):
"""
Refresh access token using a valid refresh token.
This endpoint implements refresh token rotation:
- Validates the provided refresh token
- Issues a new access token
- Issues a new refresh token (rotating the old one)
- Invalidates the old refresh token
This provides enhanced security by ensuring refresh tokens are single-use.
"""
old_refresh_token = refresh_request.refresh_token
# Find the user ID associated with this refresh token
user_id = decode_refresh_token_user_id(old_refresh_token, redis_client)
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
headers={"WWW-Authenticate": "Bearer"},
)
# Validate the refresh token is still valid and bound to this user
if not validate_refresh_token(redis_client, user_id, old_refresh_token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
headers={"WWW-Authenticate": "Bearer"},
)
# Get user from database
user = db.query(User).filter(User.id == user_id).first()
if user is None:
# Invalidate the token since user no longer exists
invalidate_refresh_token(redis_client, user_id, old_refresh_token)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
# Invalidate all tokens for disabled user
invalidate_all_user_refresh_tokens(redis_client, user_id)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User account is disabled",
)
# Invalidate the old refresh token (rotation)
invalidate_refresh_token(redis_client, user_id, old_refresh_token)
# Get role name
role_name = user.role.name if user.role else None
# Create new token payload
token_data = create_token_payload(
user_id=user.id,
email=user.email,
role=role_name,
department_id=user.department_id,
is_system_admin=user.is_system_admin,
)
# Create new access token
new_access_token = create_access_token(token_data)
# Generate new refresh token (rotation)
new_refresh_token = generate_refresh_token()
# Store new session in Redis
redis_client.setex(
f"session:{user.id}",
settings.JWT_EXPIRE_MINUTES * 60,
new_access_token,
)
# Store new refresh token
store_refresh_token(redis_client, user.id, new_refresh_token)
return RefreshTokenResponse(
access_token=new_access_token,
refresh_token=new_refresh_token,
expires_in=settings.JWT_EXPIRE_MINUTES * 60,
)
@router.get("/me", response_model=UserInfo)
async def get_current_user_info(
current_user: User = Depends(get_current_user),

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from app.core import database
from app.core.security import decode_access_token
from app.core.redis import get_redis_sync
from app.core.config import settings
from app.models import User, Notification, Project
from app.services.websocket_manager import manager
from app.core.redis_pubsub import NotificationSubscriber, ProjectTaskSubscriber
@@ -72,14 +73,24 @@ async def authenticate_websocket(
Supports two authentication methods:
1. First message authentication (preferred, more secure)
- Client sends: {"type": "auth", "token": "<jwt_token>"}
2. Query parameter authentication (deprecated, for backward compatibility)
2. Query parameter authentication (disabled in production, for backward compatibility only)
- Client connects with: ?token=<jwt_token>
Returns:
Tuple of (user_id, error_reason). user_id is None if authentication fails.
Error reasons: "invalid_token", "invalid_message", "missing_token",
"timeout", "error", "query_auth_disabled"
"""
# If token provided via query parameter (backward compatibility)
if query_token:
# Reject query parameter auth in production for security
if settings.ENVIRONMENT == "production":
logger.warning(
"WebSocket query parameter authentication attempted in production environment. "
"This is disabled for security reasons."
)
return None, "query_auth_disabled"
logger.warning(
"WebSocket authentication via query parameter is deprecated. "
"Please use first-message authentication for better security."
@@ -195,9 +206,21 @@ async def websocket_notifications(
user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
if error_reason == "query_auth_disabled":
await websocket.send_json({"type": "error", "message": "Query parameter auth disabled in production"})
await websocket.close(code=4002, reason="Query parameter auth disabled in production")
elif error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
await websocket.close(code=4001, reason="Invalid or expired token")
else:
await websocket.close(code=4001, reason="Invalid or expired token")
return
# Check connection limit before accepting
can_connect, reject_reason = await manager.check_connection_limit(user_id)
if not can_connect:
await websocket.send_json({"type": "error", "message": reject_reason})
await websocket.close(code=4005, reason=reject_reason)
return
await manager.connect(websocket, user_id)
@@ -394,9 +417,21 @@ async def websocket_project_sync(
user_id, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
if error_reason == "query_auth_disabled":
await websocket.send_json({"type": "error", "message": "Query parameter auth disabled in production"})
await websocket.close(code=4002, reason="Query parameter auth disabled in production")
elif error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
await websocket.close(code=4001, reason="Invalid or expired token")
else:
await websocket.close(code=4001, reason="Invalid or expired token")
return
# Check connection limit before accepting
can_connect, reject_reason = await manager.check_connection_limit(user_id)
if not can_connect:
await websocket.send_json({"type": "error", "message": reject_reason})
await websocket.close(code=4005, reason=reject_reason)
return
# Verify user has access to the project

View File

@@ -28,7 +28,8 @@ class Settings(BaseSettings):
# JWT - Must be set in environment, no default allowed
JWT_SECRET_KEY: str = ""
JWT_ALGORITHM: str = "HS256"
JWT_EXPIRE_MINUTES: int = 10080 # 7 days
JWT_EXPIRE_MINUTES: int = 60 # 1 hour (short-lived access token)
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # Refresh token valid for 7 days
@field_validator("JWT_SECRET_KEY")
@classmethod
@@ -127,6 +128,12 @@ class Settings(BaseSettings):
QUERY_LOGGING: bool = False # Enable SQLAlchemy query logging
QUERY_COUNT_THRESHOLD: int = 10 # Warn when query count exceeds this threshold
# Environment
ENVIRONMENT: str = "development" # Options: development, staging, production
# WebSocket Settings
MAX_WEBSOCKET_CONNECTIONS_PER_USER: int = 5 # Maximum concurrent WebSocket connections per user
class Config:
env_file = ".env"
case_sensitive = True

View File

@@ -356,3 +356,140 @@ def create_token_payload(
"department_id": department_id,
"is_system_admin": is_system_admin,
}
# Refresh Token Functions
REFRESH_TOKEN_BYTES = 32
def generate_refresh_token() -> str:
"""
Generate a cryptographically secure refresh token.
Returns:
A URL-safe base64-encoded random token
"""
return secrets.token_urlsafe(REFRESH_TOKEN_BYTES)
def get_refresh_token_key(user_id: str, token: str) -> str:
"""
Generate the Redis key for a refresh token.
Args:
user_id: The user's ID
token: The refresh token
Returns:
Redis key string
"""
# Hash the token to avoid storing it directly as a key
token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
return f"refresh_token:{user_id}:{token_hash}"
def store_refresh_token(redis_client, user_id: str, token: str) -> None:
"""
Store a refresh token in Redis with user binding.
Args:
redis_client: Redis client instance
user_id: The user's ID
token: The refresh token to store
"""
key = get_refresh_token_key(user_id, token)
# Store with TTL based on REFRESH_TOKEN_EXPIRE_DAYS
ttl_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
redis_client.setex(key, ttl_seconds, user_id)
def validate_refresh_token(redis_client, user_id: str, token: str) -> bool:
"""
Validate a refresh token exists in Redis and is bound to the user.
Args:
redis_client: Redis client instance
user_id: The expected user ID
token: The refresh token to validate
Returns:
True if token is valid, False otherwise
"""
key = get_refresh_token_key(user_id, token)
stored_user_id = redis_client.get(key)
if stored_user_id is None:
return False
# Handle Redis bytes type
if isinstance(stored_user_id, bytes):
stored_user_id = stored_user_id.decode("utf-8")
return stored_user_id == user_id
def invalidate_refresh_token(redis_client, user_id: str, token: str) -> bool:
"""
Invalidate (delete) a refresh token from Redis.
Args:
redis_client: Redis client instance
user_id: The user's ID
token: The refresh token to invalidate
Returns:
True if token was deleted, False if it didn't exist
"""
key = get_refresh_token_key(user_id, token)
result = redis_client.delete(key)
return result > 0 if isinstance(result, int) else bool(result)
def invalidate_all_user_refresh_tokens(redis_client, user_id: str) -> int:
"""
Invalidate all refresh tokens for a user.
Args:
redis_client: Redis client instance
user_id: The user's ID
Returns:
Number of tokens invalidated
"""
pattern = f"refresh_token:{user_id}:*"
count = 0
for key in redis_client.scan_iter(match=pattern):
redis_client.delete(key)
count += 1
return count
def decode_refresh_token_user_id(token: str, redis_client) -> Optional[str]:
"""
Find the user ID associated with a refresh token by searching Redis.
This is used when we only have the token and need to find which user it belongs to.
Note: This is less efficient but necessary for refresh token validation when
the user_id is not provided in the request.
Args:
token: The refresh token
redis_client: Redis client instance
Returns:
User ID if found, None otherwise
"""
# We need to search for the token across all users
# This is done by checking the token hash pattern
token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
pattern = f"refresh_token:*:{token_hash}"
for key in redis_client.scan_iter(match=pattern):
# Extract user_id from key format: refresh_token:{user_id}:{token_hash}
if isinstance(key, bytes):
key = key.decode("utf-8")
parts = key.split(":")
if len(parts) == 3:
return parts[1]
return None

View File

@@ -1,7 +1,7 @@
import os
from contextlib import asynccontextmanager
from datetime import datetime
from fastapi import FastAPI, Request, APIRouter
from fastapi import FastAPI, Request, APIRouter, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from slowapi import _rate_limit_exceeded_handler
@@ -9,6 +9,9 @@ from slowapi.errors import RateLimitExceeded
from sqlalchemy import text
from app.middleware.audit import AuditMiddleware
from app.middleware.csrf import CSRFMiddleware
from app.middleware.security_audit import SecurityAuditMiddleware
from app.middleware.error_sanitizer import ErrorSanitizerMiddleware
from app.core.scheduler import start_scheduler, shutdown_scheduler, scheduler
from app.core.rate_limiter import limiter
from app.core.deprecation import DeprecationMiddleware
@@ -61,6 +64,8 @@ from app.core.database import get_pool_status, engine
from app.core.redis import redis_client
from app.services.notification_service import get_redis_fallback_status
from app.services.file_storage_service import file_storage_service
from app.middleware.auth import require_system_admin
from app.models import User
app = FastAPI(
title="Project Control API",
@@ -73,18 +78,28 @@ app = FastAPI(
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# CORS middleware
# CORS middleware - Explicit methods and headers for security
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-CSRF-Token", "X-Request-ID"],
)
# Error sanitizer middleware - sanitizes error messages in production
# Must be first in the chain to intercept all error responses
app.add_middleware(ErrorSanitizerMiddleware)
# Audit middleware - extracts request metadata for audit logging
app.add_middleware(AuditMiddleware)
# Security audit middleware - logs 401/403 responses to audit trail
app.add_middleware(SecurityAuditMiddleware)
# CSRF middleware - validates CSRF tokens for state-changing requests
app.add_middleware(CSRFMiddleware)
# Deprecation middleware - adds deprecation headers to legacy /api/ routes
app.add_middleware(DeprecationMiddleware)
@@ -252,14 +267,20 @@ async def readiness_check():
@app.get("/health/detailed")
async def detailed_health_check():
"""Detailed health check endpoint.
async def detailed_health_check(
current_user: User = Depends(require_system_admin),
):
"""Detailed health check endpoint (requires system admin).
Returns comprehensive status of all system components:
- database: Connection pool status and connectivity
- redis: Connection status and fallback queue status
- storage: File storage validation status
- scheduler: Background job scheduler status
Note: This endpoint requires system admin authentication because it exposes
sensitive infrastructure details including connection pool statistics and
internal service states.
"""
db_health = check_database_health()
redis_health = check_redis_health()

View File

@@ -1,38 +1,55 @@
"""
CSRF (Cross-Site Request Forgery) Protection Middleware.
This module provides CSRF protection for sensitive state-changing operations.
It validates CSRF tokens for specified protected endpoints.
This module provides CSRF protection for all state-changing operations.
It validates CSRF tokens globally for authenticated POST, PUT, PATCH, DELETE requests.
"""
from fastapi import Request, HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional, Callable, List
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from fastapi import HTTPException, status
from typing import Optional, Callable, List, Set
from functools import wraps
import logging
from app.core.security import validate_csrf_token, generate_csrf_token
from app.core.security import validate_csrf_token, generate_csrf_token, decode_access_token
logger = logging.getLogger(__name__)
# Header name for CSRF token
CSRF_TOKEN_HEADER = "X-CSRF-Token"
# List of endpoint patterns that require CSRF protection
# These are sensitive state-changing operations
CSRF_PROTECTED_PATTERNS = [
# User operations
"/api/v1/users/{user_id}/admin", # Admin status change
"/api/users/{user_id}/admin", # Legacy
# Password changes would go here if implemented
# Delete operations
"/api/attachments/{attachment_id}", # DELETE method
"/api/tasks/{task_id}", # DELETE method (soft delete)
"/api/projects/{project_id}", # DELETE method
]
# Methods that require CSRF protection (all state-changing operations)
CSRF_PROTECTED_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
# Methods that require CSRF protection
CSRF_PROTECTED_METHODS = ["DELETE", "PUT", "PATCH"]
# Safe methods that don't require CSRF protection
CSRF_SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
# Public endpoints that don't require CSRF validation
# These are endpoints that either:
# 1. Don't require authentication (login, health checks)
# 2. Are not state-changing in a security-sensitive way
CSRF_EXCLUDED_PATHS: Set[str] = {
# Authentication endpoints (unauthenticated)
"/api/auth/login",
"/api/v1/auth/login",
# Health check endpoints (unauthenticated)
"/health",
"/health/live",
"/health/ready",
"/health/detailed",
# WebSocket endpoints (use different auth mechanism)
"/api/ws",
"/ws",
}
# Path prefixes that are excluded from CSRF validation
CSRF_EXCLUDED_PREFIXES: List[str] = [
# WebSocket paths
"/api/ws/",
"/ws/",
]
class CSRFProtectionError(HTTPException):
@@ -45,6 +62,114 @@ class CSRFProtectionError(HTTPException):
)
class CSRFMiddleware(BaseHTTPMiddleware):
"""
Global CSRF protection middleware.
Validates CSRF tokens for all authenticated state-changing requests
(POST, PUT, PATCH, DELETE) except for explicitly excluded endpoints.
"""
async def dispatch(self, request: Request, call_next):
"""Process the request and validate CSRF token if needed."""
method = request.method.upper()
path = request.url.path
# Skip CSRF validation for safe methods
if method in CSRF_SAFE_METHODS:
return await call_next(request)
# Skip CSRF validation for excluded paths
if self._is_excluded_path(path):
logger.debug("CSRF validation skipped for excluded path: %s", path)
return await call_next(request)
# Try to extract user ID from the Authorization header
user_id = self._extract_user_id_from_token(request)
# If no user ID (unauthenticated request), skip CSRF validation
# The authentication middleware will handle unauthorized access
if user_id is None:
logger.debug(
"CSRF validation skipped (no auth token): %s %s",
method, path
)
return await call_next(request)
# Get CSRF token from header
csrf_token = request.headers.get(CSRF_TOKEN_HEADER)
if not csrf_token:
logger.warning(
"CSRF validation failed: Missing token for user %s on %s %s",
user_id, method, path
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": "CSRF token is required"}
)
# Validate the token
is_valid, error_message = validate_csrf_token(csrf_token, user_id)
if not is_valid:
logger.warning(
"CSRF validation failed for user %s on %s %s: %s",
user_id, method, path, error_message
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": error_message}
)
logger.debug(
"CSRF validation passed for user %s on %s %s",
user_id, method, path
)
return await call_next(request)
def _is_excluded_path(self, path: str) -> bool:
"""Check if the path is excluded from CSRF validation."""
# Check exact path matches
if path in CSRF_EXCLUDED_PATHS:
return True
# Check path prefixes
for prefix in CSRF_EXCLUDED_PREFIXES:
if path.startswith(prefix):
return True
return False
def _extract_user_id_from_token(self, request: Request) -> Optional[str]:
"""
Extract user ID from the Authorization header.
Returns None if no valid token is found (unauthenticated request).
"""
auth_header = request.headers.get("Authorization")
if not auth_header:
return None
# Parse Bearer token
parts = auth_header.split()
if len(parts) != 2 or parts[0].lower() != "bearer":
return None
token = parts[1]
# Decode the token to get user ID
try:
payload = decode_access_token(token)
if payload is None:
return None
return payload.get("sub")
except Exception as e:
logger.debug("Failed to decode token for CSRF validation: %s", e)
return None
def require_csrf_token(func: Callable) -> Callable:
"""
Decorator to require CSRF token validation for an endpoint.

View File

@@ -0,0 +1,187 @@
"""Error message sanitization middleware for production environments.
This middleware intercepts error responses and sanitizes them to prevent
information disclosure in production environments. Detailed error messages
are only shown when DEBUG mode is enabled.
"""
import json
import logging
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from app.core.config import settings
logger = logging.getLogger(__name__)
# Generic error messages for production
GENERIC_ERROR_MESSAGES = {
400: "Bad Request",
401: "Authentication required",
403: "Access denied",
404: "Resource not found",
405: "Method not allowed",
409: "Request conflict",
422: "Validation error",
429: "Too many requests",
500: "Internal server error",
502: "Service unavailable",
503: "Service temporarily unavailable",
504: "Request timeout",
}
# Status codes that should preserve their original message even in production
# These are typically user-facing validation errors that don't leak sensitive info
PRESERVE_MESSAGE_CODES = {
400, # Bad request - users need to know what's wrong with their request
401, # Unauthorized - users need to know why auth failed
403, # Forbidden - users need to know what permission they lack
404, # Not found - usually safe to preserve
409, # Conflict - users need to know about conflicts
422, # Validation errors - users need to know what to fix
}
# Patterns that indicate sensitive information in error messages
SENSITIVE_PATTERNS = [
"traceback",
"stack trace",
"file path",
"/usr/",
"/var/",
"/home/",
"connection refused",
"connection error",
"timeout connecting",
"database error",
"sql",
"query failed",
"password",
"secret",
"token",
"key=",
"credentials",
".py line",
"exception in",
]
def _contains_sensitive_info(message: str) -> bool:
"""Check if an error message contains potentially sensitive information."""
if not message:
return False
message_lower = message.lower()
return any(pattern.lower() in message_lower for pattern in SENSITIVE_PATTERNS)
def _sanitize_detail(detail: any, status_code: int) -> any:
"""Sanitize error detail, removing sensitive information in production.
Args:
detail: The error detail (can be string, list, or dict)
status_code: The HTTP status code
Returns:
Sanitized detail for production, or original detail for debug mode
"""
# In debug mode, return original detail
if settings.DEBUG:
return detail
# For preserved status codes, keep the detail if it doesn't contain sensitive info
if status_code in PRESERVE_MESSAGE_CODES:
if isinstance(detail, str) and not _contains_sensitive_info(detail):
return detail
if isinstance(detail, list):
# For validation errors (list of dicts), keep the structure but sanitize
sanitized = []
for item in detail:
if isinstance(item, dict):
# Keep loc, msg, type for pydantic validation errors
sanitized_item = {}
if 'loc' in item:
sanitized_item['loc'] = item['loc']
if 'msg' in item and not _contains_sensitive_info(str(item['msg'])):
sanitized_item['msg'] = item['msg']
else:
sanitized_item['msg'] = 'Validation failed'
if 'type' in item:
sanitized_item['type'] = item['type']
sanitized.append(sanitized_item)
else:
sanitized.append(item if not _contains_sensitive_info(str(item)) else 'Invalid value')
return sanitized
return detail
# For other status codes, use generic message
return GENERIC_ERROR_MESSAGES.get(status_code, "An error occurred")
class ErrorSanitizerMiddleware(BaseHTTPMiddleware):
"""Middleware to sanitize error responses in production.
This middleware:
1. Intercepts error responses (4xx and 5xx status codes)
2. Parses JSON response bodies
3. Sanitizes the 'detail' field to remove sensitive information
4. Returns the sanitized response
In DEBUG mode, original error messages are preserved for development.
"""
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# Only process error responses with JSON content
if response.status_code < 400:
return response
content_type = response.headers.get("content-type", "")
if "application/json" not in content_type:
return response
# Read the response body
body = b""
async for chunk in response.body_iterator:
body += chunk
if not body:
return response
try:
data = json.loads(body)
except (json.JSONDecodeError, UnicodeDecodeError):
# Not valid JSON, return as-is
return Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
# Sanitize the detail field if present
if "detail" in data:
original_detail = data["detail"]
data["detail"] = _sanitize_detail(original_detail, response.status_code)
# Log the original error in production for debugging
if not settings.DEBUG and original_detail != data["detail"]:
logger.warning(
"Sanitized error response",
extra={
"status_code": response.status_code,
"path": str(request.url.path),
"method": request.method,
"original_detail_length": len(str(original_detail)),
}
)
# Return the sanitized response
return JSONResponse(
content=data,
status_code=response.status_code,
headers={
k: v for k, v in response.headers.items()
if k.lower() not in ("content-length", "content-type")
},
)

View File

@@ -0,0 +1,215 @@
"""Security audit middleware for logging access denials and suspicious auth patterns."""
import time
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Tuple
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.models import AuditLog, AuditAction
from app.services.audit_service import AuditService
# In-memory storage for tracking auth failures
# Structure: {ip_address: [(timestamp, path), ...]}
_auth_failure_tracker: Dict[str, List[Tuple[float, str]]] = defaultdict(list)
# Configuration constants
AUTH_FAILURE_THRESHOLD = 5 # Number of failures to trigger suspicious pattern alert
AUTH_FAILURE_WINDOW_SECONDS = 600 # 10 minutes window
def _cleanup_old_failures(ip: str) -> None:
"""Remove auth failures older than the tracking window."""
if ip not in _auth_failure_tracker:
return
cutoff = time.time() - AUTH_FAILURE_WINDOW_SECONDS
_auth_failure_tracker[ip] = [
(ts, path) for ts, path in _auth_failure_tracker[ip]
if ts > cutoff
]
# Clean up empty entries
if not _auth_failure_tracker[ip]:
del _auth_failure_tracker[ip]
def _track_auth_failure(ip: str, path: str) -> int:
"""Track an auth failure and return the count in the window."""
_cleanup_old_failures(ip)
_auth_failure_tracker[ip].append((time.time(), path))
return len(_auth_failure_tracker[ip])
def _get_recent_failures(ip: str) -> List[str]:
"""Get list of paths that failed auth for this IP."""
_cleanup_old_failures(ip)
return [path for _, path in _auth_failure_tracker.get(ip, [])]
class SecurityAuditMiddleware(BaseHTTPMiddleware):
"""Middleware to audit security-related events like 401/403 responses."""
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
# Only process 401 and 403 responses
if response.status_code not in (401, 403):
return response
# Get client IP from audit metadata if available
ip_address = self._get_client_ip(request)
path = str(request.url.path)
method = request.method
# Get user_id if available from request state (set by auth middleware)
user_id = getattr(request.state, 'user_id', None)
db: Session = SessionLocal()
try:
if response.status_code == 403:
self._log_access_denied(db, ip_address, path, method, user_id, request)
elif response.status_code == 401:
self._log_auth_failure(db, ip_address, path, method, request)
db.commit()
except Exception:
db.rollback()
# Don't fail the request due to audit logging errors
finally:
db.close()
return response
def _get_client_ip(self, request: Request) -> str:
"""Get the real client IP address from request."""
# Check for audit metadata first (set by AuditMiddleware)
audit_metadata = getattr(request.state, 'audit_metadata', None)
if audit_metadata and 'ip_address' in audit_metadata:
return audit_metadata['ip_address']
# Fallback to checking headers directly
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
if request.client:
return request.client.host
return "unknown"
def _log_access_denied(
self,
db: Session,
ip_address: str,
path: str,
method: str,
user_id: str | None,
request: Request,
) -> None:
"""Log a 403 Forbidden response to the audit trail."""
request_metadata = {
"ip_address": ip_address,
"user_agent": request.headers.get("user-agent", ""),
"method": method,
"path": path,
}
# Try to extract resource info from path
resource_type = self._extract_resource_type(path)
AuditService.log_event(
db=db,
event_type="security.access_denied",
resource_type=resource_type,
action=AuditAction.ACCESS_DENIED,
user_id=user_id,
resource_id=None,
changes=[{
"attempted_path": path,
"attempted_method": method,
"ip_address": ip_address,
}],
request_metadata=request_metadata,
)
def _log_auth_failure(
self,
db: Session,
ip_address: str,
path: str,
method: str,
request: Request,
) -> None:
"""Log a 401 Unauthorized response and check for suspicious patterns."""
# Track this failure
failure_count = _track_auth_failure(ip_address, path)
request_metadata = {
"ip_address": ip_address,
"user_agent": request.headers.get("user-agent", ""),
"method": method,
"path": path,
}
resource_type = self._extract_resource_type(path)
# Log the auth failure
AuditService.log_event(
db=db,
event_type="security.auth_failed",
resource_type=resource_type,
action=AuditAction.AUTH_FAILED,
user_id=None, # No user for 401
resource_id=None,
changes=[{
"attempted_path": path,
"attempted_method": method,
"ip_address": ip_address,
"failure_count_in_window": failure_count,
}],
request_metadata=request_metadata,
)
# Check for suspicious pattern
if failure_count >= AUTH_FAILURE_THRESHOLD:
recent_paths = _get_recent_failures(ip_address)
AuditService.log_event(
db=db,
event_type="security.suspicious_auth_pattern",
resource_type="security",
action=AuditAction.AUTH_FAILED,
user_id=None,
resource_id=None,
changes=[{
"ip_address": ip_address,
"failure_count": failure_count,
"window_minutes": AUTH_FAILURE_WINDOW_SECONDS // 60,
"attempted_paths": list(set(recent_paths)),
}],
request_metadata=request_metadata,
)
def _extract_resource_type(self, path: str) -> str:
"""Extract resource type from path for audit logging."""
# Remove /api/ or /api/v1/ prefix
clean_path = path
if clean_path.startswith("/api/v1/"):
clean_path = clean_path[8:]
elif clean_path.startswith("/api/"):
clean_path = clean_path[5:]
# Get the first path segment as resource type
parts = clean_path.strip("/").split("/")
if parts and parts[0]:
return parts[0]
return "unknown"

View File

@@ -13,6 +13,8 @@ class AuditAction(str, enum.Enum):
RESTORE = "restore"
LOGIN = "login"
LOGOUT = "logout"
ACCESS_DENIED = "access_denied"
AUTH_FAILED = "auth_failed"
class SensitivityLevel(str, enum.Enum):
@@ -42,10 +44,20 @@ EVENT_SENSITIVITY = {
"attachment.upload": SensitivityLevel.LOW,
"attachment.download": SensitivityLevel.LOW,
"attachment.delete": SensitivityLevel.MEDIUM,
# Security events
"security.access_denied": SensitivityLevel.MEDIUM,
"security.auth_failed": SensitivityLevel.MEDIUM,
"security.suspicious_auth_pattern": SensitivityLevel.HIGH,
}
# Events that should trigger alerts
ALERT_EVENTS = {"project.delete", "user.permission_change", "user.admin_change", "role.permission_change"}
ALERT_EVENTS = {
"project.delete",
"user.permission_change",
"user.admin_change",
"role.permission_change",
"security.suspicious_auth_pattern",
}
class AuditLog(Base):
@@ -57,7 +69,7 @@ class AuditLog(Base):
resource_id = Column(String(36), nullable=True)
user_id = Column(String(36), ForeignKey("pjctrl_users.id", ondelete="SET NULL"), nullable=True)
action = Column(
Enum("create", "update", "delete", "restore", "login", "logout", name="audit_action_enum"),
Enum("create", "update", "delete", "restore", "login", "logout", "access_denied", "auth_failed", name="audit_action_enum"),
nullable=False
)
changes = Column(JSON, nullable=True)

View File

@@ -9,10 +9,25 @@ class LoginRequest(BaseModel):
class LoginResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int = Field(default=3600, description="Access token expiry in seconds")
user: "UserInfo"
class RefreshTokenRequest(BaseModel):
"""Request body for refresh token endpoint."""
refresh_token: str = Field(..., description="The refresh token to use for obtaining a new access token")
class RefreshTokenResponse(BaseModel):
"""Response for refresh token endpoint."""
access_token: str
refresh_token: str # New refresh token (rotation)
token_type: str = "bearer"
expires_in: int = Field(default=3600, description="Access token expiry in seconds")
class UserInfo(BaseModel):
id: str
email: str

View File

@@ -4,12 +4,12 @@ from pydantic import BaseModel, Field
class CommentCreate(BaseModel):
content: str = Field(..., min_length=1, max_length=10000)
content: str = Field(..., min_length=1, max_length=5000)
parent_comment_id: Optional[str] = None
class CommentUpdate(BaseModel):
content: str = Field(..., min_length=1, max_length=10000)
content: str = Field(..., min_length=1, max_length=5000)
class CommentAuthor(BaseModel):

View File

@@ -25,7 +25,7 @@ class CustomFieldDefinition(BaseModel):
class ProjectTemplateBase(BaseModel):
"""Base schema for project template."""
name: str = Field(..., min_length=1, max_length=200)
description: Optional[str] = None
description: Optional[str] = Field(None, max_length=2000)
is_public: bool = Field(default=False)
task_statuses: Optional[List[TaskStatusDefinition]] = None
custom_fields: Optional[List[CustomFieldDefinition]] = None
@@ -43,7 +43,7 @@ class ProjectTemplateCreate(ProjectTemplateBase):
class ProjectTemplateUpdate(BaseModel):
"""Schema for updating a project template."""
name: Optional[str] = Field(None, min_length=1, max_length=200)
description: Optional[str] = None
description: Optional[str] = Field(None, max_length=2000)
is_public: Optional[bool] = None
task_statuses: Optional[List[TaskStatusDefinition]] = None
custom_fields: Optional[List[CustomFieldDefinition]] = None

View File

@@ -4,6 +4,7 @@ import logging
from typing import Dict, Set, Optional, Tuple
from fastapi import WebSocket
from app.core.redis import get_redis_sync
from app.core.config import settings
logger = logging.getLogger(__name__)
@@ -19,13 +20,48 @@ class ConnectionManager:
self._lock = asyncio.Lock()
self._project_lock = asyncio.Lock()
async def check_connection_limit(self, user_id: str) -> Tuple[bool, Optional[str]]:
"""
Check if user can create a new WebSocket connection.
Args:
user_id: The user's ID
Returns:
Tuple of (can_connect: bool, reject_reason: str | None)
- can_connect: True if user is within connection limit
- reject_reason: Error message if connection should be rejected
"""
max_connections = settings.MAX_WEBSOCKET_CONNECTIONS_PER_USER
async with self._lock:
current_count = len(self.active_connections.get(user_id, set()))
if current_count >= max_connections:
logger.warning(
f"User {user_id} exceeded WebSocket connection limit "
f"({current_count}/{max_connections})"
)
return False, "Too many connections"
return True, None
def get_user_connection_count(self, user_id: str) -> int:
"""Get the current number of WebSocket connections for a user."""
return len(self.active_connections.get(user_id, set()))
async def connect(self, websocket: WebSocket, user_id: str):
"""Accept and track a new WebSocket connection."""
await websocket.accept()
"""
Track a new WebSocket connection.
Note: WebSocket must already be accepted before calling this method.
Connection limit should be checked via check_connection_limit() before calling.
"""
async with self._lock:
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
logger.debug(
f"User {user_id} connected. Total connections: "
f"{len(self.active_connections[user_id])}"
)
async def disconnect(self, websocket: WebSocket, user_id: str):
"""Remove a WebSocket connection."""

View File

@@ -166,3 +166,20 @@ def admin_token(client, mock_redis):
mock_redis.setex("session:00000000-0000-0000-0000-000000000001", 900, token)
return token
@pytest.fixture
def csrf_token():
"""Generate a CSRF token for the admin user."""
from app.core.security import generate_csrf_token
return generate_csrf_token("00000000-0000-0000-0000-000000000001")
@pytest.fixture
def auth_headers(admin_token, csrf_token):
"""Get complete auth headers including both Authorization and CSRF token."""
return {
"Authorization": f"Bearer {admin_token}",
"X-CSRF-Token": csrf_token,
}

View File

@@ -173,7 +173,7 @@ class TestProjectTemplates:
# Should return list of templates
assert "templates" in data or isinstance(data, list)
def test_create_template(self, client, admin_token, db):
def test_create_template(self, client, auth_headers, db):
"""Test creating a new project template."""
from app.models import Space
@@ -192,14 +192,14 @@ class TestProjectTemplates:
{"name": "Done", "color": "#00FF00"}
]
},
headers={"Authorization": f"Bearer {admin_token}"}
headers=auth_headers
)
assert response.status_code in [200, 201]
data = response.json()
assert data.get("name") == "Test Template"
def test_create_project_from_template(self, client, admin_token, db):
def test_create_project_from_template(self, client, auth_headers, db):
"""Test creating a project from a template."""
from app.models import Space, ProjectTemplate
@@ -228,14 +228,14 @@ class TestProjectTemplates:
"description": "Created from template",
"template_id": "test-template-id"
},
headers={"Authorization": f"Bearer {admin_token}"}
headers=auth_headers
)
assert response.status_code in [200, 201]
data = response.json()
assert data.get("name") == "Project from Template"
def test_delete_template(self, client, admin_token, db):
def test_delete_template(self, client, auth_headers, db):
"""Test deleting a project template."""
from app.models import ProjectTemplate
@@ -251,7 +251,7 @@ class TestProjectTemplates:
response = client.delete(
"/api/templates/delete-template-id",
headers={"Authorization": f"Bearer {admin_token}"}
headers=auth_headers
)
assert response.status_code in [200, 204]

View File

@@ -42,6 +42,22 @@ def test_user_token(client, mock_redis, test_user):
return token
@pytest.fixture
def test_user_csrf_token(test_user):
"""Generate a CSRF token for the test user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_user.id)
@pytest.fixture
def test_user_auth_headers(test_user_token, test_user_csrf_token):
"""Get complete auth headers for test user."""
return {
"Authorization": f"Bearer {test_user_token}",
"X-CSRF-Token": test_user_csrf_token,
}
@pytest.fixture
def test_space(db, test_user):
"""Create a test space."""
@@ -154,7 +170,7 @@ class TestFileStorageService:
class TestAttachmentAPI:
"""Tests for Attachment API endpoints."""
def test_upload_attachment(self, client, test_user_token, test_task, db, monkeypatch, temp_upload_dir):
def test_upload_attachment(self, client, test_user_auth_headers, test_task, db, monkeypatch, temp_upload_dir):
"""Test uploading an attachment."""
monkeypatch.setattr("app.core.config.settings.UPLOAD_DIR", temp_upload_dir)
@@ -163,7 +179,7 @@ class TestAttachmentAPI:
response = client.post(
f"/api/tasks/{test_task.id}/attachments",
headers={"Authorization": f"Bearer {test_user_token}"},
headers=test_user_auth_headers,
files=files,
)
@@ -271,14 +287,14 @@ class TestAttachmentAPI:
db.refresh(attachment)
assert attachment.is_deleted == True
def test_upload_blocked_file_type(self, client, test_user_token, test_task):
def test_upload_blocked_file_type(self, client, test_user_auth_headers, test_task):
"""Test that blocked file types are rejected."""
content = b"malicious content"
files = {"file": ("virus.exe", BytesIO(content), "application/octet-stream")}
response = client.post(
f"/api/tasks/{test_task.id}/attachments",
headers={"Authorization": f"Bearer {test_user_token}"},
headers=test_user_auth_headers,
files=files,
)
@@ -322,7 +338,7 @@ class TestAttachmentAPI:
assert data["total"] == 2
assert len(data["versions"]) == 2
def test_restore_version(self, client, test_user_token, test_task, db):
def test_restore_version(self, client, test_user_auth_headers, test_task, db):
"""Test restoring to a previous version."""
attachment = Attachment(
id=str(uuid.uuid4()),
@@ -351,7 +367,7 @@ class TestAttachmentAPI:
response = client.post(
f"/api/attachments/{attachment.id}/restore/1",
headers={"Authorization": f"Bearer {test_user_token}"},
headers=test_user_auth_headers,
)
assert response.status_code == 200

View File

@@ -253,7 +253,7 @@ class TestAuditAPI:
assert data["total"] == 3
assert all(log["resource_id"] == resource_id for log in data["logs"])
def test_verify_integrity(self, client, admin_token, db):
def test_verify_integrity(self, client, auth_headers, db):
"""Test integrity verification."""
now = datetime.utcnow()
@@ -270,7 +270,7 @@ class TestAuditAPI:
response = client.post(
"/api/audit-logs/verify-integrity",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={
"start_date": (now - timedelta(hours=1)).isoformat(),
"end_date": (now + timedelta(hours=1)).isoformat(),
@@ -281,7 +281,7 @@ class TestAuditAPI:
assert data["total_checked"] >= 1
assert data["invalid_count"] == 0
def test_acknowledge_alert(self, client, admin_token, db):
def test_acknowledge_alert(self, client, auth_headers, db):
"""Test acknowledging an alert."""
# Create a log and alert
log = AuditLog(
@@ -309,7 +309,7 @@ class TestAuditAPI:
response = client.put(
f"/api/audit-alerts/{alert.id}/acknowledge",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()

View File

@@ -1,5 +1,16 @@
import pytest
from app.core.security import create_access_token, decode_access_token, create_token_payload
from app.core.security import (
create_access_token,
decode_access_token,
create_token_payload,
generate_refresh_token,
store_refresh_token,
validate_refresh_token,
invalidate_refresh_token,
invalidate_all_user_refresh_tokens,
decode_refresh_token_user_id,
get_refresh_token_key,
)
class TestJWT:
@@ -59,7 +70,7 @@ class TestAuthEndpoints:
def test_get_me_without_auth(self, client):
"""Test accessing /me without authentication."""
response = client.get("/api/auth/me")
assert response.status_code == 403
assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized
def test_get_me_with_auth(self, client, admin_token):
"""Test accessing /me with valid authentication."""
@@ -72,13 +83,196 @@ class TestAuthEndpoints:
assert data["email"] == "ymirliu@panjit.com.tw"
assert data["is_system_admin"] is True
def test_logout(self, client, admin_token, mock_redis):
def test_logout(self, client, auth_headers, mock_redis):
"""Test logout endpoint."""
response = client.post(
"/api/auth/logout",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
# Verify session is removed
assert mock_redis.get("session:00000000-0000-0000-0000-000000000001") is None
class TestRefreshToken:
"""Test refresh token functionality."""
def test_generate_refresh_token(self):
"""Test that refresh tokens are generated correctly."""
token = generate_refresh_token()
assert token is not None
assert isinstance(token, str)
assert len(token) > 20 # URL-safe base64 encoded 32 bytes
def test_generate_unique_refresh_tokens(self):
"""Test that each generated token is unique."""
tokens = [generate_refresh_token() for _ in range(100)]
assert len(set(tokens)) == 100 # All tokens should be unique
def test_store_and_validate_refresh_token(self, mock_redis):
"""Test storing and validating refresh tokens."""
user_id = "test-user-123"
token = generate_refresh_token()
# Store the token
store_refresh_token(mock_redis, user_id, token)
# Validate the token
assert validate_refresh_token(mock_redis, user_id, token) is True
# Wrong user should fail
assert validate_refresh_token(mock_redis, "wrong-user", token) is False
# Wrong token should fail
assert validate_refresh_token(mock_redis, user_id, "wrong-token") is False
def test_invalidate_refresh_token(self, mock_redis):
"""Test invalidating a refresh token."""
user_id = "test-user-123"
token = generate_refresh_token()
# Store and verify
store_refresh_token(mock_redis, user_id, token)
assert validate_refresh_token(mock_redis, user_id, token) is True
# Invalidate
result = invalidate_refresh_token(mock_redis, user_id, token)
assert result is True
# Should no longer be valid
assert validate_refresh_token(mock_redis, user_id, token) is False
def test_invalidate_all_user_refresh_tokens(self, mock_redis):
"""Test invalidating all refresh tokens for a user."""
user_id = "test-user-123"
tokens = [generate_refresh_token() for _ in range(3)]
# Store multiple tokens
for token in tokens:
store_refresh_token(mock_redis, user_id, token)
# Verify all are valid
for token in tokens:
assert validate_refresh_token(mock_redis, user_id, token) is True
# Invalidate all
count = invalidate_all_user_refresh_tokens(mock_redis, user_id)
assert count == 3
# All should be invalid now
for token in tokens:
assert validate_refresh_token(mock_redis, user_id, token) is False
def test_decode_refresh_token_user_id(self, mock_redis):
"""Test finding user ID from refresh token."""
user_id = "test-user-456"
token = generate_refresh_token()
# Store the token
store_refresh_token(mock_redis, user_id, token)
# Find user ID
found_user_id = decode_refresh_token_user_id(token, mock_redis)
assert found_user_id == user_id
# Invalid token should return None
assert decode_refresh_token_user_id("invalid-token", mock_redis) is None
class TestRefreshTokenEndpoint:
"""Test the refresh token API endpoint."""
def test_refresh_token_success(self, client, db, mock_redis):
"""Test successful token refresh."""
user_id = "00000000-0000-0000-0000-000000000001"
# Generate and store a refresh token
refresh_token = generate_refresh_token()
store_refresh_token(mock_redis, user_id, refresh_token)
# Call refresh endpoint
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token},
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
assert data["expires_in"] > 0
# Old refresh token should be invalidated (rotation)
assert validate_refresh_token(mock_redis, user_id, refresh_token) is False
# New refresh token should be valid
assert validate_refresh_token(mock_redis, user_id, data["refresh_token"]) is True
def test_refresh_token_invalid(self, client, mock_redis):
"""Test refresh with invalid token."""
response = client.post(
"/api/auth/refresh",
json={"refresh_token": "invalid-token"},
)
assert response.status_code == 401
assert "Invalid or expired refresh token" in response.json()["detail"]
def test_refresh_token_rotation(self, client, db, mock_redis):
"""Test that refresh tokens are rotated (old one invalidated)."""
user_id = "00000000-0000-0000-0000-000000000001"
# Generate and store initial refresh token
initial_token = generate_refresh_token()
store_refresh_token(mock_redis, user_id, initial_token)
# First refresh
response1 = client.post(
"/api/auth/refresh",
json={"refresh_token": initial_token},
)
assert response1.status_code == 200
new_token = response1.json()["refresh_token"]
# Try to reuse the old token (should fail due to rotation)
response2 = client.post(
"/api/auth/refresh",
json={"refresh_token": initial_token},
)
assert response2.status_code == 401
# New token should still work
response3 = client.post(
"/api/auth/refresh",
json={"refresh_token": new_token},
)
assert response3.status_code == 200
def test_refresh_token_disabled_user(self, client, db, mock_redis):
"""Test that disabled users cannot refresh tokens."""
from app.models.user import User
# Create a disabled user
disabled_user = User(
id="disabled-user-123",
email="disabled@example.com",
name="Disabled User",
is_active=False,
)
db.add(disabled_user)
db.commit()
# Generate and store refresh token for disabled user
refresh_token = generate_refresh_token()
store_refresh_token(mock_redis, disabled_user.id, refresh_token)
# Try to refresh
response = client.post(
"/api/auth/refresh",
json={"refresh_token": refresh_token},
)
assert response.status_code == 403
assert "disabled" in response.json()["detail"].lower()

View File

@@ -128,7 +128,7 @@ class TestRedisFailover:
class TestBlockerDeletionCheck:
"""Test blocker check before task deletion."""
def test_delete_task_with_blockers_warning(self, client, admin_token, db):
def test_delete_task_with_blockers_warning(self, client, admin_token, csrf_token, db):
"""Test that deleting task with blockers shows warning."""
from app.models import Space, Project, Task, TaskStatus, TaskDependency
@@ -174,7 +174,7 @@ class TestBlockerDeletionCheck:
# Try to delete without force
response = client.delete(
"/api/tasks/blocker-task",
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
)
# Should return warning or require confirmation
@@ -185,7 +185,7 @@ class TestBlockerDeletionCheck:
if "warning" in data or "blocker_count" in data:
assert data.get("blocker_count", 0) >= 1 or "blocker" in str(data).lower()
def test_force_delete_resolves_blockers(self, client, admin_token, db):
def test_force_delete_resolves_blockers(self, client, admin_token, csrf_token, db):
"""Test that force delete resolves blockers."""
from app.models import Space, Project, Task, TaskStatus, TaskDependency
@@ -231,7 +231,7 @@ class TestBlockerDeletionCheck:
# Force delete
response = client.delete(
"/api/tasks/force-del-task?force_delete=true",
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
)
assert response.status_code == 200
@@ -240,7 +240,7 @@ class TestBlockerDeletionCheck:
db.refresh(task_to_delete)
assert task_to_delete.is_deleted is True
def test_delete_task_without_blockers(self, client, admin_token, db):
def test_delete_task_without_blockers(self, client, admin_token, csrf_token, db):
"""Test deleting task without blockers succeeds normally."""
from app.models import Space, Project, Task, TaskStatus
@@ -267,7 +267,7 @@ class TestBlockerDeletionCheck:
# Delete should succeed without warning
response = client.delete(
"/api/tasks/no-blocker-task",
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
)
assert response.status_code == 200

View File

@@ -36,6 +36,13 @@ def user_token(client, mock_redis, test_user):
return token
@pytest.fixture
def user_csrf_token(test_user):
"""Generate a CSRF token for the test user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_user.id)
@pytest.fixture
def test_space(db):
"""Create a test space."""
@@ -100,11 +107,11 @@ def test_task(db, test_project, test_status):
class TestComments:
"""Tests for Comments API."""
def test_create_comment(self, client, admin_token, test_task):
def test_create_comment(self, client, auth_headers, test_task):
"""Test creating a comment."""
response = client.post(
f"/api/tasks/{test_task.id}/comments",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={"content": "This is a test comment"},
)
assert response.status_code == 201
@@ -136,7 +143,7 @@ class TestComments:
assert len(data["comments"]) == 1
assert data["comments"][0]["content"] == "Test comment"
def test_update_comment(self, client, admin_token, db, test_task):
def test_update_comment(self, client, auth_headers, db, test_task):
"""Test updating a comment."""
comment = Comment(
id=str(uuid.uuid4()),
@@ -149,7 +156,7 @@ class TestComments:
response = client.put(
f"/api/comments/{comment.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={"content": "Updated content"},
)
assert response.status_code == 200
@@ -157,7 +164,7 @@ class TestComments:
assert data["content"] == "Updated content"
assert data["is_edited"] is True
def test_delete_comment(self, client, admin_token, db, test_task):
def test_delete_comment(self, client, auth_headers, db, test_task):
"""Test deleting a comment (soft delete)."""
comment = Comment(
id=str(uuid.uuid4()),
@@ -170,7 +177,7 @@ class TestComments:
response = client.delete(
f"/api/comments/{comment.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 204
@@ -178,13 +185,13 @@ class TestComments:
db.refresh(comment)
assert comment.is_deleted is True
def test_mention_limit(self, client, admin_token, test_task):
def test_mention_limit(self, client, auth_headers, test_task):
"""Test that @mention limit is enforced."""
# Create content with more than 10 mentions
mentions = " ".join([f"@user{i}" for i in range(15)])
response = client.post(
f"/api/tasks/{test_task.id}/comments",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={"content": f"Test with many mentions: {mentions}"},
)
assert response.status_code == 400
@@ -218,7 +225,7 @@ class TestNotifications:
assert data["total"] >= 1
assert data["unread_count"] >= 1
def test_mark_notification_as_read(self, client, admin_token, db):
def test_mark_notification_as_read(self, client, auth_headers, db):
"""Test marking a notification as read."""
notification = Notification(
id=str(uuid.uuid4()),
@@ -233,14 +240,14 @@ class TestNotifications:
response = client.put(
f"/api/notifications/{notification.id}/read",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["is_read"] is True
assert data["read_at"] is not None
def test_mark_all_as_read(self, client, admin_token, db):
def test_mark_all_as_read(self, client, auth_headers, db):
"""Test marking all notifications as read."""
# Create multiple unread notifications
for i in range(3):
@@ -257,7 +264,7 @@ class TestNotifications:
response = client.put(
"/api/notifications/read-all",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
@@ -290,11 +297,11 @@ class TestNotifications:
class TestBlockers:
"""Tests for Blockers API."""
def test_create_blocker(self, client, admin_token, test_task):
def test_create_blocker(self, client, auth_headers, test_task):
"""Test creating a blocker."""
response = client.post(
f"/api/tasks/{test_task.id}/blockers",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={"reason": "Waiting for external dependency"},
)
assert response.status_code == 201
@@ -302,7 +309,7 @@ class TestBlockers:
assert data["reason"] == "Waiting for external dependency"
assert data["resolved_at"] is None
def test_resolve_blocker(self, client, admin_token, db, test_task):
def test_resolve_blocker(self, client, auth_headers, db, test_task):
"""Test resolving a blocker."""
blocker = Blocker(
id=str(uuid.uuid4()),
@@ -316,7 +323,7 @@ class TestBlockers:
response = client.put(
f"/api/blockers/{blocker.id}/resolve",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={"resolution_note": "Issue resolved by updating config"},
)
assert response.status_code == 200
@@ -348,7 +355,7 @@ class TestBlockers:
assert data["total"] == 1
assert data["blockers"][0]["reason"] == "Test blocker"
def test_cannot_create_duplicate_active_blocker(self, client, admin_token, db, test_task):
def test_cannot_create_duplicate_active_blocker(self, client, auth_headers, db, test_task):
"""Test that duplicate active blockers are prevented."""
# Create first blocker
blocker = Blocker(
@@ -363,7 +370,7 @@ class TestBlockers:
# Try to create second blocker
response = client.post(
f"/api/tasks/{test_task.id}/blockers",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={"reason": "Second blocker"},
)
assert response.status_code == 400

View File

@@ -18,7 +18,7 @@ from datetime import datetime, timedelta
class TestOptimisticLocking:
"""Test optimistic locking for concurrent updates."""
def test_version_increments_on_update(self, client, admin_token, db):
def test_version_increments_on_update(self, client, admin_token, csrf_token, db):
"""Test that task version increments on successful update."""
from app.models import Space, Project, Task, TaskStatus
@@ -47,7 +47,7 @@ class TestOptimisticLocking:
response = client.patch(
"/api/tasks/task-1",
json={"title": "Updated Task", "version": 1},
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
)
assert response.status_code == 200
@@ -55,7 +55,7 @@ class TestOptimisticLocking:
assert data["title"] == "Updated Task"
assert data["version"] == 2 # Version should increment
def test_version_conflict_returns_409(self, client, admin_token, db):
def test_version_conflict_returns_409(self, client, admin_token, csrf_token, db):
"""Test that stale version returns 409 Conflict."""
from app.models import Space, Project, Task, TaskStatus
@@ -84,7 +84,7 @@ class TestOptimisticLocking:
response = client.patch(
"/api/tasks/task-2",
json={"title": "Stale Update", "version": 1},
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
)
assert response.status_code == 409
@@ -94,7 +94,7 @@ class TestOptimisticLocking:
assert detail.get("current_version") == 5
assert detail.get("provided_version") == 1
def test_update_without_version_succeeds(self, client, admin_token, db):
def test_update_without_version_succeeds(self, client, admin_token, csrf_token, db):
"""Test that update without version (for backward compatibility) still works."""
from app.models import Space, Project, Task, TaskStatus
@@ -123,7 +123,7 @@ class TestOptimisticLocking:
response = client.patch(
"/api/tasks/task-3",
json={"title": "No Version Update"},
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
)
# Should succeed (backward compatibility)
@@ -179,7 +179,7 @@ class TestTriggerRetryMechanism:
class TestCascadeRestore:
"""Test cascade restore for soft-deleted tasks."""
def test_restore_parent_with_children(self, client, admin_token, db):
def test_restore_parent_with_children(self, client, admin_token, csrf_token, db):
"""Test restoring parent task also restores children deleted at same time."""
from app.models import Space, Project, Task, TaskStatus
from datetime import datetime
@@ -236,7 +236,7 @@ class TestCascadeRestore:
response = client.post(
"/api/tasks/parent-task/restore",
json={"cascade": True},
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
)
assert response.status_code == 200
@@ -254,7 +254,7 @@ class TestCascadeRestore:
assert child_task1.is_deleted is False
assert child_task2.is_deleted is False
def test_restore_parent_only(self, client, admin_token, db):
def test_restore_parent_only(self, client, admin_token, csrf_token, db):
"""Test restoring parent task without cascade leaves children deleted."""
from app.models import Space, Project, Task, TaskStatus
from datetime import datetime
@@ -299,7 +299,7 @@ class TestCascadeRestore:
response = client.post(
"/api/tasks/parent-task-2/restore",
json={"cascade": False},
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token}
)
assert response.status_code == 200

View File

@@ -39,7 +39,7 @@ class TestCustomFieldsCRUD:
db.commit()
return project
def test_create_text_field(self, client, db, admin_token):
def test_create_text_field(self, client, db, auth_headers):
"""Test creating a text custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -50,7 +50,7 @@ class TestCustomFieldsCRUD:
"field_type": "text",
"is_required": False,
},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 201
@@ -59,7 +59,7 @@ class TestCustomFieldsCRUD:
assert data["field_type"] == "text"
assert data["is_required"] is False
def test_create_number_field(self, client, db, admin_token):
def test_create_number_field(self, client, db, auth_headers):
"""Test creating a number custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -70,7 +70,7 @@ class TestCustomFieldsCRUD:
"field_type": "number",
"is_required": True,
},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 201
@@ -79,7 +79,7 @@ class TestCustomFieldsCRUD:
assert data["field_type"] == "number"
assert data["is_required"] is True
def test_create_dropdown_field(self, client, db, admin_token):
def test_create_dropdown_field(self, client, db, auth_headers):
"""Test creating a dropdown custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -91,7 +91,7 @@ class TestCustomFieldsCRUD:
"options": ["Frontend", "Backend", "Database", "API"],
"is_required": False,
},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 201
@@ -100,7 +100,7 @@ class TestCustomFieldsCRUD:
assert data["field_type"] == "dropdown"
assert data["options"] == ["Frontend", "Backend", "Database", "API"]
def test_create_dropdown_field_without_options_fails(self, client, db, admin_token):
def test_create_dropdown_field_without_options_fails(self, client, db, auth_headers):
"""Test that creating a dropdown field without options fails."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -111,12 +111,12 @@ class TestCustomFieldsCRUD:
"field_type": "dropdown",
"options": [],
},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 422 # Validation error
def test_create_formula_field(self, client, db, admin_token):
def test_create_formula_field(self, client, db, auth_headers):
"""Test creating a formula custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -127,7 +127,7 @@ class TestCustomFieldsCRUD:
"name": "hours_worked",
"field_type": "number",
},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
# Create formula field
@@ -138,7 +138,7 @@ class TestCustomFieldsCRUD:
"field_type": "formula",
"formula": "{time_spent} / {original_estimate} * 100",
},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 201
@@ -147,7 +147,7 @@ class TestCustomFieldsCRUD:
assert data["field_type"] == "formula"
assert "{time_spent}" in data["formula"]
def test_list_custom_fields(self, client, db, admin_token):
def test_list_custom_fields(self, client, db, auth_headers):
"""Test listing custom fields for a project."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -155,17 +155,17 @@ class TestCustomFieldsCRUD:
client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "Field 1", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "Field 2", "field_type": "number"},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
response = client.get(
f"/api/projects/{project.id}/custom-fields",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
@@ -173,7 +173,7 @@ class TestCustomFieldsCRUD:
assert data["total"] == 2
assert len(data["fields"]) == 2
def test_update_custom_field(self, client, db, admin_token):
def test_update_custom_field(self, client, db, auth_headers):
"""Test updating a custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -181,7 +181,7 @@ class TestCustomFieldsCRUD:
create_response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "Original Name", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
field_id = create_response.json()["id"]
@@ -189,7 +189,7 @@ class TestCustomFieldsCRUD:
response = client.put(
f"/api/custom-fields/{field_id}",
json={"name": "Updated Name", "is_required": True},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
@@ -197,7 +197,7 @@ class TestCustomFieldsCRUD:
assert data["name"] == "Updated Name"
assert data["is_required"] is True
def test_delete_custom_field(self, client, db, admin_token):
def test_delete_custom_field(self, client, db, auth_headers):
"""Test deleting a custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -205,14 +205,14 @@ class TestCustomFieldsCRUD:
create_response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "To Delete", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
field_id = create_response.json()["id"]
# Delete it
response = client.delete(
f"/api/custom-fields/{field_id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 204
@@ -220,11 +220,11 @@ class TestCustomFieldsCRUD:
# Verify it's gone
get_response = client.get(
f"/api/custom-fields/{field_id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert get_response.status_code == 404
def test_max_fields_limit(self, client, db, admin_token):
def test_max_fields_limit(self, client, db, auth_headers):
"""Test that maximum 20 custom fields per project is enforced."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -233,7 +233,7 @@ class TestCustomFieldsCRUD:
response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": f"Field {i}", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 201
@@ -241,12 +241,12 @@ class TestCustomFieldsCRUD:
response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "Field 21", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 400
assert "Maximum" in response.json()["detail"]
def test_duplicate_name_rejected(self, client, db, admin_token):
def test_duplicate_name_rejected(self, client, db, auth_headers):
"""Test that duplicate field names are rejected."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
@@ -254,14 +254,14 @@ class TestCustomFieldsCRUD:
client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "Unique Name", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
# Try to create another with same name
response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "Unique Name", "field_type": "number"},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 400
assert "already exists" in response.json()["detail"]
@@ -311,7 +311,7 @@ class TestFormulaService:
class TestCustomValuesWithTasks:
"""Test custom values integration with tasks."""
def setup_project_with_fields(self, db, client, admin_token, owner_id: str):
def setup_project_with_fields(self, db, client, auth_headers, owner_id: str):
"""Create a project with custom fields for testing."""
space = Space(
id="test-space-002",
@@ -342,23 +342,23 @@ class TestCustomValuesWithTasks:
text_response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "sprint_number", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
text_field_id = text_response.json()["id"]
number_response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "story_points", "field_type": "number"},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
number_field_id = number_response.json()["id"]
return project, text_field_id, number_field_id
def test_create_task_with_custom_values(self, client, db, admin_token):
def test_create_task_with_custom_values(self, client, db, auth_headers):
"""Test creating a task with custom values."""
project, text_field_id, number_field_id = self.setup_project_with_fields(
db, client, admin_token, "00000000-0000-0000-0000-000000000001"
db, client, auth_headers, "00000000-0000-0000-0000-000000000001"
)
response = client.post(
@@ -370,15 +370,15 @@ class TestCustomValuesWithTasks:
{"field_id": number_field_id, "value": "8"},
],
},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 201
def test_get_task_includes_custom_values(self, client, db, admin_token):
def test_get_task_includes_custom_values(self, client, db, auth_headers):
"""Test that getting a task includes custom values."""
project, text_field_id, number_field_id = self.setup_project_with_fields(
db, client, admin_token, "00000000-0000-0000-0000-000000000001"
db, client, auth_headers, "00000000-0000-0000-0000-000000000001"
)
# Create task with custom values
@@ -391,14 +391,14 @@ class TestCustomValuesWithTasks:
{"field_id": number_field_id, "value": "8"},
],
},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
task_id = create_response.json()["id"]
# Get task and check custom values
get_response = client.get(
f"/api/tasks/{task_id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert get_response.status_code == 200
@@ -406,10 +406,10 @@ class TestCustomValuesWithTasks:
assert data["custom_values"] is not None
assert len(data["custom_values"]) >= 2
def test_update_task_custom_values(self, client, db, admin_token):
def test_update_task_custom_values(self, client, db, auth_headers):
"""Test updating custom values on a task."""
project, text_field_id, number_field_id = self.setup_project_with_fields(
db, client, admin_token, "00000000-0000-0000-0000-000000000001"
db, client, auth_headers, "00000000-0000-0000-0000-000000000001"
)
# Create task
@@ -421,7 +421,7 @@ class TestCustomValuesWithTasks:
{"field_id": text_field_id, "value": "Sprint 5"},
],
},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
task_id = create_response.json()["id"]
@@ -434,7 +434,7 @@ class TestCustomValuesWithTasks:
{"field_id": number_field_id, "value": "13"},
],
},
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert update_response.status_code == 200

View File

@@ -619,7 +619,7 @@ class TestDashboardAPI:
def test_dashboard_unauthorized(self, client, db):
"""Unauthenticated requests should fail."""
response = client.get("/api/dashboard")
assert response.status_code == 403
assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized
def test_dashboard_with_user_tasks(self, client, db, admin_token):
"""Dashboard should reflect user's tasks correctly."""

View File

@@ -312,6 +312,13 @@ class TestConfidentialProjectUpload:
mock_redis.setex(f"session:{test_user.id}", 900, token)
return token
@pytest.fixture
def test_user_csrf_token(self, test_user):
"""Generate a CSRF token for the test user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_user.id)
@pytest.fixture
def test_space(self, db, test_user):
"""Create a test space."""
@@ -364,7 +371,7 @@ class TestConfidentialProjectUpload:
return task
def test_upload_confidential_project_encryption_unavailable(
self, client, test_user_token, test_task, db
self, client, test_user_token, test_user_csrf_token, test_task, db
):
"""Test that uploading to confidential project returns 400 when encryption is unavailable."""
from io import BytesIO
@@ -378,7 +385,7 @@ class TestConfidentialProjectUpload:
response = client.post(
f"/api/tasks/{test_task.id}/attachments",
headers={"Authorization": f"Bearer {test_user_token}"},
headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
files=files,
)
@@ -387,7 +394,7 @@ class TestConfidentialProjectUpload:
assert "environment variable" in response.json()["detail"]
def test_upload_confidential_project_no_active_key(
self, client, test_user_token, test_task, db
self, client, test_user_token, test_user_csrf_token, test_task, db
):
"""Test that uploading to confidential project returns 400 when no active encryption key exists."""
from io import BytesIO
@@ -408,7 +415,7 @@ class TestConfidentialProjectUpload:
response = client.post(
f"/api/tasks/{test_task.id}/attachments",
headers={"Authorization": f"Bearer {test_user_token}"},
headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
files=files,
)

View File

@@ -614,7 +614,7 @@ class TestHealthAPI:
def test_unauthorized_access(self, client, db):
"""Unauthenticated requests should fail."""
response = client.get("/api/projects/health/dashboard")
assert response.status_code == 403
assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized
def test_dashboard_with_status_filter(self, client, db, admin_token):
"""Dashboard should respect status filter."""

View File

@@ -38,6 +38,14 @@ def test_user_token(client, mock_redis, test_user):
return token
@pytest.fixture
def test_user_csrf_token(test_user):
"""Generate a CSRF token for the test user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_user.id)
@pytest.fixture
def test_space(db, test_user):
"""Create a test space."""
@@ -284,11 +292,11 @@ class TestReportAPI:
assert "projects" in data
assert data["summary"]["total_tasks"] == 3
def test_generate_weekly_report_api(self, client, test_user_token, test_project, test_tasks, test_statuses):
def test_generate_weekly_report_api(self, client, test_user_token, test_user_csrf_token, test_project, test_tasks, test_statuses):
"""Test generating weekly report via API."""
response = client.post(
"/api/reports/weekly/generate",
headers={"Authorization": f"Bearer {test_user_token}"},
headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
)
assert response.status_code == 200
@@ -297,7 +305,7 @@ class TestReportAPI:
assert "report_id" in data
assert "summary" in data
def test_weekly_report_subscription_toggle(self, client, test_user_token, db, test_user):
def test_weekly_report_subscription_toggle(self, client, test_user_token, test_user_csrf_token, db, test_user):
"""Test weekly report subscription toggle endpoints."""
response = client.get(
"/api/reports/weekly/subscription",
@@ -308,7 +316,7 @@ class TestReportAPI:
response = client.put(
"/api/reports/weekly/subscription",
headers={"Authorization": f"Bearer {test_user_token}"},
headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={"is_active": True},
)
assert response.status_code == 200
@@ -323,7 +331,7 @@ class TestReportAPI:
response = client.put(
"/api/reports/weekly/subscription",
headers={"Authorization": f"Bearer {test_user_token}"},
headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={"is_active": False},
)
assert response.status_code == 200

View File

@@ -52,6 +52,14 @@ def test_user_token(client, mock_redis, test_user):
return token
@pytest.fixture
def test_user_csrf_token(test_user):
"""Generate a CSRF token for the test user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_user.id)
@pytest.fixture
def test_space(db, test_user):
"""Create a test space."""
@@ -445,11 +453,11 @@ class TestDeadlineReminderLogic:
class TestScheduleTriggerAPI:
"""Tests for Schedule Trigger API endpoints."""
def test_create_cron_trigger(self, client, test_user_token, test_project):
def test_create_cron_trigger(self, client, test_user_token, test_user_csrf_token, test_project):
"""Test creating a schedule trigger with cron expression."""
response = client.post(
f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"},
headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={
"name": "Weekly Monday Reminder",
"description": "Remind every Monday at 9am",
@@ -471,11 +479,11 @@ class TestScheduleTriggerAPI:
assert data["trigger_type"] == "schedule"
assert data["conditions"]["cron_expression"] == "0 9 * * 1"
def test_create_deadline_trigger(self, client, test_user_token, test_project):
def test_create_deadline_trigger(self, client, test_user_token, test_user_csrf_token, test_project):
"""Test creating a schedule trigger with deadline reminder."""
response = client.post(
f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"},
headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={
"name": "Deadline Reminder",
"description": "Remind 5 days before deadline",
@@ -494,11 +502,11 @@ class TestScheduleTriggerAPI:
data = response.json()
assert data["conditions"]["deadline_reminder_days"] == 5
def test_create_schedule_trigger_invalid_cron(self, client, test_user_token, test_project):
def test_create_schedule_trigger_invalid_cron(self, client, test_user_token, test_user_csrf_token, test_project):
"""Test creating a schedule trigger with invalid cron expression."""
response = client.post(
f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"},
headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={
"name": "Invalid Cron Trigger",
"trigger_type": "schedule",
@@ -512,11 +520,11 @@ class TestScheduleTriggerAPI:
assert response.status_code == 400
assert "Invalid cron expression" in response.json()["detail"]
def test_create_schedule_trigger_missing_condition(self, client, test_user_token, test_project):
def test_create_schedule_trigger_missing_condition(self, client, test_user_token, test_user_csrf_token, test_project):
"""Test creating a schedule trigger without cron or deadline condition."""
response = client.post(
f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"},
headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={
"name": "Empty Schedule Trigger",
"trigger_type": "schedule",
@@ -528,11 +536,11 @@ class TestScheduleTriggerAPI:
assert response.status_code == 400
assert "require either cron_expression or deadline_reminder_days" in response.json()["detail"]
def test_update_schedule_trigger_cron(self, client, test_user_token, cron_trigger):
def test_update_schedule_trigger_cron(self, client, test_user_token, test_user_csrf_token, cron_trigger):
"""Test updating a schedule trigger's cron expression."""
response = client.put(
f"/api/triggers/{cron_trigger.id}",
headers={"Authorization": f"Bearer {test_user_token}"},
headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={
"conditions": {
"cron_expression": "0 10 * * *", # Changed to 10am
@@ -544,11 +552,11 @@ class TestScheduleTriggerAPI:
data = response.json()
assert data["conditions"]["cron_expression"] == "0 10 * * *"
def test_update_schedule_trigger_invalid_cron(self, client, test_user_token, cron_trigger):
def test_update_schedule_trigger_invalid_cron(self, client, test_user_token, test_user_csrf_token, cron_trigger):
"""Test updating a schedule trigger with invalid cron expression."""
response = client.put(
f"/api/triggers/{cron_trigger.id}",
headers={"Authorization": f"Bearer {test_user_token}"},
headers={"Authorization": f"Bearer {test_user_token}", "X-CSRF-Token": test_user_csrf_token},
json={
"conditions": {
"cron_expression": "not valid",

View File

@@ -69,6 +69,22 @@ def regular_token(client, mock_redis, test_regular_user):
return token
@pytest.fixture
def csrf_token(test_admin):
"""Generate a CSRF token for the test admin user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_admin.id)
@pytest.fixture
def auth_headers(admin_token, csrf_token):
"""Get complete auth headers including both Authorization and CSRF token."""
return {
"Authorization": f"Bearer {admin_token}",
"X-CSRF-Token": csrf_token,
}
@pytest.fixture
def test_space(db, test_admin):
"""Create a test space."""
@@ -148,11 +164,11 @@ def test_task_with_subtask(db, test_project, test_admin, test_status, test_task)
class TestSoftDelete:
"""Tests for soft delete functionality."""
def test_delete_task_soft_deletes(self, client, admin_token, test_task, db):
def test_delete_task_soft_deletes(self, client, auth_headers, test_task, db):
"""Test that DELETE soft-deletes a task."""
response = client.delete(
f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
@@ -165,36 +181,36 @@ class TestSoftDelete:
assert test_task.deleted_at is not None
assert test_task.deleted_by is not None
def test_deleted_task_not_in_list(self, client, admin_token, test_project, test_task, db):
def test_deleted_task_not_in_list(self, client, auth_headers, test_project, test_task, db):
"""Test that deleted tasks are not shown in list."""
# Delete the task
client.delete(
f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
# List tasks
response = client.get(
f"/api/projects/{test_project.id}/tasks",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 0
def test_admin_can_list_deleted_with_include_deleted(self, client, admin_token, test_project, test_task, db):
def test_admin_can_list_deleted_with_include_deleted(self, client, auth_headers, test_project, test_task, db):
"""Test that admin can see deleted tasks with include_deleted parameter."""
# Delete the task
client.delete(
f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
# List with include_deleted
response = client.get(
f"/api/projects/{test_project.id}/tasks?include_deleted=true",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
@@ -202,12 +218,12 @@ class TestSoftDelete:
assert data["total"] == 1
assert data["tasks"][0]["id"] == test_task.id
def test_regular_user_cannot_see_deleted_with_include_deleted(self, client, regular_token, test_project, test_task, admin_token, db):
def test_regular_user_cannot_see_deleted_with_include_deleted(self, client, regular_token, test_project, test_task, auth_headers, db, csrf_token):
"""Test that non-admin cannot see deleted tasks even with include_deleted."""
# Delete the task as admin
client.delete(
f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
# Try to list with include_deleted as regular user
@@ -220,12 +236,12 @@ class TestSoftDelete:
data = response.json()
assert data["total"] == 0
def test_get_deleted_task_returns_404_for_regular_user(self, client, admin_token, regular_token, test_task, db):
def test_get_deleted_task_returns_404_for_regular_user(self, client, auth_headers, regular_token, test_task, db):
"""Test that getting a deleted task returns 404 for non-admin."""
# Delete the task
client.delete(
f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
# Try to get as regular user
@@ -236,28 +252,28 @@ class TestSoftDelete:
assert response.status_code == 404
def test_admin_can_view_deleted_task(self, client, admin_token, test_task, db):
def test_admin_can_view_deleted_task(self, client, auth_headers, test_task, db):
"""Test that admin can view a deleted task."""
# Delete the task
client.delete(
f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
# Get as admin
response = client.get(
f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
def test_cascade_soft_delete_subtasks(self, client, admin_token, test_task, test_task_with_subtask, db):
def test_cascade_soft_delete_subtasks(self, client, auth_headers, test_task, test_task_with_subtask, db):
"""Test that deleting a parent task soft-deletes its subtasks."""
# Delete the parent task
response = client.delete(
f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
@@ -270,18 +286,18 @@ class TestSoftDelete:
class TestRestoreTask:
"""Tests for task restoration functionality."""
def test_restore_task(self, client, admin_token, test_task, db):
def test_restore_task(self, client, auth_headers, test_task, db):
"""Test that admin can restore a deleted task."""
# Delete the task
client.delete(
f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
# Restore the task
response = client.post(
f"/api/tasks/{test_task.id}/restore",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
@@ -292,27 +308,29 @@ class TestRestoreTask:
assert test_task.deleted_at is None
assert test_task.deleted_by is None
def test_regular_user_cannot_restore(self, client, admin_token, regular_token, test_task, db):
def test_regular_user_cannot_restore(self, client, auth_headers, regular_token, test_task, db, test_regular_user):
"""Test that non-admin cannot restore a deleted task."""
from app.core.security import generate_csrf_token
# Delete the task
client.delete(
f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
# Try to restore as regular user
regular_csrf = generate_csrf_token(test_regular_user.id)
response = client.post(
f"/api/tasks/{test_task.id}/restore",
headers={"Authorization": f"Bearer {regular_token}"},
headers={"Authorization": f"Bearer {regular_token}", "X-CSRF-Token": regular_csrf},
)
assert response.status_code == 403
def test_cannot_restore_non_deleted_task(self, client, admin_token, test_task, db):
def test_cannot_restore_non_deleted_task(self, client, auth_headers, test_task, db):
"""Test that restoring a non-deleted task returns error."""
response = client.post(
f"/api/tasks/{test_task.id}/restore",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 400
@@ -322,12 +340,12 @@ class TestRestoreTask:
class TestSubtaskCount:
"""Tests for subtask count excluding deleted."""
def test_subtask_count_excludes_deleted(self, client, admin_token, test_task, test_task_with_subtask, db):
def test_subtask_count_excludes_deleted(self, client, auth_headers, test_task, test_task_with_subtask, db):
"""Test that subtask_count excludes deleted subtasks."""
# Get parent task before deletion
response = client.get(
f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
assert response.json()["subtask_count"] == 1
@@ -335,13 +353,13 @@ class TestSubtaskCount:
# Delete subtask
client.delete(
f"/api/tasks/{test_task_with_subtask.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
# Get parent task after deletion
response = client.get(
f"/api/tasks/{test_task.id}",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
)
assert response.status_code == 200
assert response.json()["subtask_count"] == 0

View File

@@ -57,7 +57,7 @@ class TestSpacesAPI:
"/api/spaces",
json={"name": "Test Space", "description": "Test"}
)
assert response.status_code == 403 # No auth header
assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized
def test_space_routes_exist(self):
"""Test that all space routes are registered."""

View File

@@ -783,7 +783,7 @@ class TestDateValidation:
class TestDependencyCRUDAPI:
"""Test dependency CRUD API endpoints."""
def test_create_dependency(self, client, db, admin_token):
def test_create_dependency(self, client, db, admin_token, csrf_token):
"""Test creating a dependency via API."""
# Create test data
space = Space(
@@ -838,7 +838,7 @@ class TestDependencyCRUDAPI:
"dependency_type": "FS",
"lag_days": 0
},
headers={"Authorization": f"Bearer {admin_token}"},
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
)
assert response.status_code == 201
@@ -914,7 +914,7 @@ class TestDependencyCRUDAPI:
assert data["total"] >= 1
assert any(d["predecessor_id"] == "task-api-list-1" for d in data["dependencies"])
def test_delete_dependency(self, client, db, admin_token):
def test_delete_dependency(self, client, db, admin_token, csrf_token):
"""Test deleting a dependency."""
# Create test data
space = Space(
@@ -973,7 +973,7 @@ class TestDependencyCRUDAPI:
# Delete dependency
response = client.delete(
"/api/task-dependencies/dep-api-del",
headers={"Authorization": f"Bearer {admin_token}"},
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
)
assert response.status_code == 204
@@ -984,7 +984,7 @@ class TestDependencyCRUDAPI:
).first()
assert dep_check is None
def test_circular_dependency_rejected_via_api(self, client, db, admin_token):
def test_circular_dependency_rejected_via_api(self, client, db, admin_token, csrf_token):
"""Test that circular dependencies are rejected via API."""
# Create test data
space = Space(
@@ -1049,7 +1049,7 @@ class TestDependencyCRUDAPI:
"dependency_type": "FS",
"lag_days": 0
},
headers={"Authorization": f"Bearer {admin_token}"},
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
)
assert response.status_code == 400
@@ -1060,7 +1060,7 @@ class TestDependencyCRUDAPI:
class TestTaskDateValidationAPI:
"""Test task date validation in task API."""
def test_create_task_with_invalid_dates_rejected(self, client, db, admin_token):
def test_create_task_with_invalid_dates_rejected(self, client, db, admin_token, csrf_token):
"""Test that creating a task with start_date > due_date is rejected."""
# Create test data
space = Space(
@@ -1099,13 +1099,13 @@ class TestTaskDateValidationAPI:
"start_date": (now + timedelta(days=10)).isoformat(),
"due_date": now.isoformat(),
},
headers={"Authorization": f"Bearer {admin_token}"},
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
)
assert response.status_code == 400
assert "Start date cannot be after due date" in response.json()["detail"]
def test_update_task_with_invalid_dates_rejected(self, client, db, admin_token):
def test_update_task_with_invalid_dates_rejected(self, client, db, admin_token, csrf_token):
"""Test that updating a task to have start_date > due_date is rejected."""
# Create test data
space = Space(
@@ -1153,12 +1153,12 @@ class TestTaskDateValidationAPI:
json={
"start_date": (now + timedelta(days=20)).isoformat(),
},
headers={"Authorization": f"Bearer {admin_token}"},
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
)
assert response.status_code == 400
def test_create_task_with_valid_dates_accepted(self, client, db, admin_token):
def test_create_task_with_valid_dates_accepted(self, client, db, admin_token, csrf_token):
"""Test that creating a task with valid dates is accepted."""
# Create test data
space = Space(
@@ -1197,7 +1197,7 @@ class TestTaskDateValidationAPI:
"start_date": now.isoformat(),
"due_date": (now + timedelta(days=10)).isoformat(),
},
headers={"Authorization": f"Bearer {admin_token}"},
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
)
assert response.status_code == 201
@@ -1217,7 +1217,7 @@ class TestDependencyTypes:
assert DependencyType.FF.value == "FF"
assert DependencyType.SF.value == "SF"
def test_create_dependency_with_different_types(self, client, db, admin_token):
def test_create_dependency_with_different_types(self, client, db, admin_token, csrf_token):
"""Test creating dependencies with different types via API."""
# Create test data
space = Space(
@@ -1268,7 +1268,7 @@ class TestDependencyTypes:
"dependency_type": dep_type,
"lag_days": i
},
headers={"Authorization": f"Bearer {admin_token}"},
headers={"Authorization": f"Bearer {admin_token}", "X-CSRF-Token": csrf_token},
)
assert response.status_code == 201

View File

@@ -43,6 +43,22 @@ def test_user_token(client, mock_redis, test_user):
return token
@pytest.fixture
def test_user_csrf_token(test_user):
"""Generate a CSRF token for the test user."""
from app.core.security import generate_csrf_token
return generate_csrf_token(test_user.id)
@pytest.fixture
def test_user_auth_headers(test_user_token, test_user_csrf_token):
"""Get complete auth headers for test user."""
return {
"Authorization": f"Bearer {test_user_token}",
"X-CSRF-Token": test_user_csrf_token,
}
@pytest.fixture
def test_space(db, test_user):
"""Create a test space."""
@@ -513,11 +529,11 @@ class TestTriggerNotifications:
class TestTriggerAPI:
"""Tests for Trigger API endpoints."""
def test_create_trigger(self, client, test_user_token, test_project, test_status):
def test_create_trigger(self, client, test_user_auth_headers, test_project, test_status):
"""Test creating a trigger."""
response = client.post(
f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"},
headers=test_user_auth_headers,
json={
"name": "New Trigger",
"description": "Test trigger",
@@ -563,11 +579,11 @@ class TestTriggerAPI:
assert data["id"] == test_trigger.id
assert data["name"] == test_trigger.name
def test_update_trigger(self, client, test_user_token, test_trigger):
def test_update_trigger(self, client, test_user_auth_headers, test_trigger):
"""Test updating a trigger."""
response = client.put(
f"/api/triggers/{test_trigger.id}",
headers={"Authorization": f"Bearer {test_user_token}"},
headers=test_user_auth_headers,
json={
"name": "Updated Trigger",
"is_active": False,
@@ -579,11 +595,11 @@ class TestTriggerAPI:
assert data["name"] == "Updated Trigger"
assert data["is_active"] is False
def test_delete_trigger(self, client, test_user_token, test_trigger):
def test_delete_trigger(self, client, test_user_auth_headers, test_trigger, test_user_token):
"""Test deleting a trigger."""
response = client.delete(
f"/api/triggers/{test_trigger.id}",
headers={"Authorization": f"Bearer {test_user_token}"},
headers=test_user_auth_headers,
)
assert response.status_code == 204
@@ -616,11 +632,11 @@ class TestTriggerAPI:
data = response.json()
assert data["total"] >= 1
def test_create_trigger_invalid_field(self, client, test_user_token, test_project):
def test_create_trigger_invalid_field(self, client, test_user_auth_headers, test_project):
"""Test creating a trigger with invalid field."""
response = client.post(
f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"},
headers=test_user_auth_headers,
json={
"name": "Invalid Trigger",
"trigger_type": "field_change",
@@ -636,11 +652,11 @@ class TestTriggerAPI:
assert response.status_code == 400
assert "Invalid condition field" in response.json()["detail"]
def test_create_trigger_invalid_operator(self, client, test_user_token, test_project):
def test_create_trigger_invalid_operator(self, client, test_user_auth_headers, test_project):
"""Test creating a trigger with invalid operator."""
response = client.post(
f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"},
headers=test_user_auth_headers,
json={
"name": "Invalid Trigger",
"trigger_type": "field_change",

View File

@@ -1,6 +1,7 @@
import pytest
from app.models.user import User
from app.models.department import Department
from app.core.security import generate_csrf_token
class TestUserEndpoints:
@@ -35,7 +36,7 @@ class TestUserEndpoints:
)
assert response.status_code == 404
def test_update_user(self, client, admin_token, db):
def test_update_user(self, client, auth_headers, db):
"""Test updating a user."""
# Create a test user
test_user = User(
@@ -49,7 +50,7 @@ class TestUserEndpoints:
response = client.patch(
"/api/users/test-user-001",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={"name": "Updated Name"},
)
assert response.status_code == 200
@@ -84,9 +85,10 @@ class TestUserEndpoints:
mock_redis.setex("session:non-admin-001", 900, token)
# Try to modify system admin - should fail with 403
csrf_token = generate_csrf_token("non-admin-001")
response = client.patch(
"/api/users/00000000-0000-0000-0000-000000000001",
headers={"Authorization": f"Bearer {token}"},
headers={"Authorization": f"Bearer {token}", "X-CSRF-Token": csrf_token},
json={"name": "Hacked Name"},
)
# Engineer role doesn't have users.write permission
@@ -123,16 +125,17 @@ class TestCapacityUpdate:
mock_redis.setex("session:capacity-user-001", 900, token)
# Update own capacity
csrf_token = generate_csrf_token("capacity-user-001")
response = client.put(
"/api/users/capacity-user-001/capacity",
headers={"Authorization": f"Bearer {token}"},
headers={"Authorization": f"Bearer {token}", "X-CSRF-Token": csrf_token},
json={"capacity_hours": 35.5},
)
assert response.status_code == 200
data = response.json()
assert float(data["capacity"]) == 35.5
def test_admin_can_update_other_user_capacity(self, client, admin_token, db):
def test_admin_can_update_other_user_capacity(self, client, auth_headers, db):
"""Test that admin can update another user's capacity."""
# Create a test user
test_user = User(
@@ -148,7 +151,7 @@ class TestCapacityUpdate:
# Admin updates another user's capacity
response = client.put(
"/api/users/capacity-user-002/capacity",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={"capacity_hours": 20.0},
)
assert response.status_code == 200
@@ -189,15 +192,16 @@ class TestCapacityUpdate:
mock_redis.setex("session:capacity-user-003", 900, token)
# User1 tries to update user2's capacity - should fail
csrf_token = generate_csrf_token("capacity-user-003")
response = client.put(
"/api/users/capacity-user-004/capacity",
headers={"Authorization": f"Bearer {token}"},
headers={"Authorization": f"Bearer {token}", "X-CSRF-Token": csrf_token},
json={"capacity_hours": 30.0},
)
assert response.status_code == 403
assert "Only admin, manager, or the user themselves" in response.json()["detail"]
def test_update_capacity_invalid_value_negative(self, client, admin_token, db):
def test_update_capacity_invalid_value_negative(self, client, auth_headers, db):
"""Test that negative capacity hours are rejected."""
# Create a test user
test_user = User(
@@ -212,7 +216,7 @@ class TestCapacityUpdate:
response = client.put(
"/api/users/capacity-user-005/capacity",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={"capacity_hours": -5.0},
)
# Pydantic validation returns 422 Unprocessable Entity
@@ -221,7 +225,7 @@ class TestCapacityUpdate:
# Check validation error message in Pydantic format
assert any("non-negative" in str(err).lower() for err in error_detail)
def test_update_capacity_invalid_value_too_high(self, client, admin_token, db):
def test_update_capacity_invalid_value_too_high(self, client, auth_headers, db):
"""Test that capacity hours exceeding 168 are rejected."""
# Create a test user
test_user = User(
@@ -236,7 +240,7 @@ class TestCapacityUpdate:
response = client.put(
"/api/users/capacity-user-006/capacity",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={"capacity_hours": 200.0},
)
# Pydantic validation returns 422 Unprocessable Entity
@@ -245,11 +249,11 @@ class TestCapacityUpdate:
# Check validation error message in Pydantic format
assert any("168" in str(err) for err in error_detail)
def test_update_capacity_nonexistent_user(self, client, admin_token):
def test_update_capacity_nonexistent_user(self, client, auth_headers):
"""Test updating capacity for a nonexistent user."""
response = client.put(
"/api/users/nonexistent-user-id/capacity",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={"capacity_hours": 40.0},
)
assert response.status_code == 404
@@ -303,16 +307,17 @@ class TestCapacityUpdate:
mock_redis.setex("session:manager-cap-001", 900, token)
# Manager updates regular user's capacity
csrf_token = generate_csrf_token("manager-cap-001")
response = client.put(
"/api/users/regular-cap-001/capacity",
headers={"Authorization": f"Bearer {token}"},
headers={"Authorization": f"Bearer {token}", "X-CSRF-Token": csrf_token},
json={"capacity_hours": 30.0},
)
assert response.status_code == 200
data = response.json()
assert float(data["capacity"]) == 30.0
def test_capacity_change_creates_audit_log(self, client, admin_token, db):
def test_capacity_change_creates_audit_log(self, client, auth_headers, db):
"""Test that capacity changes are recorded in audit trail."""
from app.models import AuditLog
@@ -330,7 +335,7 @@ class TestCapacityUpdate:
# Update capacity
response = client.put(
"/api/users/capacity-audit-001/capacity",
headers={"Authorization": f"Bearer {admin_token}"},
headers=auth_headers,
json={"capacity_hours": 35.0},
)
assert response.status_code == 200

View File

@@ -449,7 +449,7 @@ class TestWorkloadAPI:
def test_unauthorized_access(self, client, db):
"""Unauthenticated requests should fail."""
response = client.get("/api/workload/heatmap")
assert response.status_code == 403 # No auth header
assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized
class TestWorkloadAccessControl: