feat: implement custom fields, gantt view, calendar view, and file encryption
- Custom Fields (FEAT-001): - CustomField and TaskCustomValue models with formula support - CRUD API for custom field management - Formula engine for calculated fields - Frontend: CustomFieldEditor, CustomFieldInput, ProjectSettings page - Task list API now includes custom_values - KanbanBoard displays custom field values - Gantt View (FEAT-003): - TaskDependency model with FS/SS/FF/SF dependency types - Dependency CRUD API with cycle detection - start_date field added to tasks - GanttChart component with Frappe Gantt integration - Dependency type selector in UI - Calendar View (FEAT-004): - CalendarView component with FullCalendar integration - Date range filtering API for tasks - Drag-and-drop date updates - View mode switching in Tasks page - File Encryption (FEAT-010): - AES-256-GCM encryption service - EncryptionKey model with key rotation support - Admin API for key management - Encrypted upload/download for confidential projects - Migrations: 011 (custom fields), 012 (encryption keys), 013 (task dependencies) - Updated issues.md with completion status 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
278
backend/app/services/custom_value_service.py
Normal file
278
backend/app/services/custom_value_service.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Service for managing task custom values.
|
||||
"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Task, CustomField, TaskCustomValue, User
|
||||
from app.schemas.task import CustomValueInput, CustomValueResponse
|
||||
from app.services.formula_service import FormulaService
|
||||
|
||||
|
||||
class CustomValueService:
|
||||
"""Service for managing custom field values on tasks."""
|
||||
|
||||
@staticmethod
|
||||
def get_custom_values_for_task(
|
||||
db: Session,
|
||||
task: Task,
|
||||
include_formula_calculations: bool = True,
|
||||
) -> List[CustomValueResponse]:
|
||||
"""
|
||||
Get all custom field values for a task.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
task: The task to get values for
|
||||
include_formula_calculations: Whether to calculate formula field values
|
||||
|
||||
Returns:
|
||||
List of CustomValueResponse objects
|
||||
"""
|
||||
# Get all custom fields for the project
|
||||
fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == task.project_id
|
||||
).order_by(CustomField.position).all()
|
||||
|
||||
if not fields:
|
||||
return []
|
||||
|
||||
# Get stored values
|
||||
stored_values = db.query(TaskCustomValue).filter(
|
||||
TaskCustomValue.task_id == task.id
|
||||
).all()
|
||||
|
||||
value_map = {v.field_id: v.value for v in stored_values}
|
||||
|
||||
# Calculate formula values if requested
|
||||
formula_values = {}
|
||||
if include_formula_calculations:
|
||||
formula_values = FormulaService.calculate_all_formulas_for_task(db, task)
|
||||
|
||||
result = []
|
||||
for field in fields:
|
||||
if field.field_type == "formula":
|
||||
# Use calculated formula value
|
||||
calculated = formula_values.get(field.id)
|
||||
value = str(calculated) if calculated is not None else None
|
||||
display_value = CustomValueService._format_display_value(
|
||||
field, value, db
|
||||
)
|
||||
else:
|
||||
# Use stored value
|
||||
value = value_map.get(field.id)
|
||||
display_value = CustomValueService._format_display_value(
|
||||
field, value, db
|
||||
)
|
||||
|
||||
result.append(CustomValueResponse(
|
||||
field_id=field.id,
|
||||
field_name=field.name,
|
||||
field_type=field.field_type,
|
||||
value=value,
|
||||
display_value=display_value,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _format_display_value(
|
||||
field: CustomField,
|
||||
value: Optional[str],
|
||||
db: Session,
|
||||
) -> Optional[str]:
|
||||
"""Format a value for display based on field type."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
field_type = field.field_type
|
||||
|
||||
if field_type == "person":
|
||||
# Look up user name
|
||||
from app.models import User
|
||||
user = db.query(User).filter(User.id == value).first()
|
||||
return user.name if user else value
|
||||
|
||||
elif field_type == "number" or field_type == "formula":
|
||||
# Format number
|
||||
try:
|
||||
num = Decimal(value)
|
||||
# Remove trailing zeros after decimal point
|
||||
formatted = f"{num:,.4f}".rstrip('0').rstrip('.')
|
||||
return formatted
|
||||
except (InvalidOperation, ValueError):
|
||||
return value
|
||||
|
||||
elif field_type == "date":
|
||||
# Format date
|
||||
try:
|
||||
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
|
||||
return dt.strftime('%Y-%m-%d')
|
||||
except (ValueError, AttributeError):
|
||||
return value
|
||||
|
||||
else:
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def save_custom_values(
|
||||
db: Session,
|
||||
task: Task,
|
||||
custom_values: List[CustomValueInput],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Save custom field values for a task.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
task: The task to save values for
|
||||
custom_values: List of values to save
|
||||
|
||||
Returns:
|
||||
List of field IDs that were updated (for formula recalculation)
|
||||
"""
|
||||
if not custom_values:
|
||||
return []
|
||||
|
||||
updated_field_ids = []
|
||||
|
||||
for cv in custom_values:
|
||||
field = db.query(CustomField).filter(
|
||||
CustomField.id == cv.field_id,
|
||||
CustomField.project_id == task.project_id,
|
||||
).first()
|
||||
|
||||
if not field:
|
||||
continue
|
||||
|
||||
# Skip formula fields - they are calculated, not stored directly
|
||||
if field.field_type == "formula":
|
||||
continue
|
||||
|
||||
# Validate value based on field type
|
||||
validated_value = CustomValueService._validate_value(
|
||||
field, cv.value, db
|
||||
)
|
||||
|
||||
# Find existing value or create new
|
||||
existing = db.query(TaskCustomValue).filter(
|
||||
TaskCustomValue.task_id == task.id,
|
||||
TaskCustomValue.field_id == cv.field_id,
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
if existing.value != validated_value:
|
||||
existing.value = validated_value
|
||||
updated_field_ids.append(cv.field_id)
|
||||
else:
|
||||
new_value = TaskCustomValue(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=task.id,
|
||||
field_id=cv.field_id,
|
||||
value=validated_value,
|
||||
)
|
||||
db.add(new_value)
|
||||
updated_field_ids.append(cv.field_id)
|
||||
|
||||
# Recalculate formula fields if any values were updated
|
||||
if updated_field_ids:
|
||||
for field_id in updated_field_ids:
|
||||
FormulaService.recalculate_dependent_formulas(db, task, field_id)
|
||||
|
||||
return updated_field_ids
|
||||
|
||||
@staticmethod
|
||||
def _validate_value(
|
||||
field: CustomField,
|
||||
value: Any,
|
||||
db: Session,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Validate and normalize a value based on field type.
|
||||
|
||||
Returns the validated value as a string, or None.
|
||||
"""
|
||||
if value is None or value == "":
|
||||
if field.is_required:
|
||||
raise ValueError(f"Field '{field.name}' is required")
|
||||
return None
|
||||
|
||||
field_type = field.field_type
|
||||
str_value = str(value)
|
||||
|
||||
if field_type == "text":
|
||||
return str_value
|
||||
|
||||
elif field_type == "number":
|
||||
try:
|
||||
Decimal(str_value)
|
||||
return str_value
|
||||
except (InvalidOperation, ValueError):
|
||||
raise ValueError(f"Invalid number for field '{field.name}'")
|
||||
|
||||
elif field_type == "dropdown":
|
||||
if field.options and str_value not in field.options:
|
||||
raise ValueError(
|
||||
f"Invalid option for field '{field.name}'. "
|
||||
f"Must be one of: {', '.join(field.options)}"
|
||||
)
|
||||
return str_value
|
||||
|
||||
elif field_type == "date":
|
||||
# Validate date format
|
||||
try:
|
||||
datetime.fromisoformat(str_value.replace('Z', '+00:00'))
|
||||
return str_value
|
||||
except ValueError:
|
||||
# Try parsing as date only
|
||||
try:
|
||||
datetime.strptime(str_value, '%Y-%m-%d')
|
||||
return str_value
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date for field '{field.name}'")
|
||||
|
||||
elif field_type == "person":
|
||||
# Validate user exists
|
||||
from app.models import User
|
||||
user = db.query(User).filter(User.id == str_value).first()
|
||||
if not user:
|
||||
raise ValueError(f"Invalid user ID for field '{field.name}'")
|
||||
return str_value
|
||||
|
||||
return str_value
|
||||
|
||||
@staticmethod
|
||||
def validate_required_fields(
|
||||
db: Session,
|
||||
project_id: str,
|
||||
custom_values: Optional[List[CustomValueInput]],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Validate that all required custom fields have values.
|
||||
|
||||
Returns list of missing required field names.
|
||||
"""
|
||||
required_fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == project_id,
|
||||
CustomField.is_required == True,
|
||||
CustomField.field_type != "formula", # Formula fields are calculated
|
||||
).all()
|
||||
|
||||
if not required_fields:
|
||||
return []
|
||||
|
||||
provided_field_ids = set()
|
||||
if custom_values:
|
||||
for cv in custom_values:
|
||||
if cv.value is not None and cv.value != "":
|
||||
provided_field_ids.add(cv.field_id)
|
||||
|
||||
missing = []
|
||||
for field in required_fields:
|
||||
if field.id not in provided_field_ids:
|
||||
missing.append(field.name)
|
||||
|
||||
return missing
|
||||
424
backend/app/services/dependency_service.py
Normal file
424
backend/app/services/dependency_service.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
Dependency Service
|
||||
|
||||
Handles task dependency validation including:
|
||||
- Circular dependency detection using DFS
|
||||
- Date constraint validation based on dependency types
|
||||
- Self-reference prevention
|
||||
- Cross-project dependency prevention
|
||||
"""
|
||||
from typing import List, Optional, Set, Tuple, Dict, Any
|
||||
from collections import defaultdict
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.models import Task, TaskDependency
|
||||
|
||||
|
||||
class DependencyValidationError(Exception):
|
||||
"""Custom exception for dependency validation errors."""
|
||||
|
||||
def __init__(self, error_type: str, message: str, details: Optional[dict] = None):
|
||||
self.error_type = error_type
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class DependencyService:
|
||||
"""Service for managing task dependencies with validation."""
|
||||
|
||||
# Maximum number of direct dependencies per task (as per spec)
|
||||
MAX_DIRECT_DEPENDENCIES = 10
|
||||
|
||||
@staticmethod
|
||||
def detect_circular_dependency(
|
||||
db: Session,
|
||||
predecessor_id: str,
|
||||
successor_id: str,
|
||||
project_id: str
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Detect if adding a dependency would create a circular reference.
|
||||
|
||||
Uses DFS to traverse from the successor to check if we can reach
|
||||
the predecessor through existing dependencies.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
predecessor_id: The task that must complete first
|
||||
successor_id: The task that depends on the predecessor
|
||||
project_id: Project ID to scope the query
|
||||
|
||||
Returns:
|
||||
List of task IDs forming the cycle if circular, None otherwise
|
||||
"""
|
||||
# If adding predecessor -> successor, check if successor can reach predecessor
|
||||
# This would mean predecessor depends (transitively) on successor, creating a cycle
|
||||
|
||||
# Build adjacency list for the project's dependencies
|
||||
dependencies = db.query(TaskDependency).join(
|
||||
Task, TaskDependency.successor_id == Task.id
|
||||
).filter(Task.project_id == project_id).all()
|
||||
|
||||
# Graph: successor -> [predecessors]
|
||||
# We need to check if predecessor is reachable from successor
|
||||
# by following the chain of "what does this task depend on"
|
||||
graph: Dict[str, List[str]] = defaultdict(list)
|
||||
for dep in dependencies:
|
||||
graph[dep.successor_id].append(dep.predecessor_id)
|
||||
|
||||
# Simulate adding the new edge
|
||||
graph[successor_id].append(predecessor_id)
|
||||
|
||||
# DFS to find if there's a path from predecessor back to successor
|
||||
# (which would complete a cycle)
|
||||
visited: Set[str] = set()
|
||||
path: List[str] = []
|
||||
in_path: Set[str] = set()
|
||||
|
||||
def dfs(node: str) -> Optional[List[str]]:
|
||||
"""DFS traversal to detect cycles."""
|
||||
if node in in_path:
|
||||
# Found a cycle - return the cycle path
|
||||
cycle_start = path.index(node)
|
||||
return path[cycle_start:] + [node]
|
||||
|
||||
if node in visited:
|
||||
return None
|
||||
|
||||
visited.add(node)
|
||||
in_path.add(node)
|
||||
path.append(node)
|
||||
|
||||
for neighbor in graph.get(node, []):
|
||||
result = dfs(neighbor)
|
||||
if result:
|
||||
return result
|
||||
|
||||
path.pop()
|
||||
in_path.remove(node)
|
||||
return None
|
||||
|
||||
# Start DFS from the successor to check if we can reach back to it
|
||||
return dfs(successor_id)
|
||||
|
||||
@staticmethod
|
||||
def validate_dependency(
|
||||
db: Session,
|
||||
predecessor_id: str,
|
||||
successor_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a dependency can be created.
|
||||
|
||||
Raises DependencyValidationError if validation fails.
|
||||
|
||||
Checks:
|
||||
1. Self-reference
|
||||
2. Both tasks exist
|
||||
3. Both tasks are in the same project
|
||||
4. No duplicate dependency
|
||||
5. No circular dependency
|
||||
6. Dependency limit not exceeded
|
||||
"""
|
||||
# Check self-reference
|
||||
if predecessor_id == successor_id:
|
||||
raise DependencyValidationError(
|
||||
error_type="self_reference",
|
||||
message="A task cannot depend on itself"
|
||||
)
|
||||
|
||||
# Get both tasks
|
||||
predecessor = db.query(Task).filter(Task.id == predecessor_id).first()
|
||||
successor = db.query(Task).filter(Task.id == successor_id).first()
|
||||
|
||||
if not predecessor:
|
||||
raise DependencyValidationError(
|
||||
error_type="not_found",
|
||||
message="Predecessor task not found",
|
||||
details={"task_id": predecessor_id}
|
||||
)
|
||||
|
||||
if not successor:
|
||||
raise DependencyValidationError(
|
||||
error_type="not_found",
|
||||
message="Successor task not found",
|
||||
details={"task_id": successor_id}
|
||||
)
|
||||
|
||||
# Check same project
|
||||
if predecessor.project_id != successor.project_id:
|
||||
raise DependencyValidationError(
|
||||
error_type="cross_project",
|
||||
message="Dependencies can only be created between tasks in the same project",
|
||||
details={
|
||||
"predecessor_project_id": predecessor.project_id,
|
||||
"successor_project_id": successor.project_id
|
||||
}
|
||||
)
|
||||
|
||||
# Check duplicate
|
||||
existing = db.query(TaskDependency).filter(
|
||||
TaskDependency.predecessor_id == predecessor_id,
|
||||
TaskDependency.successor_id == successor_id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise DependencyValidationError(
|
||||
error_type="duplicate",
|
||||
message="This dependency already exists"
|
||||
)
|
||||
|
||||
# Check dependency limit
|
||||
current_count = db.query(TaskDependency).filter(
|
||||
TaskDependency.successor_id == successor_id
|
||||
).count()
|
||||
|
||||
if current_count >= DependencyService.MAX_DIRECT_DEPENDENCIES:
|
||||
raise DependencyValidationError(
|
||||
error_type="limit_exceeded",
|
||||
message=f"A task cannot have more than {DependencyService.MAX_DIRECT_DEPENDENCIES} direct dependencies",
|
||||
details={"current_count": current_count}
|
||||
)
|
||||
|
||||
# Check circular dependency
|
||||
cycle = DependencyService.detect_circular_dependency(
|
||||
db, predecessor_id, successor_id, predecessor.project_id
|
||||
)
|
||||
|
||||
if cycle:
|
||||
raise DependencyValidationError(
|
||||
error_type="circular",
|
||||
message="Adding this dependency would create a circular reference",
|
||||
details={"cycle": cycle}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_date_constraints(
|
||||
task: Task,
|
||||
start_date: Optional[datetime],
|
||||
due_date: Optional[datetime],
|
||||
db: Session
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Validate date changes against dependency constraints.
|
||||
|
||||
Returns a list of constraint violations (empty if valid).
|
||||
|
||||
Dependency type meanings:
|
||||
- FS: predecessor.due_date + lag <= successor.start_date
|
||||
- SS: predecessor.start_date + lag <= successor.start_date
|
||||
- FF: predecessor.due_date + lag <= successor.due_date
|
||||
- SF: predecessor.start_date + lag <= successor.due_date
|
||||
"""
|
||||
violations = []
|
||||
|
||||
# Use provided dates or fall back to current task dates
|
||||
new_start = start_date if start_date is not None else task.start_date
|
||||
new_due = due_date if due_date is not None else task.due_date
|
||||
|
||||
# Basic date validation
|
||||
if new_start and new_due and new_start > new_due:
|
||||
violations.append({
|
||||
"type": "date_order",
|
||||
"message": "Start date cannot be after due date",
|
||||
"start_date": str(new_start),
|
||||
"due_date": str(new_due)
|
||||
})
|
||||
|
||||
# Get dependencies where this task is the successor (predecessors)
|
||||
predecessors = db.query(TaskDependency).filter(
|
||||
TaskDependency.successor_id == task.id
|
||||
).all()
|
||||
|
||||
for dep in predecessors:
|
||||
pred_task = dep.predecessor
|
||||
if not pred_task:
|
||||
continue
|
||||
|
||||
lag = timedelta(days=dep.lag_days)
|
||||
violation = None
|
||||
|
||||
if dep.dependency_type == "FS":
|
||||
# Predecessor must finish before successor starts
|
||||
if pred_task.due_date and new_start:
|
||||
required_start = pred_task.due_date + lag
|
||||
if new_start < required_start:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "FS",
|
||||
"predecessor_id": pred_task.id,
|
||||
"predecessor_title": pred_task.title,
|
||||
"message": f"Start date must be on or after {required_start.date()} (predecessor due date + {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
elif dep.dependency_type == "SS":
|
||||
# Predecessor must start before successor starts
|
||||
if pred_task.start_date and new_start:
|
||||
required_start = pred_task.start_date + lag
|
||||
if new_start < required_start:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "SS",
|
||||
"predecessor_id": pred_task.id,
|
||||
"predecessor_title": pred_task.title,
|
||||
"message": f"Start date must be on or after {required_start.date()} (predecessor start date + {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
elif dep.dependency_type == "FF":
|
||||
# Predecessor must finish before successor finishes
|
||||
if pred_task.due_date and new_due:
|
||||
required_due = pred_task.due_date + lag
|
||||
if new_due < required_due:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "FF",
|
||||
"predecessor_id": pred_task.id,
|
||||
"predecessor_title": pred_task.title,
|
||||
"message": f"Due date must be on or after {required_due.date()} (predecessor due date + {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
elif dep.dependency_type == "SF":
|
||||
# Predecessor must start before successor finishes
|
||||
if pred_task.start_date and new_due:
|
||||
required_due = pred_task.start_date + lag
|
||||
if new_due < required_due:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "SF",
|
||||
"predecessor_id": pred_task.id,
|
||||
"predecessor_title": pred_task.title,
|
||||
"message": f"Due date must be on or after {required_due.date()} (predecessor start date + {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
if violation:
|
||||
violations.append(violation)
|
||||
|
||||
# Get dependencies where this task is the predecessor (successors)
|
||||
successors = db.query(TaskDependency).filter(
|
||||
TaskDependency.predecessor_id == task.id
|
||||
).all()
|
||||
|
||||
for dep in successors:
|
||||
succ_task = dep.successor
|
||||
if not succ_task:
|
||||
continue
|
||||
|
||||
lag = timedelta(days=dep.lag_days)
|
||||
violation = None
|
||||
|
||||
if dep.dependency_type == "FS":
|
||||
# This task must finish before successor starts
|
||||
if new_due and succ_task.start_date:
|
||||
required_due = succ_task.start_date - lag
|
||||
if new_due > required_due:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "FS",
|
||||
"successor_id": succ_task.id,
|
||||
"successor_title": succ_task.title,
|
||||
"message": f"Due date must be on or before {required_due.date()} (successor start date - {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
elif dep.dependency_type == "SS":
|
||||
# This task must start before successor starts
|
||||
if new_start and succ_task.start_date:
|
||||
required_start = succ_task.start_date - lag
|
||||
if new_start > required_start:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "SS",
|
||||
"successor_id": succ_task.id,
|
||||
"successor_title": succ_task.title,
|
||||
"message": f"Start date must be on or before {required_start.date()} (successor start date - {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
elif dep.dependency_type == "FF":
|
||||
# This task must finish before successor finishes
|
||||
if new_due and succ_task.due_date:
|
||||
required_due = succ_task.due_date - lag
|
||||
if new_due > required_due:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "FF",
|
||||
"successor_id": succ_task.id,
|
||||
"successor_title": succ_task.title,
|
||||
"message": f"Due date must be on or before {required_due.date()} (successor due date - {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
elif dep.dependency_type == "SF":
|
||||
# This task must start before successor finishes
|
||||
if new_start and succ_task.due_date:
|
||||
required_start = succ_task.due_date - lag
|
||||
if new_start > required_start:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "SF",
|
||||
"successor_id": succ_task.id,
|
||||
"successor_title": succ_task.title,
|
||||
"message": f"Start date must be on or before {required_start.date()} (successor due date - {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
if violation:
|
||||
violations.append(violation)
|
||||
|
||||
return violations
|
||||
|
||||
@staticmethod
|
||||
def get_all_predecessors(db: Session, task_id: str) -> List[str]:
|
||||
"""
|
||||
Get all transitive predecessors of a task.
|
||||
|
||||
Uses BFS to find all tasks that this task depends on (directly or indirectly).
|
||||
"""
|
||||
visited: Set[str] = set()
|
||||
queue = [task_id]
|
||||
predecessors = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
if current in visited:
|
||||
continue
|
||||
|
||||
visited.add(current)
|
||||
|
||||
deps = db.query(TaskDependency).filter(
|
||||
TaskDependency.successor_id == current
|
||||
).all()
|
||||
|
||||
for dep in deps:
|
||||
if dep.predecessor_id not in visited:
|
||||
predecessors.append(dep.predecessor_id)
|
||||
queue.append(dep.predecessor_id)
|
||||
|
||||
return predecessors
|
||||
|
||||
@staticmethod
|
||||
def get_all_successors(db: Session, task_id: str) -> List[str]:
|
||||
"""
|
||||
Get all transitive successors of a task.
|
||||
|
||||
Uses BFS to find all tasks that depend on this task (directly or indirectly).
|
||||
"""
|
||||
visited: Set[str] = set()
|
||||
queue = [task_id]
|
||||
successors = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
if current in visited:
|
||||
continue
|
||||
|
||||
visited.add(current)
|
||||
|
||||
deps = db.query(TaskDependency).filter(
|
||||
TaskDependency.predecessor_id == current
|
||||
).all()
|
||||
|
||||
for dep in deps:
|
||||
if dep.successor_id not in visited:
|
||||
successors.append(dep.successor_id)
|
||||
queue.append(dep.successor_id)
|
||||
|
||||
return successors
|
||||
300
backend/app/services/encryption_service.py
Normal file
300
backend/app/services/encryption_service.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
Encryption service for AES-256-GCM file encryption.
|
||||
|
||||
This service handles:
|
||||
- File encryption key generation and management
|
||||
- Encrypting/decrypting file encryption keys with Master Key
|
||||
- Streaming file encryption/decryption with AES-256-GCM
|
||||
"""
|
||||
import os
|
||||
import base64
|
||||
import secrets
|
||||
import logging
|
||||
from typing import BinaryIO, Tuple, Optional, Generator
|
||||
from io import BytesIO
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
KEY_SIZE = 32 # 256 bits for AES-256
|
||||
NONCE_SIZE = 12 # 96 bits for GCM recommended nonce size
|
||||
TAG_SIZE = 16 # 128 bits for GCM authentication tag
|
||||
CHUNK_SIZE = 64 * 1024 # 64KB chunks for streaming
|
||||
|
||||
|
||||
class EncryptionError(Exception):
|
||||
"""Base exception for encryption errors."""
|
||||
pass
|
||||
|
||||
|
||||
class MasterKeyNotConfiguredError(EncryptionError):
|
||||
"""Raised when master key is not configured."""
|
||||
pass
|
||||
|
||||
|
||||
class DecryptionError(EncryptionError):
|
||||
"""Raised when decryption fails."""
|
||||
pass
|
||||
|
||||
|
||||
class EncryptionService:
|
||||
"""
|
||||
Service for file encryption using AES-256-GCM.
|
||||
|
||||
Key hierarchy:
|
||||
1. Master Key (from environment) -> encrypts file encryption keys
|
||||
2. File Encryption Keys (stored in DB) -> encrypt actual files
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._master_key: Optional[bytes] = None
|
||||
|
||||
@property
|
||||
def master_key(self) -> bytes:
|
||||
"""Get the master key, loading from config if needed."""
|
||||
if self._master_key is None:
|
||||
if not settings.ENCRYPTION_MASTER_KEY:
|
||||
raise MasterKeyNotConfiguredError(
|
||||
"ENCRYPTION_MASTER_KEY is not configured. "
|
||||
"File encryption is disabled."
|
||||
)
|
||||
self._master_key = base64.urlsafe_b64decode(settings.ENCRYPTION_MASTER_KEY)
|
||||
return self._master_key
|
||||
|
||||
def is_encryption_available(self) -> bool:
|
||||
"""Check if encryption is available (master key configured)."""
|
||||
return settings.ENCRYPTION_MASTER_KEY is not None
|
||||
|
||||
def generate_key(self) -> bytes:
|
||||
"""
|
||||
Generate a new AES-256 encryption key.
|
||||
|
||||
Returns:
|
||||
32-byte random key
|
||||
"""
|
||||
return secrets.token_bytes(KEY_SIZE)
|
||||
|
||||
def encrypt_key(self, key: bytes) -> str:
|
||||
"""
|
||||
Encrypt a file encryption key using the Master Key.
|
||||
|
||||
Args:
|
||||
key: The raw 32-byte file encryption key
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted key (nonce + ciphertext + tag)
|
||||
"""
|
||||
aesgcm = AESGCM(self.master_key)
|
||||
nonce = secrets.token_bytes(NONCE_SIZE)
|
||||
|
||||
# Encrypt the key
|
||||
ciphertext = aesgcm.encrypt(nonce, key, None)
|
||||
|
||||
# Combine nonce + ciphertext (includes tag)
|
||||
encrypted_data = nonce + ciphertext
|
||||
|
||||
return base64.urlsafe_b64encode(encrypted_data).decode('utf-8')
|
||||
|
||||
def decrypt_key(self, encrypted_key: str) -> bytes:
|
||||
"""
|
||||
Decrypt a file encryption key using the Master Key.
|
||||
|
||||
Args:
|
||||
encrypted_key: Base64-encoded encrypted key
|
||||
|
||||
Returns:
|
||||
The raw 32-byte file encryption key
|
||||
"""
|
||||
try:
|
||||
encrypted_data = base64.urlsafe_b64decode(encrypted_key)
|
||||
|
||||
# Extract nonce and ciphertext
|
||||
nonce = encrypted_data[:NONCE_SIZE]
|
||||
ciphertext = encrypted_data[NONCE_SIZE:]
|
||||
|
||||
# Decrypt
|
||||
aesgcm = AESGCM(self.master_key)
|
||||
return aesgcm.decrypt(nonce, ciphertext, None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt encryption key: {e}")
|
||||
raise DecryptionError("Failed to decrypt file encryption key")
|
||||
|
||||
def encrypt_file(self, file_content: BinaryIO, key: bytes) -> bytes:
|
||||
"""
|
||||
Encrypt file content using AES-256-GCM.
|
||||
|
||||
For smaller files, encrypts the entire content at once.
|
||||
The format is: nonce (12 bytes) + ciphertext + tag (16 bytes)
|
||||
|
||||
Args:
|
||||
file_content: File-like object to encrypt
|
||||
key: 32-byte AES-256 key
|
||||
|
||||
Returns:
|
||||
Encrypted bytes (nonce + ciphertext + tag)
|
||||
"""
|
||||
# Read all content
|
||||
plaintext = file_content.read()
|
||||
|
||||
# Generate nonce
|
||||
nonce = secrets.token_bytes(NONCE_SIZE)
|
||||
|
||||
# Encrypt
|
||||
aesgcm = AESGCM(key)
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
|
||||
|
||||
# Return nonce + ciphertext (tag is appended by encrypt)
|
||||
return nonce + ciphertext
|
||||
|
||||
def decrypt_file(self, encrypted_content: BinaryIO, key: bytes) -> bytes:
|
||||
"""
|
||||
Decrypt file content using AES-256-GCM.
|
||||
|
||||
Args:
|
||||
encrypted_content: File-like object containing encrypted data
|
||||
key: 32-byte AES-256 key
|
||||
|
||||
Returns:
|
||||
Decrypted bytes
|
||||
"""
|
||||
try:
|
||||
# Read all encrypted content
|
||||
encrypted_data = encrypted_content.read()
|
||||
|
||||
# Extract nonce and ciphertext
|
||||
nonce = encrypted_data[:NONCE_SIZE]
|
||||
ciphertext = encrypted_data[NONCE_SIZE:]
|
||||
|
||||
# Decrypt
|
||||
aesgcm = AESGCM(key)
|
||||
plaintext = aesgcm.decrypt(nonce, ciphertext, None)
|
||||
|
||||
return plaintext
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt file: {e}")
|
||||
raise DecryptionError("Failed to decrypt file. The file may be corrupted or the key is incorrect.")
|
||||
|
||||
def encrypt_file_streaming(self, file_content: BinaryIO, key: bytes) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Encrypt file content using AES-256-GCM with streaming.
|
||||
|
||||
For large files, encrypts in chunks. Each chunk has its own nonce.
|
||||
Format per chunk: chunk_size (4 bytes) + nonce (12 bytes) + ciphertext + tag
|
||||
|
||||
Args:
|
||||
file_content: File-like object to encrypt
|
||||
key: 32-byte AES-256 key
|
||||
|
||||
Yields:
|
||||
Encrypted chunks
|
||||
"""
|
||||
aesgcm = AESGCM(key)
|
||||
|
||||
# Write header with version byte
|
||||
yield b'\x01' # Version 1 for streaming format
|
||||
|
||||
while True:
|
||||
chunk = file_content.read(CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
# Generate nonce for this chunk
|
||||
nonce = secrets.token_bytes(NONCE_SIZE)
|
||||
|
||||
# Encrypt chunk
|
||||
ciphertext = aesgcm.encrypt(nonce, chunk, None)
|
||||
|
||||
# Write chunk size (4 bytes, little endian)
|
||||
chunk_size = len(ciphertext) + NONCE_SIZE
|
||||
yield chunk_size.to_bytes(4, 'little')
|
||||
|
||||
# Write nonce + ciphertext
|
||||
yield nonce + ciphertext
|
||||
|
||||
# Write end marker (zero size)
|
||||
yield b'\x00\x00\x00\x00'
|
||||
|
||||
def decrypt_file_streaming(self, encrypted_content: BinaryIO, key: bytes) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Decrypt file content using AES-256-GCM with streaming.
|
||||
|
||||
Args:
|
||||
encrypted_content: File-like object containing encrypted data
|
||||
key: 32-byte AES-256 key
|
||||
|
||||
Yields:
|
||||
Decrypted chunks
|
||||
"""
|
||||
aesgcm = AESGCM(key)
|
||||
|
||||
# Read version byte
|
||||
version = encrypted_content.read(1)
|
||||
if version != b'\x01':
|
||||
raise DecryptionError(f"Unknown encryption format version")
|
||||
|
||||
while True:
|
||||
# Read chunk size
|
||||
size_bytes = encrypted_content.read(4)
|
||||
if len(size_bytes) < 4:
|
||||
raise DecryptionError("Unexpected end of file")
|
||||
|
||||
chunk_size = int.from_bytes(size_bytes, 'little')
|
||||
|
||||
# Check for end marker
|
||||
if chunk_size == 0:
|
||||
break
|
||||
|
||||
# Read chunk (nonce + ciphertext)
|
||||
chunk = encrypted_content.read(chunk_size)
|
||||
if len(chunk) < chunk_size:
|
||||
raise DecryptionError("Unexpected end of file")
|
||||
|
||||
# Extract nonce and ciphertext
|
||||
nonce = chunk[:NONCE_SIZE]
|
||||
ciphertext = chunk[NONCE_SIZE:]
|
||||
|
||||
try:
|
||||
# Decrypt
|
||||
plaintext = aesgcm.decrypt(nonce, ciphertext, None)
|
||||
yield plaintext
|
||||
except Exception as e:
|
||||
raise DecryptionError(f"Failed to decrypt chunk: {e}")
|
||||
|
||||
def encrypt_bytes(self, data: bytes, key: bytes) -> bytes:
|
||||
"""
|
||||
Encrypt bytes directly (convenience method).
|
||||
|
||||
Args:
|
||||
data: Bytes to encrypt
|
||||
key: 32-byte AES-256 key
|
||||
|
||||
Returns:
|
||||
Encrypted bytes
|
||||
"""
|
||||
return self.encrypt_file(BytesIO(data), key)
|
||||
|
||||
def decrypt_bytes(self, encrypted_data: bytes, key: bytes) -> bytes:
|
||||
"""
|
||||
Decrypt bytes directly (convenience method).
|
||||
|
||||
Args:
|
||||
encrypted_data: Encrypted bytes
|
||||
key: 32-byte AES-256 key
|
||||
|
||||
Returns:
|
||||
Decrypted bytes
|
||||
"""
|
||||
return self.decrypt_file(BytesIO(encrypted_data), key)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
encryption_service = EncryptionService()
|
||||
420
backend/app/services/formula_service.py
Normal file
420
backend/app/services/formula_service.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Formula Service for Custom Fields
|
||||
|
||||
Supports:
|
||||
- Basic math operations: +, -, *, /
|
||||
- Field references: {field_name}
|
||||
- Built-in task fields: {original_estimate}, {time_spent}
|
||||
- Parentheses for grouping
|
||||
|
||||
Example formulas:
|
||||
- "{time_spent} / {original_estimate} * 100"
|
||||
- "{cost_per_hour} * {hours_worked}"
|
||||
- "({field_a} + {field_b}) / 2"
|
||||
"""
|
||||
import re
|
||||
import ast
|
||||
import operator
|
||||
from typing import Dict, Any, Optional, List, Set, Tuple
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Task, CustomField, TaskCustomValue
|
||||
|
||||
|
||||
class FormulaError(Exception):
|
||||
"""Exception raised for formula parsing or calculation errors."""
|
||||
pass
|
||||
|
||||
|
||||
class CircularReferenceError(FormulaError):
|
||||
"""Exception raised when circular references are detected in formulas."""
|
||||
pass
|
||||
|
||||
|
||||
class FormulaService:
|
||||
"""Service for parsing and calculating formula fields."""
|
||||
|
||||
# Built-in task fields that can be referenced in formulas
|
||||
BUILTIN_FIELDS = {
|
||||
"original_estimate",
|
||||
"time_spent",
|
||||
}
|
||||
|
||||
# Supported operators
|
||||
OPERATORS = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.USub: operator.neg,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_field_references(formula: str) -> Set[str]:
|
||||
"""
|
||||
Extract all field references from a formula.
|
||||
|
||||
Field references are in the format {field_name}.
|
||||
Returns a set of field names referenced in the formula.
|
||||
"""
|
||||
pattern = r'\{([^}]+)\}'
|
||||
matches = re.findall(pattern, formula)
|
||||
return set(matches)
|
||||
|
||||
@staticmethod
|
||||
def validate_formula(
|
||||
formula: str,
|
||||
project_id: str,
|
||||
db: Session,
|
||||
current_field_id: Optional[str] = None,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate a formula expression.
|
||||
|
||||
Checks:
|
||||
1. Syntax is valid
|
||||
2. All referenced fields exist
|
||||
3. Referenced fields are number or formula type
|
||||
4. No circular references
|
||||
|
||||
Returns (is_valid, error_message)
|
||||
"""
|
||||
if not formula or not formula.strip():
|
||||
return False, "Formula cannot be empty"
|
||||
|
||||
# Extract field references
|
||||
references = FormulaService.extract_field_references(formula)
|
||||
|
||||
if not references:
|
||||
return False, "Formula must reference at least one field"
|
||||
|
||||
# Validate syntax by trying to parse
|
||||
try:
|
||||
# Replace field references with dummy numbers for syntax check
|
||||
test_formula = formula
|
||||
for ref in references:
|
||||
test_formula = test_formula.replace(f"{{{ref}}}", "1")
|
||||
|
||||
# Try to parse and evaluate with safe operations
|
||||
FormulaService._safe_eval(test_formula)
|
||||
except Exception as e:
|
||||
return False, f"Invalid formula syntax: {str(e)}"
|
||||
|
||||
# Separate builtin and custom field references
|
||||
custom_references = references - FormulaService.BUILTIN_FIELDS
|
||||
|
||||
# Validate custom field references exist and are numeric types
|
||||
if custom_references:
|
||||
fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == project_id,
|
||||
CustomField.name.in_(custom_references),
|
||||
).all()
|
||||
|
||||
found_names = {f.name for f in fields}
|
||||
missing = custom_references - found_names
|
||||
|
||||
if missing:
|
||||
return False, f"Unknown field references: {', '.join(missing)}"
|
||||
|
||||
# Check field types (must be number or formula)
|
||||
for field in fields:
|
||||
if field.field_type not in ("number", "formula"):
|
||||
return False, f"Field '{field.name}' is not a numeric type"
|
||||
|
||||
# Check for circular references
|
||||
if current_field_id:
|
||||
try:
|
||||
FormulaService._check_circular_references(
|
||||
db, project_id, current_field_id, references
|
||||
)
|
||||
except CircularReferenceError as e:
|
||||
return False, str(e)
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def _check_circular_references(
|
||||
db: Session,
|
||||
project_id: str,
|
||||
field_id: str,
|
||||
references: Set[str],
|
||||
visited: Optional[Set[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Check for circular references in formula fields.
|
||||
|
||||
Raises CircularReferenceError if a cycle is detected.
|
||||
"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
# Get the current field's name
|
||||
current_field = db.query(CustomField).filter(
|
||||
CustomField.id == field_id
|
||||
).first()
|
||||
|
||||
if current_field:
|
||||
if current_field.name in references:
|
||||
raise CircularReferenceError(
|
||||
f"Circular reference detected: field cannot reference itself"
|
||||
)
|
||||
|
||||
# Get all referenced formula fields
|
||||
custom_references = references - FormulaService.BUILTIN_FIELDS
|
||||
if not custom_references:
|
||||
return
|
||||
|
||||
formula_fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == project_id,
|
||||
CustomField.name.in_(custom_references),
|
||||
CustomField.field_type == "formula",
|
||||
).all()
|
||||
|
||||
for field in formula_fields:
|
||||
if field.id in visited:
|
||||
raise CircularReferenceError(
|
||||
f"Circular reference detected involving field '{field.name}'"
|
||||
)
|
||||
|
||||
visited.add(field.id)
|
||||
|
||||
if field.formula:
|
||||
nested_refs = FormulaService.extract_field_references(field.formula)
|
||||
if current_field and current_field.name in nested_refs:
|
||||
raise CircularReferenceError(
|
||||
f"Circular reference detected: '{field.name}' references the current field"
|
||||
)
|
||||
FormulaService._check_circular_references(
|
||||
db, project_id, field_id, nested_refs, visited
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _safe_eval(expression: str) -> Decimal:
|
||||
"""
|
||||
Safely evaluate a mathematical expression.
|
||||
|
||||
Only allows basic arithmetic operations (+, -, *, /).
|
||||
"""
|
||||
try:
|
||||
node = ast.parse(expression, mode='eval')
|
||||
return FormulaService._eval_node(node.body)
|
||||
except Exception as e:
|
||||
raise FormulaError(f"Failed to evaluate expression: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _eval_node(node: ast.AST) -> Decimal:
|
||||
"""Recursively evaluate an AST node."""
|
||||
if isinstance(node, ast.Constant):
|
||||
if isinstance(node.value, (int, float)):
|
||||
return Decimal(str(node.value))
|
||||
raise FormulaError(f"Invalid constant: {node.value}")
|
||||
|
||||
elif isinstance(node, ast.BinOp):
|
||||
left = FormulaService._eval_node(node.left)
|
||||
right = FormulaService._eval_node(node.right)
|
||||
op = FormulaService.OPERATORS.get(type(node.op))
|
||||
if op is None:
|
||||
raise FormulaError(f"Unsupported operator: {type(node.op).__name__}")
|
||||
|
||||
# Handle division by zero
|
||||
if isinstance(node.op, ast.Div) and right == 0:
|
||||
return Decimal('0') # Return 0 instead of raising error
|
||||
|
||||
return Decimal(str(op(float(left), float(right))))
|
||||
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
operand = FormulaService._eval_node(node.operand)
|
||||
op = FormulaService.OPERATORS.get(type(node.op))
|
||||
if op is None:
|
||||
raise FormulaError(f"Unsupported operator: {type(node.op).__name__}")
|
||||
return Decimal(str(op(float(operand))))
|
||||
|
||||
else:
|
||||
raise FormulaError(f"Unsupported expression type: {type(node).__name__}")
|
||||
|
||||
@staticmethod
|
||||
def calculate_formula(
|
||||
formula: str,
|
||||
task: Task,
|
||||
db: Session,
|
||||
calculated_cache: Optional[Dict[str, Decimal]] = None,
|
||||
) -> Optional[Decimal]:
|
||||
"""
|
||||
Calculate the value of a formula for a given task.
|
||||
|
||||
Args:
|
||||
formula: The formula expression
|
||||
task: The task to calculate for
|
||||
db: Database session
|
||||
calculated_cache: Cache for already calculated formula values (for recursion)
|
||||
|
||||
Returns:
|
||||
The calculated value, or None if calculation fails
|
||||
"""
|
||||
if calculated_cache is None:
|
||||
calculated_cache = {}
|
||||
|
||||
references = FormulaService.extract_field_references(formula)
|
||||
values: Dict[str, Decimal] = {}
|
||||
|
||||
# Get builtin field values
|
||||
for ref in references:
|
||||
if ref in FormulaService.BUILTIN_FIELDS:
|
||||
task_value = getattr(task, ref, None)
|
||||
if task_value is not None:
|
||||
values[ref] = Decimal(str(task_value))
|
||||
else:
|
||||
values[ref] = Decimal('0')
|
||||
|
||||
# Get custom field values
|
||||
custom_references = references - FormulaService.BUILTIN_FIELDS
|
||||
if custom_references:
|
||||
# Get field definitions
|
||||
fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == task.project_id,
|
||||
CustomField.name.in_(custom_references),
|
||||
).all()
|
||||
|
||||
field_map = {f.name: f for f in fields}
|
||||
|
||||
# Get custom values for this task
|
||||
custom_values = db.query(TaskCustomValue).filter(
|
||||
TaskCustomValue.task_id == task.id,
|
||||
TaskCustomValue.field_id.in_([f.id for f in fields]),
|
||||
).all()
|
||||
|
||||
value_map = {cv.field_id: cv.value for cv in custom_values}
|
||||
|
||||
for ref in custom_references:
|
||||
field = field_map.get(ref)
|
||||
if not field:
|
||||
values[ref] = Decimal('0')
|
||||
continue
|
||||
|
||||
if field.field_type == "formula":
|
||||
# Recursively calculate formula fields
|
||||
if field.id in calculated_cache:
|
||||
values[ref] = calculated_cache[field.id]
|
||||
else:
|
||||
nested_value = FormulaService.calculate_formula(
|
||||
field.formula, task, db, calculated_cache
|
||||
)
|
||||
values[ref] = nested_value if nested_value is not None else Decimal('0')
|
||||
calculated_cache[field.id] = values[ref]
|
||||
else:
|
||||
# Get stored value
|
||||
stored_value = value_map.get(field.id)
|
||||
if stored_value:
|
||||
try:
|
||||
values[ref] = Decimal(str(stored_value))
|
||||
except (InvalidOperation, ValueError):
|
||||
values[ref] = Decimal('0')
|
||||
else:
|
||||
values[ref] = Decimal('0')
|
||||
|
||||
# Substitute values into formula
|
||||
expression = formula
|
||||
for ref, value in values.items():
|
||||
expression = expression.replace(f"{{{ref}}}", str(value))
|
||||
|
||||
# Evaluate the expression
|
||||
try:
|
||||
result = FormulaService._safe_eval(expression)
|
||||
# Round to 4 decimal places
|
||||
return result.quantize(Decimal('0.0001'))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def recalculate_dependent_formulas(
|
||||
db: Session,
|
||||
task: Task,
|
||||
changed_field_id: str,
|
||||
) -> Dict[str, Decimal]:
|
||||
"""
|
||||
Recalculate all formula fields that depend on a changed field.
|
||||
|
||||
Returns a dict of field_id -> calculated_value for updated formulas.
|
||||
"""
|
||||
# Get the changed field
|
||||
changed_field = db.query(CustomField).filter(
|
||||
CustomField.id == changed_field_id
|
||||
).first()
|
||||
|
||||
if not changed_field:
|
||||
return {}
|
||||
|
||||
# Find all formula fields in the project
|
||||
formula_fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == task.project_id,
|
||||
CustomField.field_type == "formula",
|
||||
).all()
|
||||
|
||||
results = {}
|
||||
calculated_cache: Dict[str, Decimal] = {}
|
||||
|
||||
for field in formula_fields:
|
||||
if not field.formula:
|
||||
continue
|
||||
|
||||
# Check if this formula depends on the changed field
|
||||
references = FormulaService.extract_field_references(field.formula)
|
||||
if changed_field.name in references or changed_field.name in FormulaService.BUILTIN_FIELDS:
|
||||
value = FormulaService.calculate_formula(
|
||||
field.formula, task, db, calculated_cache
|
||||
)
|
||||
if value is not None:
|
||||
results[field.id] = value
|
||||
calculated_cache[field.id] = value
|
||||
|
||||
# Update or create the custom value
|
||||
existing = db.query(TaskCustomValue).filter(
|
||||
TaskCustomValue.task_id == task.id,
|
||||
TaskCustomValue.field_id == field.id,
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
existing.value = str(value)
|
||||
else:
|
||||
import uuid
|
||||
new_value = TaskCustomValue(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=task.id,
|
||||
field_id=field.id,
|
||||
value=str(value),
|
||||
)
|
||||
db.add(new_value)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def calculate_all_formulas_for_task(
|
||||
db: Session,
|
||||
task: Task,
|
||||
) -> Dict[str, Decimal]:
|
||||
"""
|
||||
Calculate all formula fields for a task.
|
||||
|
||||
Used when loading a task to get current formula values.
|
||||
"""
|
||||
formula_fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == task.project_id,
|
||||
CustomField.field_type == "formula",
|
||||
).all()
|
||||
|
||||
results = {}
|
||||
calculated_cache: Dict[str, Decimal] = {}
|
||||
|
||||
for field in formula_fields:
|
||||
if not field.formula:
|
||||
continue
|
||||
|
||||
value = FormulaService.calculate_formula(
|
||||
field.formula, task, db, calculated_cache
|
||||
)
|
||||
if value is not None:
|
||||
results[field.id] = value
|
||||
calculated_cache[field.id] = value
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user