feat: implement 8 OpenSpec proposals for security, reliability, and UX improvements

## Security Enhancements (P0)
- Add input validation with max_length and numeric range constraints
- Implement WebSocket token authentication via first message
- Add path traversal prevention in file storage service

## Permission Enhancements (P0)
- Add project member management for cross-department access
- Implement is_department_manager flag for workload visibility

## Cycle Detection (P0)
- Add DFS-based cycle detection for task dependencies
- Add formula field circular reference detection
- Display user-friendly cycle path visualization

## Concurrency & Reliability (P1)
- Implement optimistic locking with version field (409 Conflict on mismatch)
- Add trigger retry mechanism with exponential backoff (1s, 2s, 4s)
- Implement cascade restore for soft-deleted tasks

## Rate Limiting (P1)
- Add tiered rate limits: standard (60/min), sensitive (20/min), heavy (5/min)
- Apply rate limits to tasks, reports, attachments, and comments

## Frontend Improvements (P1)
- Add responsive sidebar with hamburger menu for mobile
- Improve touch-friendly UI with proper tap target sizes
- Complete i18n translations for all components

## Backend Reliability (P2)
- Configure database connection pool (size=10, overflow=20)
- Add Redis fallback mechanism with message queue
- Add blocker check before task deletion

## API Enhancements (P3)
- Add standardized response wrapper utility
- Add /health/ready and /health/live endpoints
- Implement project templates with status/field copying

## Tests Added
- test_input_validation.py - Schema and path traversal tests
- test_concurrency_reliability.py - Optimistic locking and retry tests
- test_backend_reliability.py - Connection pool and Redis tests
- test_api_enhancements.py - Health check and template tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-10 22:13:43 +08:00
parent 96210c7ad4
commit 3bdc6ff1c9
106 changed files with 9704 additions and 429 deletions

View File

@@ -115,6 +115,13 @@ class Settings(BaseSettings):
"exe", "bat", "cmd", "sh", "ps1", "dll", "msi", "com", "scr", "vbs", "js"
]
# Rate Limiting Configuration
# Tiers: standard, sensitive, heavy
# Format: "{requests}/{period}" (e.g., "60/minute", "20/minute", "5/minute")
RATE_LIMIT_STANDARD: str = "60/minute" # Task CRUD, comments
RATE_LIMIT_SENSITIVE: str = "20/minute" # Attachments, password change, report export
RATE_LIMIT_HEAVY: str = "5/minute" # Report generation, bulk operations
class Config:
env_file = ".env"
case_sensitive = True

View File

