feat: implement 8 OpenSpec proposals for security, reliability, and UX improvements
## Security Enhancements (P0) - Add input validation with max_length and numeric range constraints - Implement WebSocket token authentication via first message - Add path traversal prevention in file storage service ## Permission Enhancements (P0) - Add project member management for cross-department access - Implement is_department_manager flag for workload visibility ## Cycle Detection (P0) - Add DFS-based cycle detection for task dependencies - Add formula field circular reference detection - Display user-friendly cycle path visualization ## Concurrency & Reliability (P1) - Implement optimistic locking with version field (409 Conflict on mismatch) - Add trigger retry mechanism with exponential backoff (1s, 2s, 4s) - Implement cascade restore for soft-deleted tasks ## Rate Limiting (P1) - Add tiered rate limits: standard (60/min), sensitive (20/min), heavy (5/min) - Apply rate limits to tasks, reports, attachments, and comments ## Frontend Improvements (P1) - Add responsive sidebar with hamburger menu for mobile - Improve touch-friendly UI with proper tap target sizes - Complete i18n translations for all components ## Backend Reliability (P2) - Configure database connection pool (size=10, overflow=20) - Add Redis fallback mechanism with message queue - Add blocker check before task deletion ## API Enhancements (P3) - Add standardized response wrapper utility - Add /health/ready and /health/live endpoints - Implement project templates with status/field copying ## Tests Added - test_input_validation.py - Schema and path traversal tests - test_concurrency_reliability.py - Optimistic locking and retry tests - test_backend_reliability.py - Connection pool and Redis tests - test_api_enhancements.py - Health check and template tests Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -115,6 +115,13 @@ class Settings(BaseSettings):
|
||||
"exe", "bat", "cmd", "sh", "ps1", "dll", "msi", "com", "scr", "vbs", "js"
|
||||
]
|
||||
|
||||
# Rate Limiting Configuration
|
||||
# Tiers: standard, sensitive, heavy
|
||||
# Format: "{requests}/{period}" (e.g., "60/minute", "20/minute", "5/minute")
|
||||
RATE_LIMIT_STANDARD: str = "60/minute" # Task CRUD, comments
|
||||
RATE_LIMIT_SENSITIVE: str = "20/minute" # Attachments, password change, report export
|
||||
RATE_LIMIT_HEAVY: str = "5/minute" # Report generation, bulk operations
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
@@ -1,19 +1,109 @@
|
||||
from sqlalchemy import create_engine
|
||||
import logging
|
||||
import threading
|
||||
import os
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Connection pool configuration with environment variable overrides
|
||||
POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "10"))
|
||||
MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "20"))
|
||||
POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
|
||||
POOL_STATS_INTERVAL = int(os.getenv("DB_POOL_STATS_INTERVAL", "300")) # 5 minutes
|
||||
|
||||
engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_timeout=POOL_TIMEOUT,
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# Connection pool statistics tracking
|
||||
_pool_stats_lock = threading.Lock()
|
||||
_pool_stats = {
|
||||
"checkouts": 0,
|
||||
"checkins": 0,
|
||||
"overflow_connections": 0,
|
||||
"invalidated_connections": 0,
|
||||
}
|
||||
|
||||
|
||||
def _log_pool_statistics():
|
||||
"""Log current connection pool statistics."""
|
||||
pool = engine.pool
|
||||
with _pool_stats_lock:
|
||||
logger.info(
|
||||
"Database connection pool statistics: "
|
||||
"size=%d, checked_in=%d, overflow=%d, "
|
||||
"total_checkouts=%d, total_checkins=%d, invalidated=%d",
|
||||
pool.size(),
|
||||
pool.checkedin(),
|
||||
pool.overflow(),
|
||||
_pool_stats["checkouts"],
|
||||
_pool_stats["checkins"],
|
||||
_pool_stats["invalidated_connections"],
|
||||
)
|
||||
|
||||
|
||||
def _start_pool_stats_logging():
|
||||
"""Start periodic logging of connection pool statistics."""
|
||||
if POOL_STATS_INTERVAL <= 0:
|
||||
return
|
||||
|
||||
def log_stats():
|
||||
_log_pool_statistics()
|
||||
# Schedule next log
|
||||
timer = threading.Timer(POOL_STATS_INTERVAL, log_stats)
|
||||
timer.daemon = True
|
||||
timer.start()
|
||||
|
||||
# Start the first timer
|
||||
timer = threading.Timer(POOL_STATS_INTERVAL, log_stats)
|
||||
timer.daemon = True
|
||||
timer.start()
|
||||
logger.info(
|
||||
"Database connection pool initialized: pool_size=%d, max_overflow=%d, pool_timeout=%d, stats_interval=%ds",
|
||||
POOL_SIZE, MAX_OVERFLOW, POOL_TIMEOUT, POOL_STATS_INTERVAL
|
||||
)
|
||||
|
||||
|
||||
# Register pool event listeners for statistics
|
||||
@event.listens_for(engine, "checkout")
|
||||
def _on_checkout(dbapi_conn, connection_record, connection_proxy):
|
||||
"""Track connection checkout events."""
|
||||
with _pool_stats_lock:
|
||||
_pool_stats["checkouts"] += 1
|
||||
|
||||
|
||||
@event.listens_for(engine, "checkin")
|
||||
def _on_checkin(dbapi_conn, connection_record):
|
||||
"""Track connection checkin events."""
|
||||
with _pool_stats_lock:
|
||||
_pool_stats["checkins"] += 1
|
||||
|
||||
|
||||
@event.listens_for(engine, "invalidate")
|
||||
def _on_invalidate(dbapi_conn, connection_record, exception):
|
||||
"""Track connection invalidation events."""
|
||||
with _pool_stats_lock:
|
||||
_pool_stats["invalidated_connections"] += 1
|
||||
if exception:
|
||||
logger.warning("Database connection invalidated due to exception: %s", exception)
|
||||
|
||||
|
||||
# Start pool statistics logging on module load
|
||||
_start_pool_stats_logging()
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Dependency for getting database session."""
|
||||
@@ -22,3 +112,18 @@ def get_db():
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_pool_status() -> dict:
|
||||
"""Get current connection pool status for health checks."""
|
||||
pool = engine.pool
|
||||
with _pool_stats_lock:
|
||||
return {
|
||||
"pool_size": pool.size(),
|
||||
"checked_in": pool.checkedin(),
|
||||
"checked_out": pool.checkedout(),
|
||||
"overflow": pool.overflow(),
|
||||
"total_checkouts": _pool_stats["checkouts"],
|
||||
"total_checkins": _pool_stats["checkins"],
|
||||
"invalidated_connections": _pool_stats["invalidated_connections"],
|
||||
}
|
||||
|
||||
45
backend/app/core/deprecation.py
Normal file
45
backend/app/core/deprecation.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Deprecation middleware for legacy API routes.
|
||||
|
||||
Provides middleware to add deprecation warning headers to legacy /api/ routes
|
||||
during the transition to /api/v1/.
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeprecationMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to add deprecation headers to legacy API routes.
|
||||
|
||||
This middleware checks if a request is using a legacy /api/ route
|
||||
(instead of /api/v1/) and adds appropriate deprecation headers to
|
||||
encourage migration to the new versioned API.
|
||||
"""
|
||||
|
||||
# Sunset date for legacy routes (6 months from now, adjust as needed)
|
||||
SUNSET_DATE = "2026-07-01T00:00:00Z"
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
|
||||
# Check if this is a legacy /api/ route (not /api/v1/)
|
||||
path = request.url.path
|
||||
if path.startswith("/api/") and not path.startswith("/api/v1/"):
|
||||
# Skip deprecation headers for health check endpoints
|
||||
if path in ["/health", "/health/ready", "/health/live", "/health/detailed"]:
|
||||
return response
|
||||
|
||||
# Add deprecation headers (RFC 8594)
|
||||
response.headers["Deprecation"] = "true"
|
||||
response.headers["Sunset"] = self.SUNSET_DATE
|
||||
response.headers["Link"] = f'</api/v1{path[4:]}>; rel="successor-version"'
|
||||
response.headers["X-Deprecation-Notice"] = (
|
||||
"This API endpoint is deprecated. "
|
||||
"Please migrate to /api/v1/ prefix. "
|
||||
f"This endpoint will be removed after {self.SUNSET_DATE}."
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -3,11 +3,19 @@ Rate limiting configuration using slowapi with Redis backend.
|
||||
|
||||
This module provides rate limiting functionality to protect against
|
||||
brute force attacks and DoS attempts on sensitive endpoints.
|
||||
|
||||
Rate Limit Tiers:
|
||||
- standard: 60/minute - For normal CRUD operations (tasks, comments)
|
||||
- sensitive: 20/minute - For sensitive operations (attachments, password change)
|
||||
- heavy: 5/minute - For resource-intensive operations (reports, bulk operations)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
|
||||
from fastapi import Request, Response
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
@@ -60,8 +68,56 @@ _storage_uri = _get_storage_uri()
|
||||
|
||||
# Create limiter instance with appropriate storage
|
||||
# Uses the client's remote address (IP) as the key for rate limiting
|
||||
# Note: headers_enabled=False because slowapi's header injection requires Response objects,
|
||||
# which conflicts with endpoints that return Pydantic models directly.
|
||||
# Rate limit status can be checked via the 429 Too Many Requests response.
|
||||
limiter = Limiter(
|
||||
key_func=get_remote_address,
|
||||
storage_uri=_storage_uri,
|
||||
strategy="fixed-window", # Fixed window strategy for predictable rate limiting
|
||||
headers_enabled=False, # Disabled due to compatibility issues with Pydantic model responses
|
||||
)
|
||||
|
||||
|
||||
# Convenience functions for rate limit tiers
|
||||
def get_rate_limit_standard() -> str:
|
||||
"""Get the standard rate limit tier (60/minute by default)."""
|
||||
return settings.RATE_LIMIT_STANDARD
|
||||
|
||||
|
||||
def get_rate_limit_sensitive() -> str:
|
||||
"""Get the sensitive rate limit tier (20/minute by default)."""
|
||||
return settings.RATE_LIMIT_SENSITIVE
|
||||
|
||||
|
||||
def get_rate_limit_heavy() -> str:
|
||||
"""Get the heavy rate limit tier (5/minute by default)."""
|
||||
return settings.RATE_LIMIT_HEAVY
|
||||
|
||||
|
||||
# Pre-configured rate limit decorators for common use cases
|
||||
def rate_limit_standard(func: Optional[Callable] = None):
|
||||
"""
|
||||
Apply standard rate limit (60/minute) for normal CRUD operations.
|
||||
|
||||
Use for: Task creation/update, comment creation, etc.
|
||||
"""
|
||||
return limiter.limit(get_rate_limit_standard())(func) if func else limiter.limit(get_rate_limit_standard())
|
||||
|
||||
|
||||
def rate_limit_sensitive(func: Optional[Callable] = None):
|
||||
"""
|
||||
Apply sensitive rate limit (20/minute) for sensitive operations.
|
||||
|
||||
Use for: File uploads, password changes, report exports, etc.
|
||||
"""
|
||||
return limiter.limit(get_rate_limit_sensitive())(func) if func else limiter.limit(get_rate_limit_sensitive())
|
||||
|
||||
|
||||
def rate_limit_heavy(func: Optional[Callable] = None):
|
||||
"""
|
||||
Apply heavy rate limit (5/minute) for resource-intensive operations.
|
||||
|
||||
Use for: Report generation, bulk operations, data exports, etc.
|
||||
"""
|
||||
return limiter.limit(get_rate_limit_heavy())(func) if func else limiter.limit(get_rate_limit_heavy())
|
||||
|
||||
178
backend/app/core/response.py
Normal file
178
backend/app/core/response.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Standardized API response wrapper.
|
||||
|
||||
Provides utility classes and functions for consistent API response formatting
|
||||
across all endpoints.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""Detailed error information."""
|
||||
error_code: str = Field(..., description="Machine-readable error code")
|
||||
message: str = Field(..., description="Human-readable error message")
|
||||
field: Optional[str] = Field(None, description="Field that caused the error, if applicable")
|
||||
details: Optional[dict] = Field(None, description="Additional error details")
|
||||
|
||||
|
||||
class ApiResponse(BaseModel, Generic[T]):
|
||||
"""Standard API response wrapper.
|
||||
|
||||
All API endpoints should return responses in this format for consistency.
|
||||
|
||||
Attributes:
|
||||
success: Whether the request was successful
|
||||
data: The actual response data (null for errors)
|
||||
message: Human-readable message about the result
|
||||
timestamp: ISO 8601 timestamp of the response
|
||||
error: Error details if success is False
|
||||
"""
|
||||
success: bool = Field(..., description="Whether the request was successful")
|
||||
data: Optional[T] = Field(None, description="Response data")
|
||||
message: Optional[str] = Field(None, description="Human-readable message")
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat() + "Z",
|
||||
description="ISO 8601 timestamp"
|
||||
)
|
||||
error: Optional[ErrorDetail] = Field(None, description="Error details if failed")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PaginatedData(BaseModel, Generic[T]):
|
||||
"""Paginated data structure."""
|
||||
items: list[T] = Field(default_factory=list, description="List of items")
|
||||
total: int = Field(..., description="Total number of items")
|
||||
page: int = Field(..., description="Current page number (1-indexed)")
|
||||
page_size: int = Field(..., description="Number of items per page")
|
||||
total_pages: int = Field(..., description="Total number of pages")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# Error codes for common scenarios
|
||||
class ErrorCode:
|
||||
"""Standard error codes for API responses."""
|
||||
# Authentication & Authorization
|
||||
UNAUTHORIZED = "AUTH_001"
|
||||
FORBIDDEN = "AUTH_002"
|
||||
TOKEN_EXPIRED = "AUTH_003"
|
||||
INVALID_TOKEN = "AUTH_004"
|
||||
|
||||
# Validation
|
||||
VALIDATION_ERROR = "VAL_001"
|
||||
INVALID_INPUT = "VAL_002"
|
||||
MISSING_FIELD = "VAL_003"
|
||||
INVALID_FORMAT = "VAL_004"
|
||||
|
||||
# Resource
|
||||
NOT_FOUND = "RES_001"
|
||||
ALREADY_EXISTS = "RES_002"
|
||||
CONFLICT = "RES_003"
|
||||
DELETED = "RES_004"
|
||||
|
||||
# Business Logic
|
||||
BUSINESS_ERROR = "BIZ_001"
|
||||
INVALID_STATE = "BIZ_002"
|
||||
LIMIT_EXCEEDED = "BIZ_003"
|
||||
DEPENDENCY_ERROR = "BIZ_004"
|
||||
|
||||
# Server
|
||||
INTERNAL_ERROR = "SRV_001"
|
||||
DATABASE_ERROR = "SRV_002"
|
||||
EXTERNAL_SERVICE_ERROR = "SRV_003"
|
||||
RATE_LIMITED = "SRV_004"
|
||||
|
||||
|
||||
def success_response(
|
||||
data: Any = None,
|
||||
message: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create a successful API response.
|
||||
|
||||
Args:
|
||||
data: The response data
|
||||
message: Optional human-readable message
|
||||
|
||||
Returns:
|
||||
Dictionary with standard response structure
|
||||
"""
|
||||
return {
|
||||
"success": True,
|
||||
"data": data,
|
||||
"message": message,
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"error": None,
|
||||
}
|
||||
|
||||
|
||||
def error_response(
|
||||
error_code: str,
|
||||
message: str,
|
||||
field: Optional[str] = None,
|
||||
details: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Create an error API response.
|
||||
|
||||
Args:
|
||||
error_code: Machine-readable error code (use ErrorCode constants)
|
||||
message: Human-readable error message
|
||||
field: Optional field name that caused the error
|
||||
details: Optional additional error details
|
||||
|
||||
Returns:
|
||||
Dictionary with standard error response structure
|
||||
"""
|
||||
return {
|
||||
"success": False,
|
||||
"data": None,
|
||||
"message": message,
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"error": {
|
||||
"error_code": error_code,
|
||||
"message": message,
|
||||
"field": field,
|
||||
"details": details,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def paginated_response(
|
||||
items: list,
|
||||
total: int,
|
||||
page: int,
|
||||
page_size: int,
|
||||
message: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create a paginated API response.
|
||||
|
||||
Args:
|
||||
items: List of items for current page
|
||||
total: Total number of items across all pages
|
||||
page: Current page number (1-indexed)
|
||||
page_size: Number of items per page
|
||||
message: Optional human-readable message
|
||||
|
||||
Returns:
|
||||
Dictionary with standard paginated response structure
|
||||
"""
|
||||
total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_pages": total_pages,
|
||||
},
|
||||
"message": message,
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"error": None,
|
||||
}
|
||||
Reference in New Issue
Block a user