feat: implement 8 OpenSpec proposals for security, reliability, and UX improvements

## Security Enhancements (P0)
- Add input validation with max_length and numeric range constraints
- Implement WebSocket token authentication via first message
- Add path traversal prevention in file storage service

## Permission Enhancements (P0)
- Add project member management for cross-department access
- Implement is_department_manager flag for workload visibility

## Cycle Detection (P0)
- Add DFS-based cycle detection for task dependencies
- Add formula field circular reference detection
- Display user-friendly cycle path visualization

## Concurrency & Reliability (P1)
- Implement optimistic locking with version field (409 Conflict on mismatch)
- Add trigger retry mechanism with exponential backoff (1s, 2s, 4s)
- Implement cascade restore for soft-deleted tasks

## Rate Limiting (P1)
- Add tiered rate limits: standard (60/min), sensitive (20/min), heavy (5/min)
- Apply rate limits to tasks, reports, attachments, and comments

## Frontend Improvements (P1)
- Add responsive sidebar with hamburger menu for mobile
- Improve touch-friendly UI with proper tap target sizes
- Complete i18n translations for all components

## Backend Reliability (P2)
- Configure database connection pool (size=10, overflow=20)
- Add Redis fallback mechanism with message queue
- Add blocker check before task deletion

## API Enhancements (P3)
- Add standardized response wrapper utility
- Add /health/ready and /health/live endpoints
- Implement project templates with status/field copying

## Tests Added
- test_input_validation.py - Schema and path traversal tests
- test_concurrency_reliability.py - Optimistic locking and retry tests
- test_backend_reliability.py - Connection pool and Redis tests
- test_api_enhancements.py - Health check and template tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-10 22:13:43 +08:00
parent 96210c7ad4
commit 3bdc6ff1c9
106 changed files with 9704 additions and 429 deletions

View File

