""" Formula Service for Custom Fields Supports: - Basic math operations: +, -, *, / - Field references: {field_name} - Built-in task fields: {original_estimate}, {time_spent} - Parentheses for grouping Example formulas: - "{time_spent} / {original_estimate} * 100" - "{cost_per_hour} * {hours_worked}" - "({field_a} + {field_b}) / 2" """ import re import ast import operator from typing import Dict, Any, Optional, List, Set, Tuple from decimal import Decimal, InvalidOperation from sqlalchemy.orm import Session from app.models import Task, CustomField, TaskCustomValue class FormulaError(Exception): """Exception raised for formula parsing or calculation errors.""" pass class CircularReferenceError(FormulaError): """Exception raised when circular references are detected in formulas.""" pass class FormulaService: """Service for parsing and calculating formula fields.""" # Built-in task fields that can be referenced in formulas BUILTIN_FIELDS = { "original_estimate", "time_spent", } # Supported operators OPERATORS = { ast.Add: operator.add, ast.Sub: operator.sub, ast.Mult: operator.mul, ast.Div: operator.truediv, ast.USub: operator.neg, } @staticmethod def extract_field_references(formula: str) -> Set[str]: """ Extract all field references from a formula. Field references are in the format {field_name}. Returns a set of field names referenced in the formula. """ pattern = r'\{([^}]+)\}' matches = re.findall(pattern, formula) return set(matches) @staticmethod def validate_formula( formula: str, project_id: str, db: Session, current_field_id: Optional[str] = None, ) -> Tuple[bool, Optional[str]]: """ Validate a formula expression. Checks: 1. Syntax is valid 2. All referenced fields exist 3. Referenced fields are number or formula type 4. No circular references Returns (is_valid, error_message) """ if not formula or not formula.strip(): return False, "Formula cannot be empty" # Extract field references references = FormulaService.extract_field_references(formula) if not references: return False, "Formula must reference at least one field" # Validate syntax by trying to parse try: # Replace field references with dummy numbers for syntax check test_formula = formula for ref in references: test_formula = test_formula.replace(f"{{{ref}}}", "1") # Try to parse and evaluate with safe operations FormulaService._safe_eval(test_formula) except Exception as e: return False, f"Invalid formula syntax: {str(e)}" # Separate builtin and custom field references custom_references = references - FormulaService.BUILTIN_FIELDS # Validate custom field references exist and are numeric types if custom_references: fields = db.query(CustomField).filter( CustomField.project_id == project_id, CustomField.name.in_(custom_references), ).all() found_names = {f.name for f in fields} missing = custom_references - found_names if missing: return False, f"Unknown field references: {', '.join(missing)}" # Check field types (must be number or formula) for field in fields: if field.field_type not in ("number", "formula"): return False, f"Field '{field.name}' is not a numeric type" # Check for circular references if current_field_id: try: FormulaService._check_circular_references( db, project_id, current_field_id, references ) except CircularReferenceError as e: return False, str(e) return True, None @staticmethod def _check_circular_references( db: Session, project_id: str, field_id: str, references: Set[str], visited: Optional[Set[str]] = None, ) -> None: """ Check for circular references in formula fields. Raises CircularReferenceError if a cycle is detected. """ if visited is None: visited = set() # Get the current field's name current_field = db.query(CustomField).filter( CustomField.id == field_id ).first() if current_field: if current_field.name in references: raise CircularReferenceError( f"Circular reference detected: field cannot reference itself" ) # Get all referenced formula fields custom_references = references - FormulaService.BUILTIN_FIELDS if not custom_references: return formula_fields = db.query(CustomField).filter( CustomField.project_id == project_id, CustomField.name.in_(custom_references), CustomField.field_type == "formula", ).all() for field in formula_fields: if field.id in visited: raise CircularReferenceError( f"Circular reference detected involving field '{field.name}'" ) visited.add(field.id) if field.formula: nested_refs = FormulaService.extract_field_references(field.formula) if current_field and current_field.name in nested_refs: raise CircularReferenceError( f"Circular reference detected: '{field.name}' references the current field" ) FormulaService._check_circular_references( db, project_id, field_id, nested_refs, visited ) @staticmethod def _safe_eval(expression: str) -> Decimal: """ Safely evaluate a mathematical expression. Only allows basic arithmetic operations (+, -, *, /). """ try: node = ast.parse(expression, mode='eval') return FormulaService._eval_node(node.body) except Exception as e: raise FormulaError(f"Failed to evaluate expression: {str(e)}") @staticmethod def _eval_node(node: ast.AST) -> Decimal: """Recursively evaluate an AST node.""" if isinstance(node, ast.Constant): if isinstance(node.value, (int, float)): return Decimal(str(node.value)) raise FormulaError(f"Invalid constant: {node.value}") elif isinstance(node, ast.BinOp): left = FormulaService._eval_node(node.left) right = FormulaService._eval_node(node.right) op = FormulaService.OPERATORS.get(type(node.op)) if op is None: raise FormulaError(f"Unsupported operator: {type(node.op).__name__}") # Handle division by zero if isinstance(node.op, ast.Div) and right == 0: return Decimal('0') # Return 0 instead of raising error return Decimal(str(op(float(left), float(right)))) elif isinstance(node, ast.UnaryOp): operand = FormulaService._eval_node(node.operand) op = FormulaService.OPERATORS.get(type(node.op)) if op is None: raise FormulaError(f"Unsupported operator: {type(node.op).__name__}") return Decimal(str(op(float(operand)))) else: raise FormulaError(f"Unsupported expression type: {type(node).__name__}") @staticmethod def calculate_formula( formula: str, task: Task, db: Session, calculated_cache: Optional[Dict[str, Decimal]] = None, ) -> Optional[Decimal]: """ Calculate the value of a formula for a given task. Args: formula: The formula expression task: The task to calculate for db: Database session calculated_cache: Cache for already calculated formula values (for recursion) Returns: The calculated value, or None if calculation fails """ if calculated_cache is None: calculated_cache = {} references = FormulaService.extract_field_references(formula) values: Dict[str, Decimal] = {} # Get builtin field values for ref in references: if ref in FormulaService.BUILTIN_FIELDS: task_value = getattr(task, ref, None) if task_value is not None: values[ref] = Decimal(str(task_value)) else: values[ref] = Decimal('0') # Get custom field values custom_references = references - FormulaService.BUILTIN_FIELDS if custom_references: # Get field definitions fields = db.query(CustomField).filter( CustomField.project_id == task.project_id, CustomField.name.in_(custom_references), ).all() field_map = {f.name: f for f in fields} # Get custom values for this task custom_values = db.query(TaskCustomValue).filter( TaskCustomValue.task_id == task.id, TaskCustomValue.field_id.in_([f.id for f in fields]), ).all() value_map = {cv.field_id: cv.value for cv in custom_values} for ref in custom_references: field = field_map.get(ref) if not field: values[ref] = Decimal('0') continue if field.field_type == "formula": # Recursively calculate formula fields if field.id in calculated_cache: values[ref] = calculated_cache[field.id] else: nested_value = FormulaService.calculate_formula( field.formula, task, db, calculated_cache ) values[ref] = nested_value if nested_value is not None else Decimal('0') calculated_cache[field.id] = values[ref] else: # Get stored value stored_value = value_map.get(field.id) if stored_value: try: values[ref] = Decimal(str(stored_value)) except (InvalidOperation, ValueError): values[ref] = Decimal('0') else: values[ref] = Decimal('0') # Substitute values into formula expression = formula for ref, value in values.items(): expression = expression.replace(f"{{{ref}}}", str(value)) # Evaluate the expression try: result = FormulaService._safe_eval(expression) # Round to 4 decimal places return result.quantize(Decimal('0.0001')) except Exception: return None @staticmethod def recalculate_dependent_formulas( db: Session, task: Task, changed_field_id: str, ) -> Dict[str, Decimal]: """ Recalculate all formula fields that depend on a changed field. Returns a dict of field_id -> calculated_value for updated formulas. """ # Get the changed field changed_field = db.query(CustomField).filter( CustomField.id == changed_field_id ).first() if not changed_field: return {} # Find all formula fields in the project formula_fields = db.query(CustomField).filter( CustomField.project_id == task.project_id, CustomField.field_type == "formula", ).all() results = {} calculated_cache: Dict[str, Decimal] = {} for field in formula_fields: if not field.formula: continue # Check if this formula depends on the changed field references = FormulaService.extract_field_references(field.formula) if changed_field.name in references or changed_field.name in FormulaService.BUILTIN_FIELDS: value = FormulaService.calculate_formula( field.formula, task, db, calculated_cache ) if value is not None: results[field.id] = value calculated_cache[field.id] = value # Update or create the custom value existing = db.query(TaskCustomValue).filter( TaskCustomValue.task_id == task.id, TaskCustomValue.field_id == field.id, ).first() if existing: existing.value = str(value) else: import uuid new_value = TaskCustomValue( id=str(uuid.uuid4()), task_id=task.id, field_id=field.id, value=str(value), ) db.add(new_value) return results @staticmethod def calculate_all_formulas_for_task( db: Session, task: Task, ) -> Dict[str, Decimal]: """ Calculate all formula fields for a task. Used when loading a task to get current formula values. """ formula_fields = db.query(CustomField).filter( CustomField.project_id == task.project_id, CustomField.field_type == "formula", ).all() results = {} calculated_cache: Dict[str, Decimal] = {} for field in formula_fields: if not field.formula: continue value = FormulaService.calculate_formula( field.formula, task, db, calculated_cache ) if value is not None: results[field.id] = value calculated_cache[field.id] = value return results