""" Dependency Service Handles task dependency validation including: - Circular dependency detection using DFS - Date constraint validation based on dependency types - Self-reference prevention - Cross-project dependency prevention - Bulk dependency operations with cycle detection """ from typing import List, Optional, Set, Tuple, Dict, Any from collections import defaultdict from sqlalchemy.orm import Session from datetime import datetime, timedelta from app.models import Task, TaskDependency class DependencyValidationError(Exception): """Custom exception for dependency validation errors.""" def __init__(self, error_type: str, message: str, details: Optional[dict] = None): self.error_type = error_type self.message = message self.details = details or {} super().__init__(message) class CycleDetectionResult: """Result of cycle detection with detailed path information.""" def __init__( self, has_cycle: bool, cycle_path: Optional[List[str]] = None, cycle_task_titles: Optional[List[str]] = None ): self.has_cycle = has_cycle self.cycle_path = cycle_path or [] self.cycle_task_titles = cycle_task_titles or [] def get_cycle_description(self) -> str: """Get a human-readable description of the cycle.""" if not self.has_cycle or not self.cycle_task_titles: return "" # Format: Task A -> Task B -> Task C -> Task A return " -> ".join(self.cycle_task_titles) class DependencyService: """Service for managing task dependencies with validation.""" # Maximum number of direct dependencies per task (as per spec) MAX_DIRECT_DEPENDENCIES = 10 @staticmethod def detect_circular_dependency( db: Session, predecessor_id: str, successor_id: str, project_id: str ) -> Optional[List[str]]: """ Detect if adding a dependency would create a circular reference. Uses DFS to traverse from the successor to check if we can reach the predecessor through existing dependencies. Args: db: Database session predecessor_id: The task that must complete first successor_id: The task that depends on the predecessor project_id: Project ID to scope the query Returns: List of task IDs forming the cycle if circular, None otherwise """ result = DependencyService.detect_circular_dependency_detailed( db, predecessor_id, successor_id, project_id ) return result.cycle_path if result.has_cycle else None @staticmethod def detect_circular_dependency_detailed( db: Session, predecessor_id: str, successor_id: str, project_id: str, additional_edges: Optional[List[Tuple[str, str]]] = None ) -> CycleDetectionResult: """ Detect if adding a dependency would create a circular reference. Uses DFS to traverse from the successor to check if we can reach the predecessor through existing dependencies. Args: db: Database session predecessor_id: The task that must complete first successor_id: The task that depends on the predecessor project_id: Project ID to scope the query additional_edges: Optional list of additional (predecessor_id, successor_id) edges to consider (for bulk operations) Returns: CycleDetectionResult with detailed cycle information """ # Build adjacency list for the project's dependencies dependencies = db.query(TaskDependency).join( Task, TaskDependency.successor_id == Task.id ).filter(Task.project_id == project_id).all() # Graph: successor -> [predecessors] # We need to check if predecessor is reachable from successor # by following the chain of "what does this task depend on" graph: Dict[str, List[str]] = defaultdict(list) for dep in dependencies: graph[dep.successor_id].append(dep.predecessor_id) # Simulate adding the new edge graph[successor_id].append(predecessor_id) # Add any additional edges for bulk operations if additional_edges: for pred_id, succ_id in additional_edges: graph[succ_id].append(pred_id) # Build task title map for readable error messages task_ids_in_graph = set() for succ_id, pred_ids in graph.items(): task_ids_in_graph.add(succ_id) task_ids_in_graph.update(pred_ids) tasks = db.query(Task).filter(Task.id.in_(task_ids_in_graph)).all() task_title_map: Dict[str, str] = {t.id: t.title for t in tasks} # DFS to find if there's a path from predecessor back to successor # (which would complete a cycle) visited: Set[str] = set() path: List[str] = [] in_path: Set[str] = set() def dfs(node: str) -> Optional[List[str]]: """DFS traversal to detect cycles.""" if node in in_path: # Found a cycle - return the cycle path cycle_start = path.index(node) return path[cycle_start:] + [node] if node in visited: return None visited.add(node) in_path.add(node) path.append(node) for neighbor in graph.get(node, []): result = dfs(neighbor) if result: return result path.pop() in_path.remove(node) return None # Start DFS from the successor to check if we can reach back to it cycle_path = dfs(successor_id) if cycle_path: # Build task titles for the cycle cycle_titles = [task_title_map.get(task_id, task_id) for task_id in cycle_path] return CycleDetectionResult( has_cycle=True, cycle_path=cycle_path, cycle_task_titles=cycle_titles ) return CycleDetectionResult(has_cycle=False) @staticmethod def validate_dependency( db: Session, predecessor_id: str, successor_id: str ) -> None: """ Validate that a dependency can be created. Raises DependencyValidationError if validation fails. Checks: 1. Self-reference 2. Both tasks exist 3. Both tasks are in the same project 4. No duplicate dependency 5. No circular dependency 6. Dependency limit not exceeded """ # Check self-reference if predecessor_id == successor_id: raise DependencyValidationError( error_type="self_reference", message="A task cannot depend on itself" ) # Get both tasks predecessor = db.query(Task).filter(Task.id == predecessor_id).first() successor = db.query(Task).filter(Task.id == successor_id).first() if not predecessor: raise DependencyValidationError( error_type="not_found", message="Predecessor task not found", details={"task_id": predecessor_id} ) if not successor: raise DependencyValidationError( error_type="not_found", message="Successor task not found", details={"task_id": successor_id} ) # Check same project if predecessor.project_id != successor.project_id: raise DependencyValidationError( error_type="cross_project", message="Dependencies can only be created between tasks in the same project", details={ "predecessor_project_id": predecessor.project_id, "successor_project_id": successor.project_id } ) # Check duplicate existing = db.query(TaskDependency).filter( TaskDependency.predecessor_id == predecessor_id, TaskDependency.successor_id == successor_id ).first() if existing: raise DependencyValidationError( error_type="duplicate", message="This dependency already exists" ) # Check dependency limit current_count = db.query(TaskDependency).filter( TaskDependency.successor_id == successor_id ).count() if current_count >= DependencyService.MAX_DIRECT_DEPENDENCIES: raise DependencyValidationError( error_type="limit_exceeded", message=f"A task cannot have more than {DependencyService.MAX_DIRECT_DEPENDENCIES} direct dependencies", details={"current_count": current_count} ) # Check circular dependency cycle_result = DependencyService.detect_circular_dependency_detailed( db, predecessor_id, successor_id, predecessor.project_id ) if cycle_result.has_cycle: raise DependencyValidationError( error_type="circular", message=f"Adding this dependency would create a circular reference: {cycle_result.get_cycle_description()}", details={ "cycle": cycle_result.cycle_path, "cycle_description": cycle_result.get_cycle_description(), "cycle_task_titles": cycle_result.cycle_task_titles } ) @staticmethod def validate_date_constraints( task: Task, start_date: Optional[datetime], due_date: Optional[datetime], db: Session ) -> List[Dict[str, Any]]: """ Validate date changes against dependency constraints. Returns a list of constraint violations (empty if valid). Dependency type meanings: - FS: predecessor.due_date + lag <= successor.start_date - SS: predecessor.start_date + lag <= successor.start_date - FF: predecessor.due_date + lag <= successor.due_date - SF: predecessor.start_date + lag <= successor.due_date """ violations = [] # Use provided dates or fall back to current task dates new_start = start_date if start_date is not None else task.start_date new_due = due_date if due_date is not None else task.due_date # Basic date validation if new_start and new_due and new_start > new_due: violations.append({ "type": "date_order", "message": "Start date cannot be after due date", "start_date": str(new_start), "due_date": str(new_due) }) # Get dependencies where this task is the successor (predecessors) predecessors = db.query(TaskDependency).filter( TaskDependency.successor_id == task.id ).all() for dep in predecessors: pred_task = dep.predecessor if not pred_task: continue lag = timedelta(days=dep.lag_days) violation = None if dep.dependency_type == "FS": # Predecessor must finish before successor starts if pred_task.due_date and new_start: required_start = pred_task.due_date + lag if new_start < required_start: violation = { "type": "dependency_constraint", "dependency_type": "FS", "predecessor_id": pred_task.id, "predecessor_title": pred_task.title, "message": f"Start date must be on or after {required_start.date()} (predecessor due date + {dep.lag_days} days lag)" } elif dep.dependency_type == "SS": # Predecessor must start before successor starts if pred_task.start_date and new_start: required_start = pred_task.start_date + lag if new_start < required_start: violation = { "type": "dependency_constraint", "dependency_type": "SS", "predecessor_id": pred_task.id, "predecessor_title": pred_task.title, "message": f"Start date must be on or after {required_start.date()} (predecessor start date + {dep.lag_days} days lag)" } elif dep.dependency_type == "FF": # Predecessor must finish before successor finishes if pred_task.due_date and new_due: required_due = pred_task.due_date + lag if new_due < required_due: violation = { "type": "dependency_constraint", "dependency_type": "FF", "predecessor_id": pred_task.id, "predecessor_title": pred_task.title, "message": f"Due date must be on or after {required_due.date()} (predecessor due date + {dep.lag_days} days lag)" } elif dep.dependency_type == "SF": # Predecessor must start before successor finishes if pred_task.start_date and new_due: required_due = pred_task.start_date + lag if new_due < required_due: violation = { "type": "dependency_constraint", "dependency_type": "SF", "predecessor_id": pred_task.id, "predecessor_title": pred_task.title, "message": f"Due date must be on or after {required_due.date()} (predecessor start date + {dep.lag_days} days lag)" } if violation: violations.append(violation) # Get dependencies where this task is the predecessor (successors) successors = db.query(TaskDependency).filter( TaskDependency.predecessor_id == task.id ).all() for dep in successors: succ_task = dep.successor if not succ_task: continue lag = timedelta(days=dep.lag_days) violation = None if dep.dependency_type == "FS": # This task must finish before successor starts if new_due and succ_task.start_date: required_due = succ_task.start_date - lag if new_due > required_due: violation = { "type": "dependency_constraint", "dependency_type": "FS", "successor_id": succ_task.id, "successor_title": succ_task.title, "message": f"Due date must be on or before {required_due.date()} (successor start date - {dep.lag_days} days lag)" } elif dep.dependency_type == "SS": # This task must start before successor starts if new_start and succ_task.start_date: required_start = succ_task.start_date - lag if new_start > required_start: violation = { "type": "dependency_constraint", "dependency_type": "SS", "successor_id": succ_task.id, "successor_title": succ_task.title, "message": f"Start date must be on or before {required_start.date()} (successor start date - {dep.lag_days} days lag)" } elif dep.dependency_type == "FF": # This task must finish before successor finishes if new_due and succ_task.due_date: required_due = succ_task.due_date - lag if new_due > required_due: violation = { "type": "dependency_constraint", "dependency_type": "FF", "successor_id": succ_task.id, "successor_title": succ_task.title, "message": f"Due date must be on or before {required_due.date()} (successor due date - {dep.lag_days} days lag)" } elif dep.dependency_type == "SF": # This task must start before successor finishes if new_start and succ_task.due_date: required_start = succ_task.due_date - lag if new_start > required_start: violation = { "type": "dependency_constraint", "dependency_type": "SF", "successor_id": succ_task.id, "successor_title": succ_task.title, "message": f"Start date must be on or before {required_start.date()} (successor due date - {dep.lag_days} days lag)" } if violation: violations.append(violation) return violations @staticmethod def get_all_predecessors(db: Session, task_id: str) -> List[str]: """ Get all transitive predecessors of a task. Uses BFS to find all tasks that this task depends on (directly or indirectly). """ visited: Set[str] = set() queue = [task_id] predecessors = [] while queue: current = queue.pop(0) if current in visited: continue visited.add(current) deps = db.query(TaskDependency).filter( TaskDependency.successor_id == current ).all() for dep in deps: if dep.predecessor_id not in visited: predecessors.append(dep.predecessor_id) queue.append(dep.predecessor_id) return predecessors @staticmethod def get_all_successors(db: Session, task_id: str) -> List[str]: """ Get all transitive successors of a task. Uses BFS to find all tasks that depend on this task (directly or indirectly). """ visited: Set[str] = set() queue = [task_id] successors = [] while queue: current = queue.pop(0) if current in visited: continue visited.add(current) deps = db.query(TaskDependency).filter( TaskDependency.predecessor_id == current ).all() for dep in deps: if dep.successor_id not in visited: successors.append(dep.successor_id) queue.append(dep.successor_id) return successors @staticmethod def validate_bulk_dependencies( db: Session, dependencies: List[Tuple[str, str]], project_id: str ) -> List[Dict[str, Any]]: """ Validate a batch of dependencies for cycle detection. This method validates multiple dependencies together to detect cycles that would only appear when all dependencies are added together. Args: db: Database session dependencies: List of (predecessor_id, successor_id) tuples project_id: Project ID to scope the query Returns: List of validation errors (empty if all valid) """ errors: List[Dict[str, Any]] = [] if not dependencies: return errors # First, validate each dependency individually for basic checks for predecessor_id, successor_id in dependencies: # Check self-reference if predecessor_id == successor_id: errors.append({ "error_type": "self_reference", "predecessor_id": predecessor_id, "successor_id": successor_id, "message": "A task cannot depend on itself" }) continue # Get tasks to validate project membership predecessor = db.query(Task).filter(Task.id == predecessor_id).first() successor = db.query(Task).filter(Task.id == successor_id).first() if not predecessor: errors.append({ "error_type": "not_found", "predecessor_id": predecessor_id, "successor_id": successor_id, "message": f"Predecessor task not found: {predecessor_id}" }) continue if not successor: errors.append({ "error_type": "not_found", "predecessor_id": predecessor_id, "successor_id": successor_id, "message": f"Successor task not found: {successor_id}" }) continue if predecessor.project_id != project_id or successor.project_id != project_id: errors.append({ "error_type": "cross_project", "predecessor_id": predecessor_id, "successor_id": successor_id, "message": "All tasks must be in the same project" }) continue # Check for duplicates within the batch existing = db.query(TaskDependency).filter( TaskDependency.predecessor_id == predecessor_id, TaskDependency.successor_id == successor_id ).first() if existing: errors.append({ "error_type": "duplicate", "predecessor_id": predecessor_id, "successor_id": successor_id, "message": "This dependency already exists" }) # If there are basic validation errors, return them first if errors: return errors # Now check for cycles considering all dependencies together # Build the graph incrementally and check for cycles accumulated_edges: List[Tuple[str, str]] = [] for predecessor_id, successor_id in dependencies: # Check if adding this edge (plus all previously accumulated edges) # would create a cycle cycle_result = DependencyService.detect_circular_dependency_detailed( db, predecessor_id, successor_id, project_id, additional_edges=accumulated_edges ) if cycle_result.has_cycle: errors.append({ "error_type": "circular", "predecessor_id": predecessor_id, "successor_id": successor_id, "message": f"Adding this dependency would create a circular reference: {cycle_result.get_cycle_description()}", "cycle": cycle_result.cycle_path, "cycle_description": cycle_result.get_cycle_description(), "cycle_task_titles": cycle_result.cycle_task_titles }) else: # Add this edge to accumulated edges for subsequent checks accumulated_edges.append((predecessor_id, successor_id)) return errors @staticmethod def detect_cycles_in_graph( db: Session, project_id: str ) -> List[CycleDetectionResult]: """ Detect all cycles in the existing dependency graph for a project. This is useful for auditing or cleanup operations. Args: db: Database session project_id: Project ID to check Returns: List of CycleDetectionResult for each cycle found """ cycles: List[CycleDetectionResult] = [] # Get all dependencies for the project dependencies = db.query(TaskDependency).join( Task, TaskDependency.successor_id == Task.id ).filter(Task.project_id == project_id).all() if not dependencies: return cycles # Build the graph graph: Dict[str, List[str]] = defaultdict(list) for dep in dependencies: graph[dep.successor_id].append(dep.predecessor_id) # Get task titles task_ids = set() for succ_id, pred_ids in graph.items(): task_ids.add(succ_id) task_ids.update(pred_ids) tasks = db.query(Task).filter(Task.id.in_(task_ids)).all() task_title_map: Dict[str, str] = {t.id: t.title for t in tasks} # Find all cycles using DFS visited: Set[str] = set() found_cycles: Set[Tuple[str, ...]] = set() def find_cycles_dfs(node: str, path: List[str], in_path: Set[str]): """DFS to find all cycles.""" if node in in_path: # Found a cycle cycle_start = path.index(node) cycle = tuple(sorted(path[cycle_start:])) # Normalize for dedup if cycle not in found_cycles: found_cycles.add(cycle) actual_cycle = path[cycle_start:] + [node] cycle_titles = [task_title_map.get(tid, tid) for tid in actual_cycle] cycles.append(CycleDetectionResult( has_cycle=True, cycle_path=actual_cycle, cycle_task_titles=cycle_titles )) return if node in visited: return visited.add(node) in_path.add(node) path.append(node) for neighbor in graph.get(node, []): find_cycles_dfs(neighbor, path.copy(), in_path.copy()) path.pop() in_path.remove(node) # Start DFS from all nodes for start_node in graph.keys(): if start_node not in visited: find_cycles_dfs(start_node, [], set()) return cycles