@@ -6,6 +6,7 @@ Handles task dependency validation including:
- Date constraint validation based on dependency types
- Self-reference prevention
- Cross-project dependency prevention
- Bulk dependency operations with cycle detection
"""
from typing import List, Optional, Set, Tuple, Dict, Any
from collections import defaultdict
@@ -25,6 +26,27 @@ class DependencyValidationError(Exception):
super().__init__(message)
class CycleDetectionResult:
"""Result of cycle detection with detailed path information."""
def __init__(
self,
has_cycle: bool,
cycle_path: Optional[List[str]] = None,
cycle_task_titles: Optional[List[str]] = None
):
self.has_cycle = has_cycle
self.cycle_path = cycle_path or []
self.cycle_task_titles = cycle_task_titles or []
def get_cycle_description(self) -> str:
"""Get a human-readable description of the cycle."""
if not self.has_cycle or not self.cycle_task_titles:
return ""
# Format: Task A -> Task B -> Task C -> Task A
return " -> ".join(self.cycle_task_titles)
class DependencyService:
"""Service for managing task dependencies with validation."""
@@ -53,9 +75,36 @@ class DependencyService:
Returns:
List of task IDs forming the cycle if circular, None otherwise
"""
# If adding predecessor -> successor, check if successor can reach predecessor
# This would mean predecessor depends (transitively) on successor, creating a cycle
result = DependencyService.detect_circular_dependency_detailed(
db, predecessor_id, successor_id, project_id
)
return result.cycle_path if result.has_cycle else None
@staticmethod
def detect_circular_dependency_detailed(
db: Session,
predecessor_id: str,
successor_id: str,
project_id: str,
additional_edges: Optional[List[Tuple[str, str]]] = None
) -> CycleDetectionResult:
"""
Detect if adding a dependency would create a circular reference.
Uses DFS to traverse from the successor to check if we can reach
the predecessor through existing dependencies.
Args:
db: Database session
predecessor_id: The task that must complete first
successor_id: The task that depends on the predecessor
project_id: Project ID to scope the query
additional_edges: Optional list of additional (predecessor_id, successor_id)
edges to consider (for bulk operations)
Returns:
CycleDetectionResult with detailed cycle information
"""
# Build adjacency list for the project's dependencies
dependencies = db.query(TaskDependency).join(
Task, TaskDependency.successor_id == Task.id
@@ -71,6 +120,20 @@ class DependencyService:
# Simulate adding the new edge
graph[successor_id].append(predecessor_id)
# Add any additional edges for bulk operations
if additional_edges:
for pred_id, succ_id in additional_edges:
graph[succ_id].append(pred_id)
# Build task title map for readable error messages
task_ids_in_graph = set()
for succ_id, pred_ids in graph.items():
task_ids_in_graph.add(succ_id)
task_ids_in_graph.update(pred_ids)
tasks = db.query(Task).filter(Task.id.in_(task_ids_in_graph)).all()
task_title_map: Dict[str, str] = {t.id: t.title for t in tasks}
# DFS to find if there's a path from predecessor back to successor
# (which would complete a cycle)
visited: Set[str] = set()
@@ -101,7 +164,18 @@ class DependencyService:
return None
# Start DFS from the successor to check if we can reach back to it
return dfs(successor_id)
cycle_path = dfs(successor_id)
if cycle_path:
# Build task titles for the cycle
cycle_titles = [task_title_map.get(task_id, task_id) for task_id in cycle_path]
return CycleDetectionResult(
has_cycle=True,
cycle_path=cycle_path,
cycle_task_titles=cycle_titles
)
return CycleDetectionResult(has_cycle=False)
@staticmethod
def validate_dependency(
@@ -183,15 +257,19 @@ class DependencyService:
)
# Check circular dependency
cycle = DependencyService.detect_circular_dependency(
cycle_result = DependencyService.detect_circular_dependency_detailed(
db, predecessor_id, successor_id, predecessor.project_id
)
if cycle:
if cycle_result.has_cycle:
raise DependencyValidationError(
error_type="circular",
message="Adding this dependency would create a circular reference",
details={"cycle": cycle}
message=f"Adding this dependency would create a circular reference: {cycle_result.get_cycle_description()}",
details={
"cycle": cycle_result.cycle_path,
"cycle_description": cycle_result.get_cycle_description(),
"cycle_task_titles": cycle_result.cycle_task_titles
}
)
@staticmethod
@@ -422,3 +500,202 @@ class DependencyService:
queue.append(dep.successor_id)
return successors
@staticmethod
def validate_bulk_dependencies(
db: Session,
dependencies: List[Tuple[str, str]],
project_id: str
) -> List[Dict[str, Any]]:
"""
Validate a batch of dependencies for cycle detection.
This method validates multiple dependencies together to detect cycles
that would only appear when all dependencies are added together.
Args:
db: Database session
dependencies: List of (predecessor_id, successor_id) tuples
project_id: Project ID to scope the query
Returns:
List of validation errors (empty if all valid)
"""
errors: List[Dict[str, Any]] = []
if not dependencies:
return errors
# First, validate each dependency individually for basic checks
for predecessor_id, successor_id in dependencies:
# Check self-reference
if predecessor_id == successor_id:
errors.append({
"error_type": "self_reference",
"predecessor_id": predecessor_id,
"successor_id": successor_id,
"message": "A task cannot depend on itself"
})
continue
# Get tasks to validate project membership
predecessor = db.query(Task).filter(Task.id == predecessor_id).first()
successor = db.query(Task).filter(Task.id == successor_id).first()
if not predecessor:
errors.append({
"error_type": "not_found",
"predecessor_id": predecessor_id,
"successor_id": successor_id,
"message": f"Predecessor task not found: {predecessor_id}"
})
continue
if not successor:
errors.append({
"error_type": "not_found",
"predecessor_id": predecessor_id,
"successor_id": successor_id,
"message": f"Successor task not found: {successor_id}"
})
continue
if predecessor.project_id != project_id or successor.project_id != project_id:
errors.append({
"error_type": "cross_project",
"predecessor_id": predecessor_id,
"successor_id": successor_id,
"message": "All tasks must be in the same project"
})
continue
# Check for duplicates within the batch
existing = db.query(TaskDependency).filter(
TaskDependency.predecessor_id == predecessor_id,
TaskDependency.successor_id == successor_id
).first()
if existing:
errors.append({
"error_type": "duplicate",
"predecessor_id": predecessor_id,
"successor_id": successor_id,
"message": "This dependency already exists"
})
# If there are basic validation errors, return them first
if errors:
return errors
# Now check for cycles considering all dependencies together
# Build the graph incrementally and check for cycles
accumulated_edges: List[Tuple[str, str]] = []
for predecessor_id, successor_id in dependencies:
# Check if adding this edge (plus all previously accumulated edges)
# would create a cycle
cycle_result = DependencyService.detect_circular_dependency_detailed(
db,
predecessor_id,
successor_id,
project_id,
additional_edges=accumulated_edges
)
if cycle_result.has_cycle:
errors.append({
"error_type": "circular",
"predecessor_id": predecessor_id,
"successor_id": successor_id,
"message": f"Adding this dependency would create a circular reference: {cycle_result.get_cycle_description()}",
"cycle": cycle_result.cycle_path,
"cycle_description": cycle_result.get_cycle_description(),
"cycle_task_titles": cycle_result.cycle_task_titles
})
else:
# Add this edge to accumulated edges for subsequent checks
accumulated_edges.append((predecessor_id, successor_id))
return errors
@staticmethod
def detect_cycles_in_graph(
db: Session,
project_id: str
) -> List[CycleDetectionResult]:
"""
Detect all cycles in the existing dependency graph for a project.
This is useful for auditing or cleanup operations.
Args:
db: Database session
project_id: Project ID to check
Returns:
List of CycleDetectionResult for each cycle found
"""
cycles: List[CycleDetectionResult] = []
# Get all dependencies for the project
dependencies = db.query(TaskDependency).join(
Task, TaskDependency.successor_id == Task.id
).filter(Task.project_id == project_id).all()
if not dependencies:
return cycles
# Build the graph
graph: Dict[str, List[str]] = defaultdict(list)
for dep in dependencies:
graph[dep.successor_id].append(dep.predecessor_id)
# Get task titles
task_ids = set()
for succ_id, pred_ids in graph.items():
task_ids.add(succ_id)
task_ids.update(pred_ids)
tasks = db.query(Task).filter(Task.id.in_(task_ids)).all()
task_title_map: Dict[str, str] = {t.id: t.title for t in tasks}
# Find all cycles using DFS
visited: Set[str] = set()
found_cycles: Set[Tuple[str, ...]] = set()
def find_cycles_dfs(node: str, path: List[str], in_path: Set[str]):
"""DFS to find all cycles."""
if node in in_path:
# Found a cycle
cycle_start = path.index(node)
cycle = tuple(sorted(path[cycle_start:])) # Normalize for dedup
if cycle not in found_cycles:
found_cycles.add(cycle)
actual_cycle = path[cycle_start:] + [node]
cycle_titles = [task_title_map.get(tid, tid) for tid in actual_cycle]
cycles.append(CycleDetectionResult(
has_cycle=True,
cycle_path=actual_cycle,
cycle_task_titles=cycle_titles
))
return
if node in visited:
return
visited.add(node)
in_path.add(node)
path.append(node)
for neighbor in graph.get(node, []):
find_cycles_dfs(neighbor, path.copy(), in_path.copy())
path.pop()
in_path.remove(node)
# Start DFS from all nodes
for start_node in graph.keys():
if start_node not in visited:
find_cycles_dfs(start_node, [], set())
return cycles

View File

@@ -1,26 +1,271 @@
import os
import hashlib
import shutil
import logging
import tempfile
from pathlib import Path
from typing import BinaryIO, Optional, Tuple
from fastapi import UploadFile, HTTPException
from app.core.config import settings
logger = logging.getLogger(__name__)
class PathTraversalError(Exception):
"""Raised when a path traversal attempt is detected."""
pass
class StorageValidationError(Exception):
"""Raised when storage validation fails."""
pass
class FileStorageService:
"""Service for handling file storage operations."""
# Common NAS mount points to detect
NAS_MOUNT_INDICATORS = [
"/mnt/", "/mount/", "/nas/", "/nfs/", "/smb/", "/cifs/",
"/Volumes/", "/media/", "/srv/", "/storage/"
]
def __init__(self):
self.base_dir = Path(settings.UPLOAD_DIR)
self._ensure_base_dir()
self.base_dir = Path(settings.UPLOAD_DIR).resolve()
self._storage_status = {
"validated": False,
"path_exists": False,
"writable": False,
"is_nas": False,
"error": None,
}
self._validate_storage_on_init()
def _validate_storage_on_init(self):
"""Validate storage configuration on service initialization."""
try:
# Step 1: Ensure directory exists
self._ensure_base_dir()
self._storage_status["path_exists"] = True
# Step 2: Check write permissions
self._check_write_permissions()
self._storage_status["writable"] = True
# Step 3: Check if using NAS
is_nas = self._detect_nas_storage()
self._storage_status["is_nas"] = is_nas
if not is_nas:
logger.warning(
"Storage directory '%s' appears to be local storage, not NAS. "
"Consider configuring UPLOAD_DIR to a NAS mount point for production use.",
self.base_dir
)
self._storage_status["validated"] = True
logger.info(
"Storage validated successfully: path=%s, is_nas=%s",
self.base_dir, is_nas
)
except StorageValidationError as e:
self._storage_status["error"] = str(e)
logger.error("Storage validation failed: %s", e)
raise
except Exception as e:
self._storage_status["error"] = str(e)
logger.error("Unexpected error during storage validation: %s", e)
raise StorageValidationError(f"Storage validation failed: {e}")
def _check_write_permissions(self):
"""Check if the storage directory has write permissions."""
test_file = self.base_dir / f".write_test_{os.getpid()}"
try:
# Try to create and write to a test file
test_file.write_text("write_test")
# Verify we can read it back
content = test_file.read_text()
if content != "write_test":
raise StorageValidationError(
f"Write verification failed for directory: {self.base_dir}"
)
# Clean up
test_file.unlink()
logger.debug("Write permission check passed for %s", self.base_dir)
except PermissionError as e:
raise StorageValidationError(
f"No write permission for storage directory '{self.base_dir}': {e}"
)
except OSError as e:
raise StorageValidationError(
f"Failed to verify write permissions for '{self.base_dir}': {e}"
)
finally:
# Ensure test file is removed even on partial failure
if test_file.exists():
try:
test_file.unlink()
except Exception:
pass
def _detect_nas_storage(self) -> bool:
"""
Detect if the storage directory appears to be on a NAS mount.
This is a best-effort detection based on common mount point patterns.
"""
path_str = str(self.base_dir)
# Check common NAS mount point patterns
for indicator in self.NAS_MOUNT_INDICATORS:
if indicator in path_str:
logger.debug("NAS storage detected: path contains '%s'", indicator)
return True
# Check if it's a mount point (Unix-like systems)
try:
if self.base_dir.is_mount():
logger.debug("NAS storage detected: path is a mount point")
return True
except Exception:
pass
# Check mount info on Linux
try:
with open("/proc/mounts", "r") as f:
mounts = f.read()
if path_str in mounts:
# Check for network filesystem types
for line in mounts.splitlines():
if path_str in line:
fs_type = line.split()[2] if len(line.split()) > 2 else ""
if fs_type in ["nfs", "nfs4", "cifs", "smb", "smbfs"]:
logger.debug("NAS storage detected: mounted as %s", fs_type)
return True
except FileNotFoundError:
pass # Not on Linux
except Exception as e:
logger.debug("Could not check /proc/mounts: %s", e)
return False
def get_storage_status(self) -> dict:
"""Get current storage status for health checks."""
return {
**self._storage_status,
"base_dir": str(self.base_dir),
"exists": self.base_dir.exists(),
"is_directory": self.base_dir.is_dir() if self.base_dir.exists() else False,
}
def _ensure_base_dir(self):
"""Ensure the base upload directory exists."""
self.base_dir.mkdir(parents=True, exist_ok=True)
def _validate_path_component(self, component: str, component_name: str) -> None:
"""
Validate a path component to prevent path traversal attacks.
Args:
component: The path component to validate (e.g., project_id, task_id)
component_name: Name of the component for error messages
Raises:
PathTraversalError: If the component contains path traversal patterns
"""
if not component:
raise PathTraversalError(f"Empty {component_name} is not allowed")
# Check for path traversal patterns
dangerous_patterns = ['..', '/', '\\', '\x00']
for pattern in dangerous_patterns:
if pattern in component:
logger.warning(
"Path traversal attempt detected in %s: %r",
component_name,
component
)
raise PathTraversalError(
f"Invalid characters in {component_name}: path traversal not allowed"
)
# Additional check: component should not start with special characters
if component.startswith('.') or component.startswith('-'):
logger.warning(
"Suspicious path component in %s: %r",
component_name,
component
)
raise PathTraversalError(
f"Invalid {component_name}: cannot start with '.' or '-'"
)
def _validate_path_in_base_dir(self, path: Path, context: str = "") -> Path:
"""
Validate that a resolved path is within the base directory.
Args:
path: The path to validate
context: Additional context for logging
Returns:
The resolved path if valid
Raises:
PathTraversalError: If the path is outside the base directory
"""
resolved_path = path.resolve()
# Check if the resolved path is within the base directory
try:
resolved_path.relative_to(self.base_dir)
except ValueError:
logger.warning(
"Path traversal attempt detected: path %s is outside base directory %s. Context: %s",
resolved_path,
self.base_dir,
context
)
raise PathTraversalError(
"Access denied: path is outside the allowed directory"
)
return resolved_path
def _get_file_path(self, project_id: str, task_id: str, attachment_id: str, version: int) -> Path:
"""Generate the file path for an attachment version."""
return self.base_dir / project_id / task_id / attachment_id / str(version)
"""
Generate the file path for an attachment version.
Args:
project_id: The project ID
task_id: The task ID
attachment_id: The attachment ID
version: The version number
Returns:
Safe path within the base directory
Raises:
PathTraversalError: If any component contains path traversal patterns
"""
# Validate all path components
self._validate_path_component(project_id, "project_id")
self._validate_path_component(task_id, "task_id")
self._validate_path_component(attachment_id, "attachment_id")
if version < 0:
raise PathTraversalError("Version must be non-negative")
# Build the path
path = self.base_dir / project_id / task_id / attachment_id / str(version)
# Validate the final path is within base directory
return self._validate_path_in_base_dir(
path,
f"project={project_id}, task={task_id}, attachment={attachment_id}, version={version}"
)
@staticmethod
def calculate_checksum(file: BinaryIO) -> str:
@@ -89,6 +334,10 @@ class FileStorageService:
"""
Save uploaded file to storage.
Returns (file_path, file_size, checksum).
Raises:
PathTraversalError: If path traversal is detected
HTTPException: If file validation fails
"""
# Validate file
extension, _ = self.validate_file(file)
@@ -96,14 +345,22 @@ class FileStorageService:
# Calculate checksum first
checksum = self.calculate_checksum(file.file)
# Create directory structure
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
# Create directory structure (path validation is done in _get_file_path)
try:
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
except PathTraversalError as e:
logger.error("Path traversal attempt during file save: %s", e)
raise HTTPException(status_code=400, detail=str(e))
dir_path.mkdir(parents=True, exist_ok=True)
# Save file with original extension
filename = f"file.{extension}" if extension else "file"
file_path = dir_path / filename
# Final validation of the file path
self._validate_path_in_base_dir(file_path, f"saving file {filename}")
# Get file size
file.file.seek(0, 2)
file_size = file.file.tell()
@@ -125,8 +382,15 @@ class FileStorageService:
"""
Get the file path for an attachment version.
Returns None if file doesn't exist.
Raises:
PathTraversalError: If path traversal is detected
"""
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
try:
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
except PathTraversalError as e:
logger.error("Path traversal attempt during file retrieval: %s", e)
return None
if not dir_path.exists():
return None
@@ -139,21 +403,48 @@ class FileStorageService:
return files[0]
def get_file_by_path(self, file_path: str) -> Optional[Path]:
"""Get file by stored path. Handles both absolute and relative paths."""
"""
Get file by stored path. Handles both absolute and relative paths.
This method validates that the requested path is within the base directory
to prevent path traversal attacks.
Args:
file_path: The stored file path
Returns:
Path object if file exists and is within base directory, None otherwise
"""
if not file_path:
return None
path = Path(file_path)
# If path is absolute and exists, return it directly
if path.is_absolute() and path.exists():
return path
# For absolute paths, validate they are within base_dir
if path.is_absolute():
try:
validated_path = self._validate_path_in_base_dir(
path,
f"get_file_by_path absolute: {file_path}"
)
if validated_path.exists():
return validated_path
except PathTraversalError:
return None
return None
# If path is relative, try prepending base_dir
# For relative paths, resolve from base_dir
full_path = self.base_dir / path
if full_path.exists():
return full_path
# Fallback: check if original path exists (e.g., relative from current dir)
if path.exists():
return path
try:
validated_path = self._validate_path_in_base_dir(
full_path,
f"get_file_by_path relative: {file_path}"
)
if validated_path.exists():
return validated_path
except PathTraversalError:
return None
return None
@@ -168,13 +459,29 @@ class FileStorageService:
Delete file(s) from storage.
If version is None, deletes all versions.
Returns True if successful.
Raises:
PathTraversalError: If path traversal is detected
"""
if version is not None:
# Delete specific version
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
else:
# Delete all versions (attachment directory)
dir_path = self.base_dir / project_id / task_id / attachment_id
try:
if version is not None:
# Delete specific version
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
else:
# Delete all versions (attachment directory)
# Validate components first
self._validate_path_component(project_id, "project_id")
self._validate_path_component(task_id, "task_id")
self._validate_path_component(attachment_id, "attachment_id")
dir_path = self.base_dir / project_id / task_id / attachment_id
dir_path = self._validate_path_in_base_dir(
dir_path,
f"delete attachment: project={project_id}, task={task_id}, attachment={attachment_id}"
)
except PathTraversalError as e:
logger.error("Path traversal attempt during file deletion: %s", e)
return False
if dir_path.exists():
shutil.rmtree(dir_path)
@@ -182,8 +489,26 @@ class FileStorageService:
return False
def delete_task_files(self, project_id: str, task_id: str) -> bool:
"""Delete all files for a task."""
dir_path = self.base_dir / project_id / task_id
"""
Delete all files for a task.
Raises:
PathTraversalError: If path traversal is detected
"""
try:
# Validate components
self._validate_path_component(project_id, "project_id")
self._validate_path_component(task_id, "task_id")
dir_path = self.base_dir / project_id / task_id
dir_path = self._validate_path_in_base_dir(
dir_path,
f"delete task files: project={project_id}, task={task_id}"
)
except PathTraversalError as e:
logger.error("Path traversal attempt during task file deletion: %s", e)
return False
if dir_path.exists():
shutil.rmtree(dir_path)
return True

View File

@@ -29,7 +29,17 @@ class FormulaError(Exception):
class CircularReferenceError(FormulaError):
"""Exception raised when circular references are detected in formulas."""
pass
def __init__(self, message: str, cycle_path: Optional[List[str]] = None):
super().__init__(message)
self.message = message
self.cycle_path = cycle_path or []
def get_cycle_description(self) -> str:
"""Get a human-readable description of the cycle."""
if not self.cycle_path:
return ""
return " -> ".join(self.cycle_path)
class FormulaService:
@@ -140,24 +150,43 @@ class FormulaService:
field_id: str,
references: Set[str],
visited: Optional[Set[str]] = None,
path: Optional[List[str]] = None,
) -> None:
"""
Check for circular references in formula fields.
Raises CircularReferenceError if a cycle is detected.
Args:
db: Database session
project_id: Project ID to scope the query
field_id: The current field being validated
references: Set of field names referenced in the formula
visited: Set of visited field IDs (for cycle detection)
path: Current path of field names (for error reporting)
"""
if visited is None:
visited = set()
if path is None:
path = []
# Get the current field's name
current_field = db.query(CustomField).filter(
CustomField.id == field_id
).first()
current_field_name = current_field.name if current_field else "unknown"
# Add current field to path if not already there
if current_field_name not in path:
path = path + [current_field_name]
if current_field:
if current_field.name in references:
cycle_path = path + [current_field.name]
raise CircularReferenceError(
f"Circular reference detected: field cannot reference itself"
f"Circular reference detected: field '{current_field.name}' cannot reference itself",
cycle_path=cycle_path
)
# Get all referenced formula fields
@@ -173,22 +202,199 @@ class FormulaService:
for field in formula_fields:
if field.id in visited:
# Found a cycle
cycle_path = path + [field.name]
raise CircularReferenceError(
f"Circular reference detected involving field '{field.name}'"
f"Circular reference detected: {' -> '.join(cycle_path)}",
cycle_path=cycle_path
)
visited.add(field.id)
new_path = path + [field.name]
if field.formula:
nested_refs = FormulaService.extract_field_references(field.formula)
if current_field and current_field.name in nested_refs:
cycle_path = new_path + [current_field.name]
raise CircularReferenceError(
f"Circular reference detected: '{field.name}' references the current field"
f"Circular reference detected: {' -> '.join(cycle_path)}",
cycle_path=cycle_path
)
FormulaService._check_circular_references(
db, project_id, field_id, nested_refs, visited
db, project_id, field_id, nested_refs, visited, new_path
)
@staticmethod
def build_formula_dependency_graph(
db: Session,
project_id: str
) -> Dict[str, Set[str]]:
"""
Build a dependency graph for all formula fields in a project.
Returns a dict where keys are field names and values are sets of
field names that the key field depends on.
Args:
db: Database session
project_id: Project ID to scope the query
Returns:
Dict mapping field names to their dependencies
"""
graph: Dict[str, Set[str]] = {}
formula_fields = db.query(CustomField).filter(
CustomField.project_id == project_id,
CustomField.field_type == "formula",
).all()
for field in formula_fields:
if field.formula:
refs = FormulaService.extract_field_references(field.formula)
# Only include custom field references (not builtin fields)
custom_refs = refs - FormulaService.BUILTIN_FIELDS
graph[field.name] = custom_refs
else:
graph[field.name] = set()
return graph
@staticmethod
def detect_formula_cycles(
db: Session,
project_id: str
) -> List[List[str]]:
"""
Detect all cycles in the formula dependency graph for a project.
This is useful for auditing or cleanup operations.
Args:
db: Database session
project_id: Project ID to check
Returns:
List of cycles, where each cycle is a list of field names
"""
graph = FormulaService.build_formula_dependency_graph(db, project_id)
if not graph:
return []
cycles: List[List[str]] = []
visited: Set[str] = set()
found_cycles: Set[Tuple[str, ...]] = set()
def dfs(node: str, path: List[str], in_path: Set[str]):
"""DFS to find cycles."""
if node in in_path:
# Found a cycle
cycle_start = path.index(node)
cycle = path[cycle_start:] + [node]
# Normalize for deduplication
normalized = tuple(sorted(cycle[:-1]))
if normalized not in found_cycles:
found_cycles.add(normalized)
cycles.append(cycle)
return
if node in visited:
return
if node not in graph:
return
visited.add(node)
in_path.add(node)
path.append(node)
for neighbor in graph.get(node, set()):
dfs(neighbor, path.copy(), in_path.copy())
path.pop()
in_path.discard(node)
for start_node in graph.keys():
if start_node not in visited:
dfs(start_node, [], set())
return cycles
@staticmethod
def validate_formula_with_details(
formula: str,
project_id: str,
db: Session,
current_field_id: Optional[str] = None,
) -> Tuple[bool, Optional[str], Optional[List[str]]]:
"""
Validate a formula expression with detailed error information.
Similar to validate_formula but returns cycle path on circular reference errors.
Args:
formula: The formula expression to validate
project_id: Project ID to scope field lookups
db: Database session
current_field_id: Optional ID of the field being edited (for self-reference check)
Returns:
Tuple of (is_valid, error_message, cycle_path)
"""
if not formula or not formula.strip():
return False, "Formula cannot be empty", None
# Extract field references
references = FormulaService.extract_field_references(formula)
if not references:
return False, "Formula must reference at least one field", None
# Validate syntax by trying to parse
try:
# Replace field references with dummy numbers for syntax check
test_formula = formula
for ref in references:
test_formula = test_formula.replace(f"{{{ref}}}", "1")
# Try to parse and evaluate with safe operations
FormulaService._safe_eval(test_formula)
except Exception as e:
return False, f"Invalid formula syntax: {str(e)}", None
# Separate builtin and custom field references
custom_references = references - FormulaService.BUILTIN_FIELDS
# Validate custom field references exist and are numeric types
if custom_references:
fields = db.query(CustomField).filter(
CustomField.project_id == project_id,
CustomField.name.in_(custom_references),
).all()
found_names = {f.name for f in fields}
missing = custom_references - found_names
if missing:
return False, f"Unknown field references: {', '.join(missing)}", None
# Check field types (must be number or formula)
for field in fields:
if field.field_type not in ("number", "formula"):
return False, f"Field '{field.name}' is not a numeric type", None
# Check for circular references
if current_field_id:
try:
FormulaService._check_circular_references(
db, project_id, current_field_id, references
)
except CircularReferenceError as e:
return False, str(e), e.cycle_path
return True, None, None
@staticmethod
def _safe_eval(expression: str) -> Decimal:
"""

View File

@@ -4,8 +4,10 @@ import re
import asyncio
import logging
import threading
import os
from datetime import datetime, timezone
from typing import List, Optional, Dict, Set
from collections import deque
from sqlalchemy.orm import Session
from sqlalchemy import event
@@ -22,9 +24,152 @@ _pending_publish: Dict[int, List[dict]] = {}
# Track which sessions have handlers registered
_registered_sessions: Set[int] = set()
# Redis fallback queue configuration
REDIS_FALLBACK_MAX_QUEUE_SIZE = int(os.getenv("REDIS_FALLBACK_MAX_QUEUE_SIZE", "1000"))
REDIS_FALLBACK_RETRY_INTERVAL = int(os.getenv("REDIS_FALLBACK_RETRY_INTERVAL", "5")) # seconds
REDIS_FALLBACK_MAX_RETRIES = int(os.getenv("REDIS_FALLBACK_MAX_RETRIES", "10"))
# Redis fallback queue for failed publishes
_redis_fallback_lock = threading.Lock()
_redis_fallback_queue: deque = deque(maxlen=REDIS_FALLBACK_MAX_QUEUE_SIZE)
_redis_retry_timer: Optional[threading.Timer] = None
_redis_available = True
_redis_consecutive_failures = 0
def _add_to_fallback_queue(user_id: str, data: dict, retry_count: int = 0) -> bool:
"""
Add a failed notification to the fallback queue.
Returns True if added successfully, False if queue is full.
"""
global _redis_consecutive_failures
with _redis_fallback_lock:
if len(_redis_fallback_queue) >= REDIS_FALLBACK_MAX_QUEUE_SIZE:
logger.warning(
"Redis fallback queue is full (%d items), dropping notification for user %s",
REDIS_FALLBACK_MAX_QUEUE_SIZE, user_id
)
return False
_redis_fallback_queue.append({
"user_id": user_id,
"data": data,
"retry_count": retry_count,
"queued_at": datetime.now(timezone.utc).isoformat(),
})
_redis_consecutive_failures += 1
queue_size = len(_redis_fallback_queue)
logger.debug("Added notification to fallback queue (size: %d)", queue_size)
# Start retry mechanism if not already running
_ensure_retry_timer_running()
return True
def _ensure_retry_timer_running():
"""Ensure the retry timer is running if there are items in the queue."""
global _redis_retry_timer
if _redis_retry_timer is None or not _redis_retry_timer.is_alive():
_redis_retry_timer = threading.Timer(REDIS_FALLBACK_RETRY_INTERVAL, _process_fallback_queue)
_redis_retry_timer.daemon = True
_redis_retry_timer.start()
def _process_fallback_queue():
"""Process the fallback queue and retry sending notifications to Redis."""
global _redis_available, _redis_consecutive_failures, _redis_retry_timer
items_to_retry = []
with _redis_fallback_lock:
# Get all items from queue
while _redis_fallback_queue:
items_to_retry.append(_redis_fallback_queue.popleft())
if not items_to_retry:
_redis_retry_timer = None
return
logger.info("Processing %d items from Redis fallback queue", len(items_to_retry))
failed_items = []
success_count = 0
for item in items_to_retry:
user_id = item["user_id"]
data = item["data"]
retry_count = item["retry_count"]
if retry_count >= REDIS_FALLBACK_MAX_RETRIES:
logger.warning(
"Notification for user %s exceeded max retries (%d), dropping",
user_id, REDIS_FALLBACK_MAX_RETRIES
)
continue
try:
redis_client = get_redis_sync()
channel = get_channel_name(user_id)
message = json.dumps(data, default=str)
redis_client.publish(channel, message)
success_count += 1
except Exception as e:
logger.debug("Retry failed for user %s: %s", user_id, e)
failed_items.append({
**item,
"retry_count": retry_count + 1,
})
# Re-queue failed items
if failed_items:
with _redis_fallback_lock:
for item in failed_items:
if len(_redis_fallback_queue) < REDIS_FALLBACK_MAX_QUEUE_SIZE:
_redis_fallback_queue.append(item)
# Log recovery if we had successes
if success_count > 0:
with _redis_fallback_lock:
_redis_consecutive_failures = 0
if not _redis_fallback_queue:
_redis_available = True
logger.info(
"Redis connection recovered. Successfully processed %d notifications from fallback queue",
success_count
)
# Schedule next retry if queue is not empty
with _redis_fallback_lock:
if _redis_fallback_queue:
_redis_retry_timer = threading.Timer(REDIS_FALLBACK_RETRY_INTERVAL, _process_fallback_queue)
_redis_retry_timer.daemon = True
_redis_retry_timer.start()
else:
_redis_retry_timer = None
def get_redis_fallback_status() -> dict:
"""Get current Redis fallback queue status for health checks."""
with _redis_fallback_lock:
return {
"queue_size": len(_redis_fallback_queue),
"max_queue_size": REDIS_FALLBACK_MAX_QUEUE_SIZE,
"redis_available": _redis_available,
"consecutive_failures": _redis_consecutive_failures,
"retry_interval_seconds": REDIS_FALLBACK_RETRY_INTERVAL,
"max_retries": REDIS_FALLBACK_MAX_RETRIES,
}
def _sync_publish(user_id: str, data: dict):
"""Sync fallback to publish notification via Redis when no event loop available."""
global _redis_available
try:
redis_client = get_redis_sync()
channel = get_channel_name(user_id)
@@ -33,6 +178,10 @@ def _sync_publish(user_id: str, data: dict):
logger.debug(f"Sync published notification to channel {channel}")
except Exception as e:
logger.error(f"Failed to sync publish notification to Redis: {e}")
# Add to fallback queue for retry
with _redis_fallback_lock:
_redis_available = False
_add_to_fallback_queue(user_id, data)
def _cleanup_session(session_id: int, remove_registration: bool = True):
@@ -86,10 +235,16 @@ def _register_session_handlers(db: Session, session_id: int):
async def _async_publish(user_id: str, data: dict):
"""Async helper to publish notification to Redis."""
global _redis_available
try:
await redis_publish(user_id, data)
except Exception as e:
logger.error(f"Failed to publish notification to Redis: {e}")
# Add to fallback queue for retry
with _redis_fallback_lock:
_redis_available = False
_add_to_fallback_queue(user_id, data)
class NotificationService:

View File

@@ -7,6 +7,7 @@ scheduled triggers based on their cron schedule, including deadline reminders.
import uuid
import logging
import time
from datetime import datetime, timezone, timedelta
from typing import Optional, List, Dict, Any, Tuple, Set
@@ -22,6 +23,10 @@ logger = logging.getLogger(__name__)
# Key prefix for tracking deadline reminders already sent
DEADLINE_REMINDER_LOG_TYPE = "deadline_reminder"
# Retry configuration
MAX_RETRIES = 3
BASE_DELAY_SECONDS = 1 # 1s, 2s, 4s exponential backoff
class TriggerSchedulerService:
"""Service for scheduling and executing cron-based triggers."""
@@ -220,50 +225,170 @@ class TriggerSchedulerService:
@staticmethod
def _execute_trigger(db: Session, trigger: Trigger) -> TriggerLog:
"""
Execute a scheduled trigger's actions.
Execute a scheduled trigger's actions with retry mechanism.
Implements exponential backoff retry (1s, 2s, 4s) for transient failures.
After max retries are exhausted, marks as permanently failed and sends alert.
Args:
db: Database session
trigger: The trigger to execute
Returns:
TriggerLog entry for this execution
"""
return TriggerSchedulerService._execute_trigger_with_retry(
db=db,
trigger=trigger,
task_id=None,
log_type="schedule",
)
@staticmethod
def _execute_trigger_with_retry(
db: Session,
trigger: Trigger,
task_id: Optional[str] = None,
log_type: str = "schedule",
) -> TriggerLog:
"""
Execute trigger actions with exponential backoff retry.
Args:
db: Database session
trigger: The trigger to execute
task_id: Optional task ID for context (deadline reminders)
log_type: Type of trigger execution for logging
Returns:
TriggerLog entry for this execution
"""
actions = trigger.actions if isinstance(trigger.actions, list) else [trigger.actions]
executed_actions = []
error_message = None
last_error = None
attempt = 0
try:
for action in actions:
action_type = action.get("type")
while attempt < MAX_RETRIES:
attempt += 1
executed_actions = []
last_error = None
if action_type == "notify":
TriggerSchedulerService._execute_notify_action(db, action, trigger)
executed_actions.append({"type": action_type, "status": "success"})
try:
logger.info(
f"Executing trigger {trigger.id} (attempt {attempt}/{MAX_RETRIES})"
)
# Add more action types here as needed
for action in actions:
action_type = action.get("type")
status = "success"
if action_type == "notify":
TriggerSchedulerService._execute_notify_action(db, action, trigger)
executed_actions.append({"type": action_type, "status": "success"})
except Exception as e:
status = "failed"
error_message = str(e)
executed_actions.append({"type": "error", "message": str(e)})
logger.error(f"Error executing trigger {trigger.id} actions: {e}")
# Add more action types here as needed
# Success - return log
logger.info(f"Trigger {trigger.id} executed successfully on attempt {attempt}")
return TriggerSchedulerService._log_execution(
db=db,
trigger=trigger,
status="success",
details={
"trigger_name": trigger.name,
"trigger_type": log_type,
"cron_expression": trigger.conditions.get("cron_expression") if trigger.conditions else None,
"actions_executed": executed_actions,
"attempts": attempt,
},
error_message=None,
task_id=task_id,
)
except Exception as e:
last_error = e
executed_actions.append({"type": "error", "message": str(e)})
logger.warning(
f"Trigger {trigger.id} failed on attempt {attempt}/{MAX_RETRIES}: {e}"
)
# Calculate exponential backoff delay
if attempt < MAX_RETRIES:
delay = BASE_DELAY_SECONDS * (2 ** (attempt - 1))
logger.info(f"Retrying trigger {trigger.id} in {delay}s...")
time.sleep(delay)
# All retries exhausted - permanent failure
logger.error(
f"Trigger {trigger.id} permanently failed after {MAX_RETRIES} attempts: {last_error}"
)
# Send alert notification for permanent failure
TriggerSchedulerService._send_failure_alert(db, trigger, str(last_error), MAX_RETRIES)
return TriggerSchedulerService._log_execution(
db=db,
trigger=trigger,
status=status,
status="permanently_failed",
details={
"trigger_name": trigger.name,
"trigger_type": "schedule",
"cron_expression": trigger.conditions.get("cron_expression"),
"trigger_type": log_type,
"cron_expression": trigger.conditions.get("cron_expression") if trigger.conditions else None,
"actions_executed": executed_actions,
"attempts": MAX_RETRIES,
"permanent_failure": True,
},
error_message=error_message,
error_message=f"Failed after {MAX_RETRIES} retries: {last_error}",
task_id=task_id,
)
@staticmethod
def _send_failure_alert(
db: Session,
trigger: Trigger,
error_message: str,
attempts: int,
) -> None:
"""
Send alert notification when trigger exhausts all retries.
Args:
db: Database session
trigger: The failed trigger
error_message: The last error message
attempts: Number of attempts made
"""
try:
# Notify the project owner about the failure
project = trigger.project
if not project:
logger.warning(f"Cannot send failure alert: trigger {trigger.id} has no project")
return
target_user_id = project.owner_id
if not target_user_id:
logger.warning(f"Cannot send failure alert: project {project.id} has no owner")
return
message = (
f"Trigger '{trigger.name}' has permanently failed after {attempts} attempts. "
f"Last error: {error_message}"
)
NotificationService.create_notification(
db=db,
user_id=target_user_id,
notification_type="trigger_failure",
reference_type="trigger",
reference_id=trigger.id,
title=f"Trigger Failed: {trigger.name}",
message=message,
)
logger.info(f"Sent failure alert for trigger {trigger.id} to user {target_user_id}")
except Exception as e:
logger.error(f"Failed to send failure alert for trigger {trigger.id}: {e}")
@staticmethod
def _execute_notify_action(db: Session, action: Dict[str, Any], trigger: Trigger) -> None:
"""