""" Formula Service for Custom Fields Supports: - Basic math operations: +, -, *, / - Field references: {field_name} - Built-in task fields: {original_estimate}, {time_spent} - Parentheses for grouping Example formulas: - "{time_spent} / {original_estimate} * 100" - "{cost_per_hour} * {hours_worked}" - "({field_a} + {field_b}) / 2" """ import re import ast import operator from typing import Dict, Any, Optional, List, Set, Tuple from decimal import Decimal, InvalidOperation from sqlalchemy.orm import Session from app.models import Task, CustomField, TaskCustomValue class FormulaError(Exception): """Exception raised for formula parsing or calculation errors.""" pass class CircularReferenceError(FormulaError): """Exception raised when circular references are detected in formulas.""" def __init__(self, message: str, cycle_path: Optional[List[str]] = None): super().__init__(message) self.message = message self.cycle_path = cycle_path or [] def get_cycle_description(self) -> str: """Get a human-readable description of the cycle.""" if not self.cycle_path: return "" return " -> ".join(self.cycle_path) class FormulaService: """Service for parsing and calculating formula fields.""" # Built-in task fields that can be referenced in formulas BUILTIN_FIELDS = { "original_estimate", "time_spent", } # Supported operators OPERATORS = { ast.Add: operator.add, ast.Sub: operator.sub, ast.Mult: operator.mul, ast.Div: operator.truediv, ast.USub: operator.neg, } @staticmethod def extract_field_references(formula: str) -> Set[str]: """ Extract all field references from a formula. Field references are in the format {field_name}. Returns a set of field names referenced in the formula. """ pattern = r'\{([^}]+)\}' matches = re.findall(pattern, formula) return set(matches) @staticmethod def validate_formula( formula: str, project_id: str, db: Session, current_field_id: Optional[str] = None, ) -> Tuple[bool, Optional[str]]: """ Validate a formula expression. Checks: 1. Syntax is valid 2. All referenced fields exist 3. Referenced fields are number or formula type 4. No circular references Returns (is_valid, error_message) """ if not formula or not formula.strip(): return False, "Formula cannot be empty" # Extract field references references = FormulaService.extract_field_references(formula) if not references: return False, "Formula must reference at least one field" # Validate syntax by trying to parse try: # Replace field references with dummy numbers for syntax check test_formula = formula for ref in references: test_formula = test_formula.replace(f"{{{ref}}}", "1") # Try to parse and evaluate with safe operations FormulaService._safe_eval(test_formula) except Exception as e: return False, f"Invalid formula syntax: {str(e)}" # Separate builtin and custom field references custom_references = references - FormulaService.BUILTIN_FIELDS # Validate custom field references exist and are numeric types if custom_references: fields = db.query(CustomField).filter( CustomField.project_id == project_id, CustomField.name.in_(custom_references), ).all() found_names = {f.name for f in fields} missing = custom_references - found_names if missing: return False, f"Unknown field references: {', '.join(missing)}" # Check field types (must be number or formula) for field in fields: if field.field_type not in ("number", "formula"): return False, f"Field '{field.name}' is not a numeric type" # Check for circular references if current_field_id: try: FormulaService._check_circular_references( db, project_id, current_field_id, references ) except CircularReferenceError as e: return False, str(e) return True, None @staticmethod def _check_circular_references( db: Session, project_id: str, field_id: str, references: Set[str], visited: Optional[Set[str]] = None, path: Optional[List[str]] = None, ) -> None: """ Check for circular references in formula fields. Raises CircularReferenceError if a cycle is detected. Args: db: Database session project_id: Project ID to scope the query field_id: The current field being validated references: Set of field names referenced in the formula visited: Set of visited field IDs (for cycle detection) path: Current path of field names (for error reporting) """ if visited is None: visited = set() if path is None: path = [] # Get the current field's name current_field = db.query(CustomField).filter( CustomField.id == field_id ).first() current_field_name = current_field.name if current_field else "unknown" # Add current field to path if not already there if current_field_name not in path: path = path + [current_field_name] if current_field: if current_field.name in references: cycle_path = path + [current_field.name] raise CircularReferenceError( f"Circular reference detected: field '{current_field.name}' cannot reference itself", cycle_path=cycle_path ) # Get all referenced formula fields custom_references = references - FormulaService.BUILTIN_FIELDS if not custom_references: return formula_fields = db.query(CustomField).filter( CustomField.project_id == project_id, CustomField.name.in_(custom_references), CustomField.field_type == "formula", ).all() for field in formula_fields: if field.id in visited: # Found a cycle cycle_path = path + [field.name] raise CircularReferenceError( f"Circular reference detected: {' -> '.join(cycle_path)}", cycle_path=cycle_path ) visited.add(field.id) new_path = path + [field.name] if field.formula: nested_refs = FormulaService.extract_field_references(field.formula) if current_field and current_field.name in nested_refs: cycle_path = new_path + [current_field.name] raise CircularReferenceError( f"Circular reference detected: {' -> '.join(cycle_path)}", cycle_path=cycle_path ) FormulaService._check_circular_references( db, project_id, field_id, nested_refs, visited, new_path ) @staticmethod def build_formula_dependency_graph( db: Session, project_id: str ) -> Dict[str, Set[str]]: """ Build a dependency graph for all formula fields in a project. Returns a dict where keys are field names and values are sets of field names that the key field depends on. Args: db: Database session project_id: Project ID to scope the query Returns: Dict mapping field names to their dependencies """ graph: Dict[str, Set[str]] = {} formula_fields = db.query(CustomField).filter( CustomField.project_id == project_id, CustomField.field_type == "formula", ).all() for field in formula_fields: if field.formula: refs = FormulaService.extract_field_references(field.formula) # Only include custom field references (not builtin fields) custom_refs = refs - FormulaService.BUILTIN_FIELDS graph[field.name] = custom_refs else: graph[field.name] = set() return graph @staticmethod def detect_formula_cycles( db: Session, project_id: str ) -> List[List[str]]: """ Detect all cycles in the formula dependency graph for a project. This is useful for auditing or cleanup operations. Args: db: Database session project_id: Project ID to check Returns: List of cycles, where each cycle is a list of field names """ graph = FormulaService.build_formula_dependency_graph(db, project_id) if not graph: return [] cycles: List[List[str]] = [] visited: Set[str] = set() found_cycles: Set[Tuple[str, ...]] = set() def dfs(node: str, path: List[str], in_path: Set[str]): """DFS to find cycles.""" if node in in_path: # Found a cycle cycle_start = path.index(node) cycle = path[cycle_start:] + [node] # Normalize for deduplication normalized = tuple(sorted(cycle[:-1])) if normalized not in found_cycles: found_cycles.add(normalized) cycles.append(cycle) return if node in visited: return if node not in graph: return visited.add(node) in_path.add(node) path.append(node) for neighbor in graph.get(node, set()): dfs(neighbor, path.copy(), in_path.copy()) path.pop() in_path.discard(node) for start_node in graph.keys(): if start_node not in visited: dfs(start_node, [], set()) return cycles @staticmethod def validate_formula_with_details( formula: str, project_id: str, db: Session, current_field_id: Optional[str] = None, ) -> Tuple[bool, Optional[str], Optional[List[str]]]: """ Validate a formula expression with detailed error information. Similar to validate_formula but returns cycle path on circular reference errors. Args: formula: The formula expression to validate project_id: Project ID to scope field lookups db: Database session current_field_id: Optional ID of the field being edited (for self-reference check) Returns: Tuple of (is_valid, error_message, cycle_path) """ if not formula or not formula.strip(): return False, "Formula cannot be empty", None # Extract field references references = FormulaService.extract_field_references(formula) if not references: return False, "Formula must reference at least one field", None # Validate syntax by trying to parse try: # Replace field references with dummy numbers for syntax check test_formula = formula for ref in references: test_formula = test_formula.replace(f"{{{ref}}}", "1") # Try to parse and evaluate with safe operations FormulaService._safe_eval(test_formula) except Exception as e: return False, f"Invalid formula syntax: {str(e)}", None # Separate builtin and custom field references custom_references = references - FormulaService.BUILTIN_FIELDS # Validate custom field references exist and are numeric types if custom_references: fields = db.query(CustomField).filter( CustomField.project_id == project_id, CustomField.name.in_(custom_references), ).all() found_names = {f.name for f in fields} missing = custom_references - found_names if missing: return False, f"Unknown field references: {', '.join(missing)}", None # Check field types (must be number or formula) for field in fields: if field.field_type not in ("number", "formula"): return False, f"Field '{field.name}' is not a numeric type", None # Check for circular references if current_field_id: try: FormulaService._check_circular_references( db, project_id, current_field_id, references ) except CircularReferenceError as e: return False, str(e), e.cycle_path return True, None, None @staticmethod def _safe_eval(expression: str) -> Decimal: """ Safely evaluate a mathematical expression. Only allows basic arithmetic operations (+, -, *, /). """ try: node = ast.parse(expression, mode='eval') return FormulaService._eval_node(node.body) except Exception as e: raise FormulaError(f"Failed to evaluate expression: {str(e)}") @staticmethod def _eval_node(node: ast.AST) -> Decimal: """Recursively evaluate an AST node.""" if isinstance(node, ast.Constant): if isinstance(node.value, (int, float)): return Decimal(str(node.value)) raise FormulaError(f"Invalid constant: {node.value}") elif isinstance(node, ast.BinOp): left = FormulaService._eval_node(node.left) right = FormulaService._eval_node(node.right) op = FormulaService.OPERATORS.get(type(node.op)) if op is None: raise FormulaError(f"Unsupported operator: {type(node.op).__name__}") # Handle division by zero if isinstance(node.op, ast.Div) and right == 0: return Decimal('0') # Return 0 instead of raising error return Decimal(str(op(float(left), float(right)))) elif isinstance(node, ast.UnaryOp): operand = FormulaService._eval_node(node.operand) op = FormulaService.OPERATORS.get(type(node.op)) if op is None: raise FormulaError(f"Unsupported operator: {type(node.op).__name__}") return Decimal(str(op(float(operand)))) else: raise FormulaError(f"Unsupported expression type: {type(node).__name__}") @staticmethod def calculate_formula( formula: str, task: Task, db: Session, calculated_cache: Optional[Dict[str, Decimal]] = None, ) -> Optional[Decimal]: """ Calculate the value of a formula for a given task. Args: formula: The formula expression task: The task to calculate for db: Database session calculated_cache: Cache for already calculated formula values (for recursion) Returns: The calculated value, or None if calculation fails """ if calculated_cache is None: calculated_cache = {} references = FormulaService.extract_field_references(formula) values: Dict[str, Decimal] = {} # Get builtin field values for ref in references: if ref in FormulaService.BUILTIN_FIELDS: task_value = getattr(task, ref, None) if task_value is not None: values[ref] = Decimal(str(task_value)) else: values[ref] = Decimal('0') # Get custom field values custom_references = references - FormulaService.BUILTIN_FIELDS if custom_references: # Get field definitions fields = db.query(CustomField).filter( CustomField.project_id == task.project_id, CustomField.name.in_(custom_references), ).all() field_map = {f.name: f for f in fields} # Get custom values for this task custom_values = db.query(TaskCustomValue).filter( TaskCustomValue.task_id == task.id, TaskCustomValue.field_id.in_([f.id for f in fields]), ).all() value_map = {cv.field_id: cv.value for cv in custom_values} for ref in custom_references: field = field_map.get(ref) if not field: values[ref] = Decimal('0') continue if field.field_type == "formula": # Recursively calculate formula fields if field.id in calculated_cache: values[ref] = calculated_cache[field.id] else: nested_value = FormulaService.calculate_formula( field.formula, task, db, calculated_cache ) values[ref] = nested_value if nested_value is not None else Decimal('0') calculated_cache[field.id] = values[ref] else: # Get stored value stored_value = value_map.get(field.id) if stored_value: try: values[ref] = Decimal(str(stored_value)) except (InvalidOperation, ValueError): values[ref] = Decimal('0') else: values[ref] = Decimal('0') # Substitute values into formula expression = formula for ref, value in values.items(): expression = expression.replace(f"{{{ref}}}", str(value)) # Evaluate the expression try: result = FormulaService._safe_eval(expression) # Round to 4 decimal places return result.quantize(Decimal('0.0001')) except Exception: return None @staticmethod def recalculate_dependent_formulas( db: Session, task: Task, changed_field_id: str, ) -> Dict[str, Decimal]: """ Recalculate all formula fields that depend on a changed field. Returns a dict of field_id -> calculated_value for updated formulas. """ # Get the changed field changed_field = db.query(CustomField).filter( CustomField.id == changed_field_id ).first() if not changed_field: return {} # Find all formula fields in the project formula_fields = db.query(CustomField).filter( CustomField.project_id == task.project_id, CustomField.field_type == "formula", ).all() results = {} calculated_cache: Dict[str, Decimal] = {} for field in formula_fields: if not field.formula: continue # Check if this formula depends on the changed field references = FormulaService.extract_field_references(field.formula) if changed_field.name in references or changed_field.name in FormulaService.BUILTIN_FIELDS: value = FormulaService.calculate_formula( field.formula, task, db, calculated_cache ) if value is not None: results[field.id] = value calculated_cache[field.id] = value # Update or create the custom value existing = db.query(TaskCustomValue).filter( TaskCustomValue.task_id == task.id, TaskCustomValue.field_id == field.id, ).first() if existing: existing.value = str(value) else: import uuid new_value = TaskCustomValue( id=str(uuid.uuid4()), task_id=task.id, field_id=field.id, value=str(value), ) db.add(new_value) return results @staticmethod def calculate_all_formulas_for_task( db: Session, task: Task, ) -> Dict[str, Decimal]: """ Calculate all formula fields for a task. Used when loading a task to get current formula values. """ formula_fields = db.query(CustomField).filter( CustomField.project_id == task.project_id, CustomField.field_type == "formula", ).all() results = {} calculated_cache: Dict[str, Decimal] = {} for field in formula_fields: if not field.formula: continue value = FormulaService.calculate_formula( field.formula, task, db, calculated_cache ) if value is not None: results[field.id] = value calculated_cache[field.id] = value return results