@@ -1,19 +1,109 @@
from sqlalchemy import create_engine
import logging
import threading
import os
from sqlalchemy import create_engine, event
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from app.core.config import settings
logger = logging.getLogger(__name__)
# Connection pool configuration with environment variable overrides
POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "10"))
MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "20"))
POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
POOL_STATS_INTERVAL = int(os.getenv("DB_POOL_STATS_INTERVAL", "300")) # 5 minutes
engine = create_engine(
settings.DATABASE_URL,
pool_pre_ping=True,
pool_recycle=3600,
pool_size=POOL_SIZE,
max_overflow=MAX_OVERFLOW,
pool_timeout=POOL_TIMEOUT,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Connection pool statistics tracking
_pool_stats_lock = threading.Lock()
_pool_stats = {
"checkouts": 0,
"checkins": 0,
"overflow_connections": 0,
"invalidated_connections": 0,
}
def _log_pool_statistics():
"""Log current connection pool statistics."""
pool = engine.pool
with _pool_stats_lock:
logger.info(
"Database connection pool statistics: "
"size=%d, checked_in=%d, overflow=%d, "
"total_checkouts=%d, total_checkins=%d, invalidated=%d",
pool.size(),
pool.checkedin(),
pool.overflow(),
_pool_stats["checkouts"],
_pool_stats["checkins"],
_pool_stats["invalidated_connections"],
)
def _start_pool_stats_logging():
"""Start periodic logging of connection pool statistics."""
if POOL_STATS_INTERVAL <= 0:
return
def log_stats():
_log_pool_statistics()
# Schedule next log
timer = threading.Timer(POOL_STATS_INTERVAL, log_stats)
timer.daemon = True
timer.start()
# Start the first timer
timer = threading.Timer(POOL_STATS_INTERVAL, log_stats)
timer.daemon = True
timer.start()
logger.info(
"Database connection pool initialized: pool_size=%d, max_overflow=%d, pool_timeout=%d, stats_interval=%ds",
POOL_SIZE, MAX_OVERFLOW, POOL_TIMEOUT, POOL_STATS_INTERVAL
)
# Register pool event listeners for statistics
@event.listens_for(engine, "checkout")
def _on_checkout(dbapi_conn, connection_record, connection_proxy):
"""Track connection checkout events."""
with _pool_stats_lock:
_pool_stats["checkouts"] += 1
@event.listens_for(engine, "checkin")
def _on_checkin(dbapi_conn, connection_record):
"""Track connection checkin events."""
with _pool_stats_lock:
_pool_stats["checkins"] += 1
@event.listens_for(engine, "invalidate")
def _on_invalidate(dbapi_conn, connection_record, exception):
"""Track connection invalidation events."""
with _pool_stats_lock:
_pool_stats["invalidated_connections"] += 1
if exception:
logger.warning("Database connection invalidated due to exception: %s", exception)
# Start pool statistics logging on module load
_start_pool_stats_logging()
def get_db():
"""Dependency for getting database session."""
@@ -22,3 +112,18 @@ def get_db():
yield db
finally:
db.close()
def get_pool_status() -> dict:
"""Get current connection pool status for health checks."""
pool = engine.pool
with _pool_stats_lock:
return {
"pool_size": pool.size(),
"checked_in": pool.checkedin(),
"checked_out": pool.checkedout(),
"overflow": pool.overflow(),
"total_checkouts": _pool_stats["checkouts"],
"total_checkins": _pool_stats["checkins"],
"invalidated_connections": _pool_stats["invalidated_connections"],
}

View File

@@ -0,0 +1,45 @@
"""Deprecation middleware for legacy API routes.
Provides middleware to add deprecation warning headers to legacy /api/ routes
during the transition to /api/v1/.
"""
import logging
from datetime import datetime
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
logger = logging.getLogger(__name__)
class DeprecationMiddleware(BaseHTTPMiddleware):
"""Middleware to add deprecation headers to legacy API routes.
This middleware checks if a request is using a legacy /api/ route
(instead of /api/v1/) and adds appropriate deprecation headers to
encourage migration to the new versioned API.
"""
# Sunset date for legacy routes (6 months from now, adjust as needed)
SUNSET_DATE = "2026-07-01T00:00:00Z"
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Check if this is a legacy /api/ route (not /api/v1/)
path = request.url.path
if path.startswith("/api/") and not path.startswith("/api/v1/"):
# Skip deprecation headers for health check endpoints
if path in ["/health", "/health/ready", "/health/live", "/health/detailed"]:
return response
# Add deprecation headers (RFC 8594)
response.headers["Deprecation"] = "true"
response.headers["Sunset"] = self.SUNSET_DATE
response.headers["Link"] = f'</api/v1{path[4:]}>; rel="successor-version"'
response.headers["X-Deprecation-Notice"] = (
"This API endpoint is deprecated. "
"Please migrate to /api/v1/ prefix. "
f"This endpoint will be removed after {self.SUNSET_DATE}."
)
return response

View File

@@ -3,11 +3,19 @@ Rate limiting configuration using slowapi with Redis backend.
This module provides rate limiting functionality to protect against
brute force attacks and DoS attempts on sensitive endpoints.
Rate Limit Tiers:
- standard: 60/minute - For normal CRUD operations (tasks, comments)
- sensitive: 20/minute - For sensitive operations (attachments, password change)
- heavy: 5/minute - For resource-intensive operations (reports, bulk operations)
"""
import logging
import os
from functools import wraps
from typing import Callable, Optional
from fastapi import Request, Response
from slowapi import Limiter
from slowapi.util import get_remote_address
@@ -60,8 +68,56 @@ _storage_uri = _get_storage_uri()
# Create limiter instance with appropriate storage
# Uses the client's remote address (IP) as the key for rate limiting
# Note: headers_enabled=False because slowapi's header injection requires Response objects,
# which conflicts with endpoints that return Pydantic models directly.
# Rate limit status can be checked via the 429 Too Many Requests response.
limiter = Limiter(
key_func=get_remote_address,
storage_uri=_storage_uri,
strategy="fixed-window", # Fixed window strategy for predictable rate limiting
headers_enabled=False, # Disabled due to compatibility issues with Pydantic model responses
)
# Convenience functions for rate limit tiers
def get_rate_limit_standard() -> str:
"""Get the standard rate limit tier (60/minute by default)."""
return settings.RATE_LIMIT_STANDARD
def get_rate_limit_sensitive() -> str:
"""Get the sensitive rate limit tier (20/minute by default)."""
return settings.RATE_LIMIT_SENSITIVE
def get_rate_limit_heavy() -> str:
"""Get the heavy rate limit tier (5/minute by default)."""
return settings.RATE_LIMIT_HEAVY
# Pre-configured rate limit decorators for common use cases
def rate_limit_standard(func: Optional[Callable] = None):
"""
Apply standard rate limit (60/minute) for normal CRUD operations.
Use for: Task creation/update, comment creation, etc.
"""
return limiter.limit(get_rate_limit_standard())(func) if func else limiter.limit(get_rate_limit_standard())
def rate_limit_sensitive(func: Optional[Callable] = None):
"""
Apply sensitive rate limit (20/minute) for sensitive operations.
Use for: File uploads, password changes, report exports, etc.
"""
return limiter.limit(get_rate_limit_sensitive())(func) if func else limiter.limit(get_rate_limit_sensitive())
def rate_limit_heavy(func: Optional[Callable] = None):
"""
Apply heavy rate limit (5/minute) for resource-intensive operations.
Use for: Report generation, bulk operations, data exports, etc.
"""
return limiter.limit(get_rate_limit_heavy())(func) if func else limiter.limit(get_rate_limit_heavy())

View File

@@ -0,0 +1,178 @@
"""Standardized API response wrapper.
Provides utility classes and functions for consistent API response formatting
across all endpoints.
"""
from datetime import datetime
from typing import Any, Generic, Optional, TypeVar
from pydantic import BaseModel, Field
T = TypeVar("T")
class ErrorDetail(BaseModel):
"""Detailed error information."""
error_code: str = Field(..., description="Machine-readable error code")
message: str = Field(..., description="Human-readable error message")
field: Optional[str] = Field(None, description="Field that caused the error, if applicable")
details: Optional[dict] = Field(None, description="Additional error details")
class ApiResponse(BaseModel, Generic[T]):
"""Standard API response wrapper.
All API endpoints should return responses in this format for consistency.
Attributes:
success: Whether the request was successful
data: The actual response data (null for errors)
message: Human-readable message about the result
timestamp: ISO 8601 timestamp of the response
error: Error details if success is False
"""
success: bool = Field(..., description="Whether the request was successful")
data: Optional[T] = Field(None, description="Response data")
message: Optional[str] = Field(None, description="Human-readable message")
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat() + "Z",
description="ISO 8601 timestamp"
)
error: Optional[ErrorDetail] = Field(None, description="Error details if failed")
class Config:
from_attributes = True
class PaginatedData(BaseModel, Generic[T]):
"""Paginated data structure."""
items: list[T] = Field(default_factory=list, description="List of items")
total: int = Field(..., description="Total number of items")
page: int = Field(..., description="Current page number (1-indexed)")
page_size: int = Field(..., description="Number of items per page")
total_pages: int = Field(..., description="Total number of pages")
class Config:
from_attributes = True
# Error codes for common scenarios
class ErrorCode:
"""Standard error codes for API responses."""
# Authentication & Authorization
UNAUTHORIZED = "AUTH_001"
FORBIDDEN = "AUTH_002"
TOKEN_EXPIRED = "AUTH_003"
INVALID_TOKEN = "AUTH_004"
# Validation
VALIDATION_ERROR = "VAL_001"
INVALID_INPUT = "VAL_002"
MISSING_FIELD = "VAL_003"
INVALID_FORMAT = "VAL_004"
# Resource
NOT_FOUND = "RES_001"
ALREADY_EXISTS = "RES_002"
CONFLICT = "RES_003"
DELETED = "RES_004"
# Business Logic
BUSINESS_ERROR = "BIZ_001"
INVALID_STATE = "BIZ_002"
LIMIT_EXCEEDED = "BIZ_003"
DEPENDENCY_ERROR = "BIZ_004"
# Server
INTERNAL_ERROR = "SRV_001"
DATABASE_ERROR = "SRV_002"
EXTERNAL_SERVICE_ERROR = "SRV_003"
RATE_LIMITED = "SRV_004"
def success_response(
data: Any = None,
message: Optional[str] = None,
) -> dict:
"""Create a successful API response.
Args:
data: The response data
message: Optional human-readable message
Returns:
Dictionary with standard response structure
"""
return {
"success": True,
"data": data,
"message": message,
"timestamp": datetime.utcnow().isoformat() + "Z",
"error": None,
}
def error_response(
error_code: str,
message: str,
field: Optional[str] = None,
details: Optional[dict] = None,
) -> dict:
"""Create an error API response.
Args:
error_code: Machine-readable error code (use ErrorCode constants)
message: Human-readable error message
field: Optional field name that caused the error
details: Optional additional error details
Returns:
Dictionary with standard error response structure
"""
return {
"success": False,
"data": None,
"message": message,
"timestamp": datetime.utcnow().isoformat() + "Z",
"error": {
"error_code": error_code,
"message": message,
"field": field,
"details": details,
},
}
def paginated_response(
items: list,
total: int,
page: int,
page_size: int,
message: Optional[str] = None,
) -> dict:
"""Create a paginated API response.
Args:
items: List of items for current page
total: Total number of items across all pages
page: Current page number (1-indexed)
page_size: Number of items per page
message: Optional human-readable message
Returns:
Dictionary with standard paginated response structure
"""
total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0
return {
"success": True,
"data": {
"items": items,
"total": total,
"page": page,
"page_size": page_size,
"total_pages": total_pages,
},
"message": message,
"timestamp": datetime.utcnow().isoformat() + "Z",
"error": None,
}