feat: implement 8 OpenSpec proposals for security, reliability, and UX improvements
## Security Enhancements (P0) - Add input validation with max_length and numeric range constraints - Implement WebSocket token authentication via first message - Add path traversal prevention in file storage service ## Permission Enhancements (P0) - Add project member management for cross-department access - Implement is_department_manager flag for workload visibility ## Cycle Detection (P0) - Add DFS-based cycle detection for task dependencies - Add formula field circular reference detection - Display user-friendly cycle path visualization ## Concurrency & Reliability (P1) - Implement optimistic locking with version field (409 Conflict on mismatch) - Add trigger retry mechanism with exponential backoff (1s, 2s, 4s) - Implement cascade restore for soft-deleted tasks ## Rate Limiting (P1) - Add tiered rate limits: standard (60/min), sensitive (20/min), heavy (5/min) - Apply rate limits to tasks, reports, attachments, and comments ## Frontend Improvements (P1) - Add responsive sidebar with hamburger menu for mobile - Improve touch-friendly UI with proper tap target sizes - Complete i18n translations for all components ## Backend Reliability (P2) - Configure database connection pool (size=10, overflow=20) - Add Redis fallback mechanism with message queue - Add blocker check before task deletion ## API Enhancements (P3) - Add standardized response wrapper utility - Add /health/ready and /health/live endpoints - Implement project templates with status/field copying ## Tests Added - test_input_validation.py - Schema and path traversal tests - test_concurrency_reliability.py - Optimistic locking and retry tests - test_backend_reliability.py - Connection pool and Redis tests - test_api_enhancements.py - Health check and template tests Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -6,6 +6,7 @@ Handles task dependency validation including:
|
||||
- Date constraint validation based on dependency types
|
||||
- Self-reference prevention
|
||||
- Cross-project dependency prevention
|
||||
- Bulk dependency operations with cycle detection
|
||||
"""
|
||||
from typing import List, Optional, Set, Tuple, Dict, Any
|
||||
from collections import defaultdict
|
||||
@@ -25,6 +26,27 @@ class DependencyValidationError(Exception):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CycleDetectionResult:
|
||||
"""Result of cycle detection with detailed path information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
has_cycle: bool,
|
||||
cycle_path: Optional[List[str]] = None,
|
||||
cycle_task_titles: Optional[List[str]] = None
|
||||
):
|
||||
self.has_cycle = has_cycle
|
||||
self.cycle_path = cycle_path or []
|
||||
self.cycle_task_titles = cycle_task_titles or []
|
||||
|
||||
def get_cycle_description(self) -> str:
|
||||
"""Get a human-readable description of the cycle."""
|
||||
if not self.has_cycle or not self.cycle_task_titles:
|
||||
return ""
|
||||
# Format: Task A -> Task B -> Task C -> Task A
|
||||
return " -> ".join(self.cycle_task_titles)
|
||||
|
||||
|
||||
class DependencyService:
|
||||
"""Service for managing task dependencies with validation."""
|
||||
|
||||
@@ -53,9 +75,36 @@ class DependencyService:
|
||||
Returns:
|
||||
List of task IDs forming the cycle if circular, None otherwise
|
||||
"""
|
||||
# If adding predecessor -> successor, check if successor can reach predecessor
|
||||
# This would mean predecessor depends (transitively) on successor, creating a cycle
|
||||
result = DependencyService.detect_circular_dependency_detailed(
|
||||
db, predecessor_id, successor_id, project_id
|
||||
)
|
||||
return result.cycle_path if result.has_cycle else None
|
||||
|
||||
@staticmethod
|
||||
def detect_circular_dependency_detailed(
|
||||
db: Session,
|
||||
predecessor_id: str,
|
||||
successor_id: str,
|
||||
project_id: str,
|
||||
additional_edges: Optional[List[Tuple[str, str]]] = None
|
||||
) -> CycleDetectionResult:
|
||||
"""
|
||||
Detect if adding a dependency would create a circular reference.
|
||||
|
||||
Uses DFS to traverse from the successor to check if we can reach
|
||||
the predecessor through existing dependencies.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
predecessor_id: The task that must complete first
|
||||
successor_id: The task that depends on the predecessor
|
||||
project_id: Project ID to scope the query
|
||||
additional_edges: Optional list of additional (predecessor_id, successor_id)
|
||||
edges to consider (for bulk operations)
|
||||
|
||||
Returns:
|
||||
CycleDetectionResult with detailed cycle information
|
||||
"""
|
||||
# Build adjacency list for the project's dependencies
|
||||
dependencies = db.query(TaskDependency).join(
|
||||
Task, TaskDependency.successor_id == Task.id
|
||||
@@ -71,6 +120,20 @@ class DependencyService:
|
||||
# Simulate adding the new edge
|
||||
graph[successor_id].append(predecessor_id)
|
||||
|
||||
# Add any additional edges for bulk operations
|
||||
if additional_edges:
|
||||
for pred_id, succ_id in additional_edges:
|
||||
graph[succ_id].append(pred_id)
|
||||
|
||||
# Build task title map for readable error messages
|
||||
task_ids_in_graph = set()
|
||||
for succ_id, pred_ids in graph.items():
|
||||
task_ids_in_graph.add(succ_id)
|
||||
task_ids_in_graph.update(pred_ids)
|
||||
|
||||
tasks = db.query(Task).filter(Task.id.in_(task_ids_in_graph)).all()
|
||||
task_title_map: Dict[str, str] = {t.id: t.title for t in tasks}
|
||||
|
||||
# DFS to find if there's a path from predecessor back to successor
|
||||
# (which would complete a cycle)
|
||||
visited: Set[str] = set()
|
||||
@@ -101,7 +164,18 @@ class DependencyService:
|
||||
return None
|
||||
|
||||
# Start DFS from the successor to check if we can reach back to it
|
||||
return dfs(successor_id)
|
||||
cycle_path = dfs(successor_id)
|
||||
|
||||
if cycle_path:
|
||||
# Build task titles for the cycle
|
||||
cycle_titles = [task_title_map.get(task_id, task_id) for task_id in cycle_path]
|
||||
return CycleDetectionResult(
|
||||
has_cycle=True,
|
||||
cycle_path=cycle_path,
|
||||
cycle_task_titles=cycle_titles
|
||||
)
|
||||
|
||||
return CycleDetectionResult(has_cycle=False)
|
||||
|
||||
@staticmethod
|
||||
def validate_dependency(
|
||||
@@ -183,15 +257,19 @@ class DependencyService:
|
||||
)
|
||||
|
||||
# Check circular dependency
|
||||
cycle = DependencyService.detect_circular_dependency(
|
||||
cycle_result = DependencyService.detect_circular_dependency_detailed(
|
||||
db, predecessor_id, successor_id, predecessor.project_id
|
||||
)
|
||||
|
||||
if cycle:
|
||||
if cycle_result.has_cycle:
|
||||
raise DependencyValidationError(
|
||||
error_type="circular",
|
||||
message="Adding this dependency would create a circular reference",
|
||||
details={"cycle": cycle}
|
||||
message=f"Adding this dependency would create a circular reference: {cycle_result.get_cycle_description()}",
|
||||
details={
|
||||
"cycle": cycle_result.cycle_path,
|
||||
"cycle_description": cycle_result.get_cycle_description(),
|
||||
"cycle_task_titles": cycle_result.cycle_task_titles
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -422,3 +500,202 @@ class DependencyService:
|
||||
queue.append(dep.successor_id)
|
||||
|
||||
return successors
|
||||
|
||||
@staticmethod
|
||||
def validate_bulk_dependencies(
|
||||
db: Session,
|
||||
dependencies: List[Tuple[str, str]],
|
||||
project_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Validate a batch of dependencies for cycle detection.
|
||||
|
||||
This method validates multiple dependencies together to detect cycles
|
||||
that would only appear when all dependencies are added together.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
dependencies: List of (predecessor_id, successor_id) tuples
|
||||
project_id: Project ID to scope the query
|
||||
|
||||
Returns:
|
||||
List of validation errors (empty if all valid)
|
||||
"""
|
||||
errors: List[Dict[str, Any]] = []
|
||||
|
||||
if not dependencies:
|
||||
return errors
|
||||
|
||||
# First, validate each dependency individually for basic checks
|
||||
for predecessor_id, successor_id in dependencies:
|
||||
# Check self-reference
|
||||
if predecessor_id == successor_id:
|
||||
errors.append({
|
||||
"error_type": "self_reference",
|
||||
"predecessor_id": predecessor_id,
|
||||
"successor_id": successor_id,
|
||||
"message": "A task cannot depend on itself"
|
||||
})
|
||||
continue
|
||||
|
||||
# Get tasks to validate project membership
|
||||
predecessor = db.query(Task).filter(Task.id == predecessor_id).first()
|
||||
successor = db.query(Task).filter(Task.id == successor_id).first()
|
||||
|
||||
if not predecessor:
|
||||
errors.append({
|
||||
"error_type": "not_found",
|
||||
"predecessor_id": predecessor_id,
|
||||
"successor_id": successor_id,
|
||||
"message": f"Predecessor task not found: {predecessor_id}"
|
||||
})
|
||||
continue
|
||||
|
||||
if not successor:
|
||||
errors.append({
|
||||
"error_type": "not_found",
|
||||
"predecessor_id": predecessor_id,
|
||||
"successor_id": successor_id,
|
||||
"message": f"Successor task not found: {successor_id}"
|
||||
})
|
||||
continue
|
||||
|
||||
if predecessor.project_id != project_id or successor.project_id != project_id:
|
||||
errors.append({
|
||||
"error_type": "cross_project",
|
||||
"predecessor_id": predecessor_id,
|
||||
"successor_id": successor_id,
|
||||
"message": "All tasks must be in the same project"
|
||||
})
|
||||
continue
|
||||
|
||||
# Check for duplicates within the batch
|
||||
existing = db.query(TaskDependency).filter(
|
||||
TaskDependency.predecessor_id == predecessor_id,
|
||||
TaskDependency.successor_id == successor_id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
errors.append({
|
||||
"error_type": "duplicate",
|
||||
"predecessor_id": predecessor_id,
|
||||
"successor_id": successor_id,
|
||||
"message": "This dependency already exists"
|
||||
})
|
||||
|
||||
# If there are basic validation errors, return them first
|
||||
if errors:
|
||||
return errors
|
||||
|
||||
# Now check for cycles considering all dependencies together
|
||||
# Build the graph incrementally and check for cycles
|
||||
accumulated_edges: List[Tuple[str, str]] = []
|
||||
|
||||
for predecessor_id, successor_id in dependencies:
|
||||
# Check if adding this edge (plus all previously accumulated edges)
|
||||
# would create a cycle
|
||||
cycle_result = DependencyService.detect_circular_dependency_detailed(
|
||||
db,
|
||||
predecessor_id,
|
||||
successor_id,
|
||||
project_id,
|
||||
additional_edges=accumulated_edges
|
||||
)
|
||||
|
||||
if cycle_result.has_cycle:
|
||||
errors.append({
|
||||
"error_type": "circular",
|
||||
"predecessor_id": predecessor_id,
|
||||
"successor_id": successor_id,
|
||||
"message": f"Adding this dependency would create a circular reference: {cycle_result.get_cycle_description()}",
|
||||
"cycle": cycle_result.cycle_path,
|
||||
"cycle_description": cycle_result.get_cycle_description(),
|
||||
"cycle_task_titles": cycle_result.cycle_task_titles
|
||||
})
|
||||
else:
|
||||
# Add this edge to accumulated edges for subsequent checks
|
||||
accumulated_edges.append((predecessor_id, successor_id))
|
||||
|
||||
return errors
|
||||
|
||||
@staticmethod
|
||||
def detect_cycles_in_graph(
|
||||
db: Session,
|
||||
project_id: str
|
||||
) -> List[CycleDetectionResult]:
|
||||
"""
|
||||
Detect all cycles in the existing dependency graph for a project.
|
||||
|
||||
This is useful for auditing or cleanup operations.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
project_id: Project ID to check
|
||||
|
||||
Returns:
|
||||
List of CycleDetectionResult for each cycle found
|
||||
"""
|
||||
cycles: List[CycleDetectionResult] = []
|
||||
|
||||
# Get all dependencies for the project
|
||||
dependencies = db.query(TaskDependency).join(
|
||||
Task, TaskDependency.successor_id == Task.id
|
||||
).filter(Task.project_id == project_id).all()
|
||||
|
||||
if not dependencies:
|
||||
return cycles
|
||||
|
||||
# Build the graph
|
||||
graph: Dict[str, List[str]] = defaultdict(list)
|
||||
for dep in dependencies:
|
||||
graph[dep.successor_id].append(dep.predecessor_id)
|
||||
|
||||
# Get task titles
|
||||
task_ids = set()
|
||||
for succ_id, pred_ids in graph.items():
|
||||
task_ids.add(succ_id)
|
||||
task_ids.update(pred_ids)
|
||||
|
||||
tasks = db.query(Task).filter(Task.id.in_(task_ids)).all()
|
||||
task_title_map: Dict[str, str] = {t.id: t.title for t in tasks}
|
||||
|
||||
# Find all cycles using DFS
|
||||
visited: Set[str] = set()
|
||||
found_cycles: Set[Tuple[str, ...]] = set()
|
||||
|
||||
def find_cycles_dfs(node: str, path: List[str], in_path: Set[str]):
|
||||
"""DFS to find all cycles."""
|
||||
if node in in_path:
|
||||
# Found a cycle
|
||||
cycle_start = path.index(node)
|
||||
cycle = tuple(sorted(path[cycle_start:])) # Normalize for dedup
|
||||
if cycle not in found_cycles:
|
||||
found_cycles.add(cycle)
|
||||
actual_cycle = path[cycle_start:] + [node]
|
||||
cycle_titles = [task_title_map.get(tid, tid) for tid in actual_cycle]
|
||||
cycles.append(CycleDetectionResult(
|
||||
has_cycle=True,
|
||||
cycle_path=actual_cycle,
|
||||
cycle_task_titles=cycle_titles
|
||||
))
|
||||
return
|
||||
|
||||
if node in visited:
|
||||
return
|
||||
|
||||
visited.add(node)
|
||||
in_path.add(node)
|
||||
path.append(node)
|
||||
|
||||
for neighbor in graph.get(node, []):
|
||||
find_cycles_dfs(neighbor, path.copy(), in_path.copy())
|
||||
|
||||
path.pop()
|
||||
in_path.remove(node)
|
||||
|
||||
# Start DFS from all nodes
|
||||
for start_node in graph.keys():
|
||||
if start_node not in visited:
|
||||
find_cycles_dfs(start_node, [], set())
|
||||
|
||||
return cycles
|
||||
|
||||
@@ -1,26 +1,271 @@
|
||||
import os
|
||||
import hashlib
|
||||
import shutil
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO, Optional, Tuple
|
||||
from fastapi import UploadFile, HTTPException
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PathTraversalError(Exception):
|
||||
"""Raised when a path traversal attempt is detected."""
|
||||
pass
|
||||
|
||||
|
||||
class StorageValidationError(Exception):
|
||||
"""Raised when storage validation fails."""
|
||||
pass
|
||||
|
||||
|
||||
class FileStorageService:
|
||||
"""Service for handling file storage operations."""
|
||||
|
||||
# Common NAS mount points to detect
|
||||
NAS_MOUNT_INDICATORS = [
|
||||
"/mnt/", "/mount/", "/nas/", "/nfs/", "/smb/", "/cifs/",
|
||||
"/Volumes/", "/media/", "/srv/", "/storage/"
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.base_dir = Path(settings.UPLOAD_DIR)
|
||||
self._ensure_base_dir()
|
||||
self.base_dir = Path(settings.UPLOAD_DIR).resolve()
|
||||
self._storage_status = {
|
||||
"validated": False,
|
||||
"path_exists": False,
|
||||
"writable": False,
|
||||
"is_nas": False,
|
||||
"error": None,
|
||||
}
|
||||
self._validate_storage_on_init()
|
||||
|
||||
def _validate_storage_on_init(self):
|
||||
"""Validate storage configuration on service initialization."""
|
||||
try:
|
||||
# Step 1: Ensure directory exists
|
||||
self._ensure_base_dir()
|
||||
self._storage_status["path_exists"] = True
|
||||
|
||||
# Step 2: Check write permissions
|
||||
self._check_write_permissions()
|
||||
self._storage_status["writable"] = True
|
||||
|
||||
# Step 3: Check if using NAS
|
||||
is_nas = self._detect_nas_storage()
|
||||
self._storage_status["is_nas"] = is_nas
|
||||
|
||||
if not is_nas:
|
||||
logger.warning(
|
||||
"Storage directory '%s' appears to be local storage, not NAS. "
|
||||
"Consider configuring UPLOAD_DIR to a NAS mount point for production use.",
|
||||
self.base_dir
|
||||
)
|
||||
|
||||
self._storage_status["validated"] = True
|
||||
logger.info(
|
||||
"Storage validated successfully: path=%s, is_nas=%s",
|
||||
self.base_dir, is_nas
|
||||
)
|
||||
|
||||
except StorageValidationError as e:
|
||||
self._storage_status["error"] = str(e)
|
||||
logger.error("Storage validation failed: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
self._storage_status["error"] = str(e)
|
||||
logger.error("Unexpected error during storage validation: %s", e)
|
||||
raise StorageValidationError(f"Storage validation failed: {e}")
|
||||
|
||||
def _check_write_permissions(self):
|
||||
"""Check if the storage directory has write permissions."""
|
||||
test_file = self.base_dir / f".write_test_{os.getpid()}"
|
||||
try:
|
||||
# Try to create and write to a test file
|
||||
test_file.write_text("write_test")
|
||||
# Verify we can read it back
|
||||
content = test_file.read_text()
|
||||
if content != "write_test":
|
||||
raise StorageValidationError(
|
||||
f"Write verification failed for directory: {self.base_dir}"
|
||||
)
|
||||
# Clean up
|
||||
test_file.unlink()
|
||||
logger.debug("Write permission check passed for %s", self.base_dir)
|
||||
except PermissionError as e:
|
||||
raise StorageValidationError(
|
||||
f"No write permission for storage directory '{self.base_dir}': {e}"
|
||||
)
|
||||
except OSError as e:
|
||||
raise StorageValidationError(
|
||||
f"Failed to verify write permissions for '{self.base_dir}': {e}"
|
||||
)
|
||||
finally:
|
||||
# Ensure test file is removed even on partial failure
|
||||
if test_file.exists():
|
||||
try:
|
||||
test_file.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _detect_nas_storage(self) -> bool:
|
||||
"""
|
||||
Detect if the storage directory appears to be on a NAS mount.
|
||||
|
||||
This is a best-effort detection based on common mount point patterns.
|
||||
"""
|
||||
path_str = str(self.base_dir)
|
||||
|
||||
# Check common NAS mount point patterns
|
||||
for indicator in self.NAS_MOUNT_INDICATORS:
|
||||
if indicator in path_str:
|
||||
logger.debug("NAS storage detected: path contains '%s'", indicator)
|
||||
return True
|
||||
|
||||
# Check if it's a mount point (Unix-like systems)
|
||||
try:
|
||||
if self.base_dir.is_mount():
|
||||
logger.debug("NAS storage detected: path is a mount point")
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check mount info on Linux
|
||||
try:
|
||||
with open("/proc/mounts", "r") as f:
|
||||
mounts = f.read()
|
||||
if path_str in mounts:
|
||||
# Check for network filesystem types
|
||||
for line in mounts.splitlines():
|
||||
if path_str in line:
|
||||
fs_type = line.split()[2] if len(line.split()) > 2 else ""
|
||||
if fs_type in ["nfs", "nfs4", "cifs", "smb", "smbfs"]:
|
||||
logger.debug("NAS storage detected: mounted as %s", fs_type)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
pass # Not on Linux
|
||||
except Exception as e:
|
||||
logger.debug("Could not check /proc/mounts: %s", e)
|
||||
|
||||
return False
|
||||
|
||||
def get_storage_status(self) -> dict:
|
||||
"""Get current storage status for health checks."""
|
||||
return {
|
||||
**self._storage_status,
|
||||
"base_dir": str(self.base_dir),
|
||||
"exists": self.base_dir.exists(),
|
||||
"is_directory": self.base_dir.is_dir() if self.base_dir.exists() else False,
|
||||
}
|
||||
|
||||
def _ensure_base_dir(self):
|
||||
"""Ensure the base upload directory exists."""
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _validate_path_component(self, component: str, component_name: str) -> None:
|
||||
"""
|
||||
Validate a path component to prevent path traversal attacks.
|
||||
|
||||
Args:
|
||||
component: The path component to validate (e.g., project_id, task_id)
|
||||
component_name: Name of the component for error messages
|
||||
|
||||
Raises:
|
||||
PathTraversalError: If the component contains path traversal patterns
|
||||
"""
|
||||
if not component:
|
||||
raise PathTraversalError(f"Empty {component_name} is not allowed")
|
||||
|
||||
# Check for path traversal patterns
|
||||
dangerous_patterns = ['..', '/', '\\', '\x00']
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in component:
|
||||
logger.warning(
|
||||
"Path traversal attempt detected in %s: %r",
|
||||
component_name,
|
||||
component
|
||||
)
|
||||
raise PathTraversalError(
|
||||
f"Invalid characters in {component_name}: path traversal not allowed"
|
||||
)
|
||||
|
||||
# Additional check: component should not start with special characters
|
||||
if component.startswith('.') or component.startswith('-'):
|
||||
logger.warning(
|
||||
"Suspicious path component in %s: %r",
|
||||
component_name,
|
||||
component
|
||||
)
|
||||
raise PathTraversalError(
|
||||
f"Invalid {component_name}: cannot start with '.' or '-'"
|
||||
)
|
||||
|
||||
def _validate_path_in_base_dir(self, path: Path, context: str = "") -> Path:
|
||||
"""
|
||||
Validate that a resolved path is within the base directory.
|
||||
|
||||
Args:
|
||||
path: The path to validate
|
||||
context: Additional context for logging
|
||||
|
||||
Returns:
|
||||
The resolved path if valid
|
||||
|
||||
Raises:
|
||||
PathTraversalError: If the path is outside the base directory
|
||||
"""
|
||||
resolved_path = path.resolve()
|
||||
|
||||
# Check if the resolved path is within the base directory
|
||||
try:
|
||||
resolved_path.relative_to(self.base_dir)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Path traversal attempt detected: path %s is outside base directory %s. Context: %s",
|
||||
resolved_path,
|
||||
self.base_dir,
|
||||
context
|
||||
)
|
||||
raise PathTraversalError(
|
||||
"Access denied: path is outside the allowed directory"
|
||||
)
|
||||
|
||||
return resolved_path
|
||||
|
||||
def _get_file_path(self, project_id: str, task_id: str, attachment_id: str, version: int) -> Path:
|
||||
"""Generate the file path for an attachment version."""
|
||||
return self.base_dir / project_id / task_id / attachment_id / str(version)
|
||||
"""
|
||||
Generate the file path for an attachment version.
|
||||
|
||||
Args:
|
||||
project_id: The project ID
|
||||
task_id: The task ID
|
||||
attachment_id: The attachment ID
|
||||
version: The version number
|
||||
|
||||
Returns:
|
||||
Safe path within the base directory
|
||||
|
||||
Raises:
|
||||
PathTraversalError: If any component contains path traversal patterns
|
||||
"""
|
||||
# Validate all path components
|
||||
self._validate_path_component(project_id, "project_id")
|
||||
self._validate_path_component(task_id, "task_id")
|
||||
self._validate_path_component(attachment_id, "attachment_id")
|
||||
|
||||
if version < 0:
|
||||
raise PathTraversalError("Version must be non-negative")
|
||||
|
||||
# Build the path
|
||||
path = self.base_dir / project_id / task_id / attachment_id / str(version)
|
||||
|
||||
# Validate the final path is within base directory
|
||||
return self._validate_path_in_base_dir(
|
||||
path,
|
||||
f"project={project_id}, task={task_id}, attachment={attachment_id}, version={version}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def calculate_checksum(file: BinaryIO) -> str:
|
||||
@@ -89,6 +334,10 @@ class FileStorageService:
|
||||
"""
|
||||
Save uploaded file to storage.
|
||||
Returns (file_path, file_size, checksum).
|
||||
|
||||
Raises:
|
||||
PathTraversalError: If path traversal is detected
|
||||
HTTPException: If file validation fails
|
||||
"""
|
||||
# Validate file
|
||||
extension, _ = self.validate_file(file)
|
||||
@@ -96,14 +345,22 @@ class FileStorageService:
|
||||
# Calculate checksum first
|
||||
checksum = self.calculate_checksum(file.file)
|
||||
|
||||
# Create directory structure
|
||||
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
|
||||
# Create directory structure (path validation is done in _get_file_path)
|
||||
try:
|
||||
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
|
||||
except PathTraversalError as e:
|
||||
logger.error("Path traversal attempt during file save: %s", e)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save file with original extension
|
||||
filename = f"file.{extension}" if extension else "file"
|
||||
file_path = dir_path / filename
|
||||
|
||||
# Final validation of the file path
|
||||
self._validate_path_in_base_dir(file_path, f"saving file {filename}")
|
||||
|
||||
# Get file size
|
||||
file.file.seek(0, 2)
|
||||
file_size = file.file.tell()
|
||||
@@ -125,8 +382,15 @@ class FileStorageService:
|
||||
"""
|
||||
Get the file path for an attachment version.
|
||||
Returns None if file doesn't exist.
|
||||
|
||||
Raises:
|
||||
PathTraversalError: If path traversal is detected
|
||||
"""
|
||||
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
|
||||
try:
|
||||
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
|
||||
except PathTraversalError as e:
|
||||
logger.error("Path traversal attempt during file retrieval: %s", e)
|
||||
return None
|
||||
|
||||
if not dir_path.exists():
|
||||
return None
|
||||
@@ -139,21 +403,48 @@ class FileStorageService:
|
||||
return files[0]
|
||||
|
||||
def get_file_by_path(self, file_path: str) -> Optional[Path]:
|
||||
"""Get file by stored path. Handles both absolute and relative paths."""
|
||||
"""
|
||||
Get file by stored path. Handles both absolute and relative paths.
|
||||
|
||||
This method validates that the requested path is within the base directory
|
||||
to prevent path traversal attacks.
|
||||
|
||||
Args:
|
||||
file_path: The stored file path
|
||||
|
||||
Returns:
|
||||
Path object if file exists and is within base directory, None otherwise
|
||||
"""
|
||||
if not file_path:
|
||||
return None
|
||||
|
||||
path = Path(file_path)
|
||||
|
||||
# If path is absolute and exists, return it directly
|
||||
if path.is_absolute() and path.exists():
|
||||
return path
|
||||
# For absolute paths, validate they are within base_dir
|
||||
if path.is_absolute():
|
||||
try:
|
||||
validated_path = self._validate_path_in_base_dir(
|
||||
path,
|
||||
f"get_file_by_path absolute: {file_path}"
|
||||
)
|
||||
if validated_path.exists():
|
||||
return validated_path
|
||||
except PathTraversalError:
|
||||
return None
|
||||
return None
|
||||
|
||||
# If path is relative, try prepending base_dir
|
||||
# For relative paths, resolve from base_dir
|
||||
full_path = self.base_dir / path
|
||||
if full_path.exists():
|
||||
return full_path
|
||||
|
||||
# Fallback: check if original path exists (e.g., relative from current dir)
|
||||
if path.exists():
|
||||
return path
|
||||
try:
|
||||
validated_path = self._validate_path_in_base_dir(
|
||||
full_path,
|
||||
f"get_file_by_path relative: {file_path}"
|
||||
)
|
||||
if validated_path.exists():
|
||||
return validated_path
|
||||
except PathTraversalError:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@@ -168,13 +459,29 @@ class FileStorageService:
|
||||
Delete file(s) from storage.
|
||||
If version is None, deletes all versions.
|
||||
Returns True if successful.
|
||||
|
||||
Raises:
|
||||
PathTraversalError: If path traversal is detected
|
||||
"""
|
||||
if version is not None:
|
||||
# Delete specific version
|
||||
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
|
||||
else:
|
||||
# Delete all versions (attachment directory)
|
||||
dir_path = self.base_dir / project_id / task_id / attachment_id
|
||||
try:
|
||||
if version is not None:
|
||||
# Delete specific version
|
||||
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
|
||||
else:
|
||||
# Delete all versions (attachment directory)
|
||||
# Validate components first
|
||||
self._validate_path_component(project_id, "project_id")
|
||||
self._validate_path_component(task_id, "task_id")
|
||||
self._validate_path_component(attachment_id, "attachment_id")
|
||||
|
||||
dir_path = self.base_dir / project_id / task_id / attachment_id
|
||||
dir_path = self._validate_path_in_base_dir(
|
||||
dir_path,
|
||||
f"delete attachment: project={project_id}, task={task_id}, attachment={attachment_id}"
|
||||
)
|
||||
except PathTraversalError as e:
|
||||
logger.error("Path traversal attempt during file deletion: %s", e)
|
||||
return False
|
||||
|
||||
if dir_path.exists():
|
||||
shutil.rmtree(dir_path)
|
||||
@@ -182,8 +489,26 @@ class FileStorageService:
|
||||
return False
|
||||
|
||||
def delete_task_files(self, project_id: str, task_id: str) -> bool:
|
||||
"""Delete all files for a task."""
|
||||
dir_path = self.base_dir / project_id / task_id
|
||||
"""
|
||||
Delete all files for a task.
|
||||
|
||||
Raises:
|
||||
PathTraversalError: If path traversal is detected
|
||||
"""
|
||||
try:
|
||||
# Validate components
|
||||
self._validate_path_component(project_id, "project_id")
|
||||
self._validate_path_component(task_id, "task_id")
|
||||
|
||||
dir_path = self.base_dir / project_id / task_id
|
||||
dir_path = self._validate_path_in_base_dir(
|
||||
dir_path,
|
||||
f"delete task files: project={project_id}, task={task_id}"
|
||||
)
|
||||
except PathTraversalError as e:
|
||||
logger.error("Path traversal attempt during task file deletion: %s", e)
|
||||
return False
|
||||
|
||||
if dir_path.exists():
|
||||
shutil.rmtree(dir_path)
|
||||
return True
|
||||
|
||||
@@ -29,7 +29,17 @@ class FormulaError(Exception):
|
||||
|
||||
class CircularReferenceError(FormulaError):
|
||||
"""Exception raised when circular references are detected in formulas."""
|
||||
pass
|
||||
|
||||
def __init__(self, message: str, cycle_path: Optional[List[str]] = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.cycle_path = cycle_path or []
|
||||
|
||||
def get_cycle_description(self) -> str:
|
||||
"""Get a human-readable description of the cycle."""
|
||||
if not self.cycle_path:
|
||||
return ""
|
||||
return " -> ".join(self.cycle_path)
|
||||
|
||||
|
||||
class FormulaService:
|
||||
@@ -140,24 +150,43 @@ class FormulaService:
|
||||
field_id: str,
|
||||
references: Set[str],
|
||||
visited: Optional[Set[str]] = None,
|
||||
path: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Check for circular references in formula fields.
|
||||
|
||||
Raises CircularReferenceError if a cycle is detected.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
project_id: Project ID to scope the query
|
||||
field_id: The current field being validated
|
||||
references: Set of field names referenced in the formula
|
||||
visited: Set of visited field IDs (for cycle detection)
|
||||
path: Current path of field names (for error reporting)
|
||||
"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if path is None:
|
||||
path = []
|
||||
|
||||
# Get the current field's name
|
||||
current_field = db.query(CustomField).filter(
|
||||
CustomField.id == field_id
|
||||
).first()
|
||||
|
||||
current_field_name = current_field.name if current_field else "unknown"
|
||||
|
||||
# Add current field to path if not already there
|
||||
if current_field_name not in path:
|
||||
path = path + [current_field_name]
|
||||
|
||||
if current_field:
|
||||
if current_field.name in references:
|
||||
cycle_path = path + [current_field.name]
|
||||
raise CircularReferenceError(
|
||||
f"Circular reference detected: field cannot reference itself"
|
||||
f"Circular reference detected: field '{current_field.name}' cannot reference itself",
|
||||
cycle_path=cycle_path
|
||||
)
|
||||
|
||||
# Get all referenced formula fields
|
||||
@@ -173,22 +202,199 @@ class FormulaService:
|
||||
|
||||
for field in formula_fields:
|
||||
if field.id in visited:
|
||||
# Found a cycle
|
||||
cycle_path = path + [field.name]
|
||||
raise CircularReferenceError(
|
||||
f"Circular reference detected involving field '{field.name}'"
|
||||
f"Circular reference detected: {' -> '.join(cycle_path)}",
|
||||
cycle_path=cycle_path
|
||||
)
|
||||
|
||||
visited.add(field.id)
|
||||
new_path = path + [field.name]
|
||||
|
||||
if field.formula:
|
||||
nested_refs = FormulaService.extract_field_references(field.formula)
|
||||
if current_field and current_field.name in nested_refs:
|
||||
cycle_path = new_path + [current_field.name]
|
||||
raise CircularReferenceError(
|
||||
f"Circular reference detected: '{field.name}' references the current field"
|
||||
f"Circular reference detected: {' -> '.join(cycle_path)}",
|
||||
cycle_path=cycle_path
|
||||
)
|
||||
FormulaService._check_circular_references(
|
||||
db, project_id, field_id, nested_refs, visited
|
||||
db, project_id, field_id, nested_refs, visited, new_path
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_formula_dependency_graph(
|
||||
db: Session,
|
||||
project_id: str
|
||||
) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Build a dependency graph for all formula fields in a project.
|
||||
|
||||
Returns a dict where keys are field names and values are sets of
|
||||
field names that the key field depends on.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
project_id: Project ID to scope the query
|
||||
|
||||
Returns:
|
||||
Dict mapping field names to their dependencies
|
||||
"""
|
||||
graph: Dict[str, Set[str]] = {}
|
||||
|
||||
formula_fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == project_id,
|
||||
CustomField.field_type == "formula",
|
||||
).all()
|
||||
|
||||
for field in formula_fields:
|
||||
if field.formula:
|
||||
refs = FormulaService.extract_field_references(field.formula)
|
||||
# Only include custom field references (not builtin fields)
|
||||
custom_refs = refs - FormulaService.BUILTIN_FIELDS
|
||||
graph[field.name] = custom_refs
|
||||
else:
|
||||
graph[field.name] = set()
|
||||
|
||||
return graph
|
||||
|
||||
@staticmethod
|
||||
def detect_formula_cycles(
|
||||
db: Session,
|
||||
project_id: str
|
||||
) -> List[List[str]]:
|
||||
"""
|
||||
Detect all cycles in the formula dependency graph for a project.
|
||||
|
||||
This is useful for auditing or cleanup operations.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
project_id: Project ID to check
|
||||
|
||||
Returns:
|
||||
List of cycles, where each cycle is a list of field names
|
||||
"""
|
||||
graph = FormulaService.build_formula_dependency_graph(db, project_id)
|
||||
|
||||
if not graph:
|
||||
return []
|
||||
|
||||
cycles: List[List[str]] = []
|
||||
visited: Set[str] = set()
|
||||
found_cycles: Set[Tuple[str, ...]] = set()
|
||||
|
||||
def dfs(node: str, path: List[str], in_path: Set[str]):
|
||||
"""DFS to find cycles."""
|
||||
if node in in_path:
|
||||
# Found a cycle
|
||||
cycle_start = path.index(node)
|
||||
cycle = path[cycle_start:] + [node]
|
||||
# Normalize for deduplication
|
||||
normalized = tuple(sorted(cycle[:-1]))
|
||||
if normalized not in found_cycles:
|
||||
found_cycles.add(normalized)
|
||||
cycles.append(cycle)
|
||||
return
|
||||
|
||||
if node in visited:
|
||||
return
|
||||
|
||||
if node not in graph:
|
||||
return
|
||||
|
||||
visited.add(node)
|
||||
in_path.add(node)
|
||||
path.append(node)
|
||||
|
||||
for neighbor in graph.get(node, set()):
|
||||
dfs(neighbor, path.copy(), in_path.copy())
|
||||
|
||||
path.pop()
|
||||
in_path.discard(node)
|
||||
|
||||
for start_node in graph.keys():
|
||||
if start_node not in visited:
|
||||
dfs(start_node, [], set())
|
||||
|
||||
return cycles
|
||||
|
||||
@staticmethod
|
||||
def validate_formula_with_details(
|
||||
formula: str,
|
||||
project_id: str,
|
||||
db: Session,
|
||||
current_field_id: Optional[str] = None,
|
||||
) -> Tuple[bool, Optional[str], Optional[List[str]]]:
|
||||
"""
|
||||
Validate a formula expression with detailed error information.
|
||||
|
||||
Similar to validate_formula but returns cycle path on circular reference errors.
|
||||
|
||||
Args:
|
||||
formula: The formula expression to validate
|
||||
project_id: Project ID to scope field lookups
|
||||
db: Database session
|
||||
current_field_id: Optional ID of the field being edited (for self-reference check)
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message, cycle_path)
|
||||
"""
|
||||
if not formula or not formula.strip():
|
||||
return False, "Formula cannot be empty", None
|
||||
|
||||
# Extract field references
|
||||
references = FormulaService.extract_field_references(formula)
|
||||
|
||||
if not references:
|
||||
return False, "Formula must reference at least one field", None
|
||||
|
||||
# Validate syntax by trying to parse
|
||||
try:
|
||||
# Replace field references with dummy numbers for syntax check
|
||||
test_formula = formula
|
||||
for ref in references:
|
||||
test_formula = test_formula.replace(f"{{{ref}}}", "1")
|
||||
|
||||
# Try to parse and evaluate with safe operations
|
||||
FormulaService._safe_eval(test_formula)
|
||||
except Exception as e:
|
||||
return False, f"Invalid formula syntax: {str(e)}", None
|
||||
|
||||
# Separate builtin and custom field references
|
||||
custom_references = references - FormulaService.BUILTIN_FIELDS
|
||||
|
||||
# Validate custom field references exist and are numeric types
|
||||
if custom_references:
|
||||
fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == project_id,
|
||||
CustomField.name.in_(custom_references),
|
||||
).all()
|
||||
|
||||
found_names = {f.name for f in fields}
|
||||
missing = custom_references - found_names
|
||||
|
||||
if missing:
|
||||
return False, f"Unknown field references: {', '.join(missing)}", None
|
||||
|
||||
# Check field types (must be number or formula)
|
||||
for field in fields:
|
||||
if field.field_type not in ("number", "formula"):
|
||||
return False, f"Field '{field.name}' is not a numeric type", None
|
||||
|
||||
# Check for circular references
|
||||
if current_field_id:
|
||||
try:
|
||||
FormulaService._check_circular_references(
|
||||
db, project_id, current_field_id, references
|
||||
)
|
||||
except CircularReferenceError as e:
|
||||
return False, str(e), e.cycle_path
|
||||
|
||||
return True, None, None
|
||||
|
||||
@staticmethod
|
||||
def _safe_eval(expression: str) -> Decimal:
|
||||
"""
|
||||
|
||||
@@ -4,8 +4,10 @@ import re
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional, Dict, Set
|
||||
from collections import deque
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import event
|
||||
|
||||
@@ -22,9 +24,152 @@ _pending_publish: Dict[int, List[dict]] = {}
|
||||
# Track which sessions have handlers registered
|
||||
_registered_sessions: Set[int] = set()
|
||||
|
||||
# Redis fallback queue configuration
|
||||
REDIS_FALLBACK_MAX_QUEUE_SIZE = int(os.getenv("REDIS_FALLBACK_MAX_QUEUE_SIZE", "1000"))
|
||||
REDIS_FALLBACK_RETRY_INTERVAL = int(os.getenv("REDIS_FALLBACK_RETRY_INTERVAL", "5")) # seconds
|
||||
REDIS_FALLBACK_MAX_RETRIES = int(os.getenv("REDIS_FALLBACK_MAX_RETRIES", "10"))
|
||||
|
||||
# Redis fallback queue for failed publishes
|
||||
_redis_fallback_lock = threading.Lock()
|
||||
_redis_fallback_queue: deque = deque(maxlen=REDIS_FALLBACK_MAX_QUEUE_SIZE)
|
||||
_redis_retry_timer: Optional[threading.Timer] = None
|
||||
_redis_available = True
|
||||
_redis_consecutive_failures = 0
|
||||
|
||||
|
||||
def _add_to_fallback_queue(user_id: str, data: dict, retry_count: int = 0) -> bool:
|
||||
"""
|
||||
Add a failed notification to the fallback queue.
|
||||
|
||||
Returns True if added successfully, False if queue is full.
|
||||
"""
|
||||
global _redis_consecutive_failures
|
||||
|
||||
with _redis_fallback_lock:
|
||||
if len(_redis_fallback_queue) >= REDIS_FALLBACK_MAX_QUEUE_SIZE:
|
||||
logger.warning(
|
||||
"Redis fallback queue is full (%d items), dropping notification for user %s",
|
||||
REDIS_FALLBACK_MAX_QUEUE_SIZE, user_id
|
||||
)
|
||||
return False
|
||||
|
||||
_redis_fallback_queue.append({
|
||||
"user_id": user_id,
|
||||
"data": data,
|
||||
"retry_count": retry_count,
|
||||
"queued_at": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
_redis_consecutive_failures += 1
|
||||
|
||||
queue_size = len(_redis_fallback_queue)
|
||||
logger.debug("Added notification to fallback queue (size: %d)", queue_size)
|
||||
|
||||
# Start retry mechanism if not already running
|
||||
_ensure_retry_timer_running()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _ensure_retry_timer_running():
|
||||
"""Ensure the retry timer is running if there are items in the queue."""
|
||||
global _redis_retry_timer
|
||||
|
||||
if _redis_retry_timer is None or not _redis_retry_timer.is_alive():
|
||||
_redis_retry_timer = threading.Timer(REDIS_FALLBACK_RETRY_INTERVAL, _process_fallback_queue)
|
||||
_redis_retry_timer.daemon = True
|
||||
_redis_retry_timer.start()
|
||||
|
||||
|
||||
def _process_fallback_queue():
|
||||
"""Process the fallback queue and retry sending notifications to Redis."""
|
||||
global _redis_available, _redis_consecutive_failures, _redis_retry_timer
|
||||
|
||||
items_to_retry = []
|
||||
|
||||
with _redis_fallback_lock:
|
||||
# Get all items from queue
|
||||
while _redis_fallback_queue:
|
||||
items_to_retry.append(_redis_fallback_queue.popleft())
|
||||
|
||||
if not items_to_retry:
|
||||
_redis_retry_timer = None
|
||||
return
|
||||
|
||||
logger.info("Processing %d items from Redis fallback queue", len(items_to_retry))
|
||||
|
||||
failed_items = []
|
||||
success_count = 0
|
||||
|
||||
for item in items_to_retry:
|
||||
user_id = item["user_id"]
|
||||
data = item["data"]
|
||||
retry_count = item["retry_count"]
|
||||
|
||||
if retry_count >= REDIS_FALLBACK_MAX_RETRIES:
|
||||
logger.warning(
|
||||
"Notification for user %s exceeded max retries (%d), dropping",
|
||||
user_id, REDIS_FALLBACK_MAX_RETRIES
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
redis_client = get_redis_sync()
|
||||
channel = get_channel_name(user_id)
|
||||
message = json.dumps(data, default=str)
|
||||
redis_client.publish(channel, message)
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
logger.debug("Retry failed for user %s: %s", user_id, e)
|
||||
failed_items.append({
|
||||
**item,
|
||||
"retry_count": retry_count + 1,
|
||||
})
|
||||
|
||||
# Re-queue failed items
|
||||
if failed_items:
|
||||
with _redis_fallback_lock:
|
||||
for item in failed_items:
|
||||
if len(_redis_fallback_queue) < REDIS_FALLBACK_MAX_QUEUE_SIZE:
|
||||
_redis_fallback_queue.append(item)
|
||||
|
||||
# Log recovery if we had successes
|
||||
if success_count > 0:
|
||||
with _redis_fallback_lock:
|
||||
_redis_consecutive_failures = 0
|
||||
if not _redis_fallback_queue:
|
||||
_redis_available = True
|
||||
logger.info(
|
||||
"Redis connection recovered. Successfully processed %d notifications from fallback queue",
|
||||
success_count
|
||||
)
|
||||
|
||||
# Schedule next retry if queue is not empty
|
||||
with _redis_fallback_lock:
|
||||
if _redis_fallback_queue:
|
||||
_redis_retry_timer = threading.Timer(REDIS_FALLBACK_RETRY_INTERVAL, _process_fallback_queue)
|
||||
_redis_retry_timer.daemon = True
|
||||
_redis_retry_timer.start()
|
||||
else:
|
||||
_redis_retry_timer = None
|
||||
|
||||
|
||||
def get_redis_fallback_status() -> dict:
|
||||
"""Get current Redis fallback queue status for health checks."""
|
||||
with _redis_fallback_lock:
|
||||
return {
|
||||
"queue_size": len(_redis_fallback_queue),
|
||||
"max_queue_size": REDIS_FALLBACK_MAX_QUEUE_SIZE,
|
||||
"redis_available": _redis_available,
|
||||
"consecutive_failures": _redis_consecutive_failures,
|
||||
"retry_interval_seconds": REDIS_FALLBACK_RETRY_INTERVAL,
|
||||
"max_retries": REDIS_FALLBACK_MAX_RETRIES,
|
||||
}
|
||||
|
||||
|
||||
def _sync_publish(user_id: str, data: dict):
|
||||
"""Sync fallback to publish notification via Redis when no event loop available."""
|
||||
global _redis_available
|
||||
|
||||
try:
|
||||
redis_client = get_redis_sync()
|
||||
channel = get_channel_name(user_id)
|
||||
@@ -33,6 +178,10 @@ def _sync_publish(user_id: str, data: dict):
|
||||
logger.debug(f"Sync published notification to channel {channel}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync publish notification to Redis: {e}")
|
||||
# Add to fallback queue for retry
|
||||
with _redis_fallback_lock:
|
||||
_redis_available = False
|
||||
_add_to_fallback_queue(user_id, data)
|
||||
|
||||
|
||||
def _cleanup_session(session_id: int, remove_registration: bool = True):
|
||||
@@ -86,10 +235,16 @@ def _register_session_handlers(db: Session, session_id: int):
|
||||
|
||||
async def _async_publish(user_id: str, data: dict):
|
||||
"""Async helper to publish notification to Redis."""
|
||||
global _redis_available
|
||||
|
||||
try:
|
||||
await redis_publish(user_id, data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to publish notification to Redis: {e}")
|
||||
# Add to fallback queue for retry
|
||||
with _redis_fallback_lock:
|
||||
_redis_available = False
|
||||
_add_to_fallback_queue(user_id, data)
|
||||
|
||||
|
||||
class NotificationService:
|
||||
|
||||
@@ -7,6 +7,7 @@ scheduled triggers based on their cron schedule, including deadline reminders.
|
||||
|
||||
import uuid
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional, List, Dict, Any, Tuple, Set
|
||||
|
||||
@@ -22,6 +23,10 @@ logger = logging.getLogger(__name__)
|
||||
# Key prefix for tracking deadline reminders already sent
|
||||
DEADLINE_REMINDER_LOG_TYPE = "deadline_reminder"
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES = 3
|
||||
BASE_DELAY_SECONDS = 1 # 1s, 2s, 4s exponential backoff
|
||||
|
||||
|
||||
class TriggerSchedulerService:
|
||||
"""Service for scheduling and executing cron-based triggers."""
|
||||
@@ -220,50 +225,170 @@ class TriggerSchedulerService:
|
||||
@staticmethod
|
||||
def _execute_trigger(db: Session, trigger: Trigger) -> TriggerLog:
|
||||
"""
|
||||
Execute a scheduled trigger's actions.
|
||||
Execute a scheduled trigger's actions with retry mechanism.
|
||||
|
||||
Implements exponential backoff retry (1s, 2s, 4s) for transient failures.
|
||||
After max retries are exhausted, marks as permanently failed and sends alert.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
trigger: The trigger to execute
|
||||
|
||||
Returns:
|
||||
TriggerLog entry for this execution
|
||||
"""
|
||||
return TriggerSchedulerService._execute_trigger_with_retry(
|
||||
db=db,
|
||||
trigger=trigger,
|
||||
task_id=None,
|
||||
log_type="schedule",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _execute_trigger_with_retry(
|
||||
db: Session,
|
||||
trigger: Trigger,
|
||||
task_id: Optional[str] = None,
|
||||
log_type: str = "schedule",
|
||||
) -> TriggerLog:
|
||||
"""
|
||||
Execute trigger actions with exponential backoff retry.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
trigger: The trigger to execute
|
||||
task_id: Optional task ID for context (deadline reminders)
|
||||
log_type: Type of trigger execution for logging
|
||||
|
||||
Returns:
|
||||
TriggerLog entry for this execution
|
||||
"""
|
||||
actions = trigger.actions if isinstance(trigger.actions, list) else [trigger.actions]
|
||||
executed_actions = []
|
||||
error_message = None
|
||||
last_error = None
|
||||
attempt = 0
|
||||
|
||||
try:
|
||||
for action in actions:
|
||||
action_type = action.get("type")
|
||||
while attempt < MAX_RETRIES:
|
||||
attempt += 1
|
||||
executed_actions = []
|
||||
last_error = None
|
||||
|
||||
if action_type == "notify":
|
||||
TriggerSchedulerService._execute_notify_action(db, action, trigger)
|
||||
executed_actions.append({"type": action_type, "status": "success"})
|
||||
try:
|
||||
logger.info(
|
||||
f"Executing trigger {trigger.id} (attempt {attempt}/{MAX_RETRIES})"
|
||||
)
|
||||
|
||||
# Add more action types here as needed
|
||||
for action in actions:
|
||||
action_type = action.get("type")
|
||||
|
||||
status = "success"
|
||||
if action_type == "notify":
|
||||
TriggerSchedulerService._execute_notify_action(db, action, trigger)
|
||||
executed_actions.append({"type": action_type, "status": "success"})
|
||||
|
||||
except Exception as e:
|
||||
status = "failed"
|
||||
error_message = str(e)
|
||||
executed_actions.append({"type": "error", "message": str(e)})
|
||||
logger.error(f"Error executing trigger {trigger.id} actions: {e}")
|
||||
# Add more action types here as needed
|
||||
|
||||
# Success - return log
|
||||
logger.info(f"Trigger {trigger.id} executed successfully on attempt {attempt}")
|
||||
return TriggerSchedulerService._log_execution(
|
||||
db=db,
|
||||
trigger=trigger,
|
||||
status="success",
|
||||
details={
|
||||
"trigger_name": trigger.name,
|
||||
"trigger_type": log_type,
|
||||
"cron_expression": trigger.conditions.get("cron_expression") if trigger.conditions else None,
|
||||
"actions_executed": executed_actions,
|
||||
"attempts": attempt,
|
||||
},
|
||||
error_message=None,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
executed_actions.append({"type": "error", "message": str(e)})
|
||||
logger.warning(
|
||||
f"Trigger {trigger.id} failed on attempt {attempt}/{MAX_RETRIES}: {e}"
|
||||
)
|
||||
|
||||
# Calculate exponential backoff delay
|
||||
if attempt < MAX_RETRIES:
|
||||
delay = BASE_DELAY_SECONDS * (2 ** (attempt - 1))
|
||||
logger.info(f"Retrying trigger {trigger.id} in {delay}s...")
|
||||
time.sleep(delay)
|
||||
|
||||
# All retries exhausted - permanent failure
|
||||
logger.error(
|
||||
f"Trigger {trigger.id} permanently failed after {MAX_RETRIES} attempts: {last_error}"
|
||||
)
|
||||
|
||||
# Send alert notification for permanent failure
|
||||
TriggerSchedulerService._send_failure_alert(db, trigger, str(last_error), MAX_RETRIES)
|
||||
|
||||
return TriggerSchedulerService._log_execution(
|
||||
db=db,
|
||||
trigger=trigger,
|
||||
status=status,
|
||||
status="permanently_failed",
|
||||
details={
|
||||
"trigger_name": trigger.name,
|
||||
"trigger_type": "schedule",
|
||||
"cron_expression": trigger.conditions.get("cron_expression"),
|
||||
"trigger_type": log_type,
|
||||
"cron_expression": trigger.conditions.get("cron_expression") if trigger.conditions else None,
|
||||
"actions_executed": executed_actions,
|
||||
"attempts": MAX_RETRIES,
|
||||
"permanent_failure": True,
|
||||
},
|
||||
error_message=error_message,
|
||||
error_message=f"Failed after {MAX_RETRIES} retries: {last_error}",
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _send_failure_alert(
|
||||
db: Session,
|
||||
trigger: Trigger,
|
||||
error_message: str,
|
||||
attempts: int,
|
||||
) -> None:
|
||||
"""
|
||||
Send alert notification when trigger exhausts all retries.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
trigger: The failed trigger
|
||||
error_message: The last error message
|
||||
attempts: Number of attempts made
|
||||
"""
|
||||
try:
|
||||
# Notify the project owner about the failure
|
||||
project = trigger.project
|
||||
if not project:
|
||||
logger.warning(f"Cannot send failure alert: trigger {trigger.id} has no project")
|
||||
return
|
||||
|
||||
target_user_id = project.owner_id
|
||||
if not target_user_id:
|
||||
logger.warning(f"Cannot send failure alert: project {project.id} has no owner")
|
||||
return
|
||||
|
||||
message = (
|
||||
f"Trigger '{trigger.name}' has permanently failed after {attempts} attempts. "
|
||||
f"Last error: {error_message}"
|
||||
)
|
||||
|
||||
NotificationService.create_notification(
|
||||
db=db,
|
||||
user_id=target_user_id,
|
||||
notification_type="trigger_failure",
|
||||
reference_type="trigger",
|
||||
reference_id=trigger.id,
|
||||
title=f"Trigger Failed: {trigger.name}",
|
||||
message=message,
|
||||
)
|
||||
|
||||
logger.info(f"Sent failure alert for trigger {trigger.id} to user {target_user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send failure alert for trigger {trigger.id}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _execute_notify_action(db: Session, action: Dict[str, Any], trigger: Trigger) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user