import uuid import logging from datetime import datetime, date from decimal import Decimal, InvalidOperation from typing import List, Dict, Any, Optional, Set, Tuple from sqlalchemy.orm import Session from sqlalchemy import func from app.models import Trigger, TriggerLog, Task, User, ProjectMember, CustomField, Role from app.services.notification_service import NotificationService from app.services.custom_value_service import CustomValueService from app.services.action_executor import ( ActionExecutor, ActionExecutionError, ActionValidationError, ) logger = logging.getLogger(__name__) class TriggerService: """Service for evaluating and executing triggers.""" SUPPORTED_FIELDS = ["status_id", "assignee_id", "priority", "start_date", "due_date", "custom_fields"] SUPPORTED_OPERATORS = ["equals", "not_equals", "changed_to", "changed_from", "before", "after", "in"] DATE_FIELDS = {"start_date", "due_date"} @staticmethod def evaluate_triggers( db: Session, task: Task, old_values: Dict[str, Any], new_values: Dict[str, Any], current_user: User, ) -> List[TriggerLog]: """Evaluate all active triggers for a project when task values change.""" logs: List[TriggerLog] = [] old_values = old_values or {} new_values = new_values or {} # Get active field_change triggers for the project triggers = db.query(Trigger).filter( Trigger.project_id == task.project_id, Trigger.is_active == True, Trigger.trigger_type == "field_change", ).all() if not triggers: return logs custom_field_ids: Set[str] = set() needs_custom_fields = False for trigger in triggers: rules = TriggerService._extract_rules(trigger.conditions or {}) for rule in rules: if rule.get("field") == "custom_fields": needs_custom_fields = True field_id = rule.get("field_id") if field_id: custom_field_ids.add(field_id) custom_field_types: Dict[str, str] = {} if custom_field_ids: fields = db.query(CustomField).filter(CustomField.id.in_(custom_field_ids)).all() custom_field_types = {f.id: f.field_type for f in fields} current_custom_values = None if needs_custom_fields: if isinstance(new_values.get("custom_fields"), dict): current_custom_values = new_values.get("custom_fields") else: current_custom_values = TriggerService._get_custom_values_map(db, task) current_values = { "status_id": task.status_id, "assignee_id": task.assignee_id, "priority": task.priority, "start_date": task.start_date, "due_date": task.due_date, "custom_fields": current_custom_values or {}, } changed_fields = TriggerService._detect_field_changes(old_values, new_values) changed_custom_field_ids = TriggerService._detect_custom_field_changes( old_values.get("custom_fields"), new_values.get("custom_fields"), ) for trigger in triggers: if TriggerService._check_conditions( trigger.conditions, old_values, new_values, current_values=current_values, changed_fields=changed_fields, changed_custom_field_ids=changed_custom_field_ids, custom_field_types=custom_field_types, ): log = TriggerService._execute_actions(db, trigger, task, current_user, old_values, new_values) logs.append(log) return logs @staticmethod def _check_conditions( conditions: Dict[str, Any], old_values: Dict[str, Any], new_values: Dict[str, Any], current_values: Optional[Dict[str, Any]] = None, changed_fields: Optional[Set[str]] = None, changed_custom_field_ids: Optional[Set[str]] = None, custom_field_types: Optional[Dict[str, str]] = None, ) -> bool: """Check if trigger conditions are met.""" old_values = old_values or {} new_values = new_values or {} current_values = current_values or new_values changed_fields = changed_fields or TriggerService._detect_field_changes(old_values, new_values) changed_custom_field_ids = changed_custom_field_ids or TriggerService._detect_custom_field_changes( old_values.get("custom_fields"), new_values.get("custom_fields"), ) custom_field_types = custom_field_types or {} rules = TriggerService._extract_rules(conditions) if not rules: return False if conditions.get("rules") is not None and conditions.get("logic") != "and": return False any_rule_changed = False for rule in rules: field = rule.get("field") operator = rule.get("operator") value = rule.get("value") field_id = rule.get("field_id") if field not in TriggerService.SUPPORTED_FIELDS: return False if operator not in TriggerService.SUPPORTED_OPERATORS: return False if field == "custom_fields": if not field_id: return False custom_values = current_values.get("custom_fields") or {} old_custom = old_values.get("custom_fields") or {} new_custom = new_values.get("custom_fields") or {} current_value = custom_values.get(field_id) old_value = old_custom.get(field_id) new_value = new_custom.get(field_id) field_type = TriggerService._normalize_field_type(custom_field_types.get(field_id)) field_changed = field_id in changed_custom_field_ids else: current_value = current_values.get(field) old_value = old_values.get(field) new_value = new_values.get(field) field_type = "date" if field in TriggerService.DATE_FIELDS else None field_changed = field in changed_fields if TriggerService._evaluate_rule( operator, current_value, old_value, new_value, value, field_type, field_changed, ) is False: return False if field_changed: any_rule_changed = True return any_rule_changed @staticmethod def _extract_rules(conditions: Dict[str, Any]) -> List[Dict[str, Any]]: rules = conditions.get("rules") if isinstance(rules, list): return rules field = conditions.get("field") if field: rule = { "field": field, "operator": conditions.get("operator"), "value": conditions.get("value"), } if conditions.get("field_id"): rule["field_id"] = conditions.get("field_id") return [rule] return [] @staticmethod def _get_custom_values_map(db: Session, task: Task) -> Dict[str, Any]: values = CustomValueService.get_custom_values_for_task( db, task, include_formula_calculations=True, ) return {cv.field_id: cv.value for cv in values} @staticmethod def _detect_field_changes( old_values: Dict[str, Any], new_values: Dict[str, Any], ) -> Set[str]: changed = set() for field in TriggerService.SUPPORTED_FIELDS: if field == "custom_fields": continue if field in old_values or field in new_values: if old_values.get(field) != new_values.get(field): changed.add(field) return changed @staticmethod def _detect_custom_field_changes( old_custom_values: Any, new_custom_values: Any, ) -> Set[str]: if not isinstance(old_custom_values, dict) or not isinstance(new_custom_values, dict): return set() changed_ids = set() field_ids = set(old_custom_values.keys()) | set(new_custom_values.keys()) for field_id in field_ids: if old_custom_values.get(field_id) != new_custom_values.get(field_id): changed_ids.add(field_id) return changed_ids @staticmethod def _normalize_field_type(field_type: Optional[str]) -> Optional[str]: if not field_type: return None if field_type == "formula": return "number" return field_type @staticmethod def _evaluate_rule( operator: str, current_value: Any, old_value: Any, new_value: Any, target_value: Any, field_type: Optional[str], field_changed: bool, ) -> bool: if operator in ("changed_to", "changed_from"): if not field_changed: return False if operator == "changed_to": return ( TriggerService._value_equals(new_value, target_value, field_type) and not TriggerService._value_equals(old_value, target_value, field_type) ) return ( TriggerService._value_equals(old_value, target_value, field_type) and not TriggerService._value_equals(new_value, target_value, field_type) ) if operator == "equals": return TriggerService._value_equals(current_value, target_value, field_type) if operator == "not_equals": return not TriggerService._value_equals(current_value, target_value, field_type) if operator == "before": return TriggerService._compare_before(current_value, target_value, field_type) if operator == "after": return TriggerService._compare_after(current_value, target_value, field_type) if operator == "in": return TriggerService._compare_in(current_value, target_value, field_type) return False @staticmethod def _value_equals(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool: if current_value is None: return target_value is None if field_type == "date": current_dt, current_date_only = TriggerService._parse_datetime_value(current_value) target_dt, target_date_only = TriggerService._parse_datetime_value(target_value) if not current_dt or not target_dt: return False if current_date_only or target_date_only: return current_dt.date() == target_dt.date() return current_dt == target_dt if field_type == "number": current_num = TriggerService._parse_number_value(current_value) target_num = TriggerService._parse_number_value(target_value) if current_num is None or target_num is None: return False return current_num == target_num if isinstance(target_value, (list, dict)): return False return str(current_value) == str(target_value) @staticmethod def _compare_before(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool: if current_value is None or target_value is None: return False if field_type == "date": current_dt, current_date_only = TriggerService._parse_datetime_value(current_value) target_dt, target_date_only = TriggerService._parse_datetime_value(target_value) if not current_dt or not target_dt: return False if current_date_only or target_date_only: return current_dt.date() < target_dt.date() return current_dt < target_dt if field_type == "number": current_num = TriggerService._parse_number_value(current_value) target_num = TriggerService._parse_number_value(target_value) if current_num is None or target_num is None: return False return current_num < target_num return False @staticmethod def _compare_after(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool: if current_value is None or target_value is None: return False if field_type == "date": current_dt, current_date_only = TriggerService._parse_datetime_value(current_value) target_dt, target_date_only = TriggerService._parse_datetime_value(target_value) if not current_dt or not target_dt: return False if current_date_only or target_date_only: return current_dt.date() > target_dt.date() return current_dt > target_dt if field_type == "number": current_num = TriggerService._parse_number_value(current_value) target_num = TriggerService._parse_number_value(target_value) if current_num is None or target_num is None: return False return current_num > target_num return False @staticmethod def _compare_in(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool: if current_value is None or target_value is None: return False if field_type == "date": if not isinstance(target_value, dict): return False start_dt, start_date_only = TriggerService._parse_datetime_value(target_value.get("start")) end_dt, end_date_only = TriggerService._parse_datetime_value(target_value.get("end")) current_dt, current_date_only = TriggerService._parse_datetime_value(current_value) if not start_dt or not end_dt or not current_dt: return False date_only = current_date_only or start_date_only or end_date_only if date_only: current_date = current_dt.date() return start_dt.date() <= current_date <= end_dt.date() return start_dt <= current_dt <= end_dt if isinstance(target_value, (list, tuple, set)): if field_type == "number": current_num = TriggerService._parse_number_value(current_value) if current_num is None: return False for item in target_value: item_num = TriggerService._parse_number_value(item) if item_num is not None and item_num == current_num: return True return False return str(current_value) in {str(item) for item in target_value if item is not None} return False @staticmethod def _parse_datetime_value(value: Any) -> Tuple[Optional[datetime], bool]: if value is None: return None, False if isinstance(value, datetime): return value, False if isinstance(value, date): return datetime.combine(value, datetime.min.time()), True if isinstance(value, str): try: if len(value) == 10: parsed = datetime.strptime(value, "%Y-%m-%d") return parsed, True parsed = datetime.fromisoformat(value.replace("Z", "+00:00")) return parsed.replace(tzinfo=None), False except ValueError: return None, False return None, False @staticmethod def _parse_number_value(value: Any) -> Optional[Decimal]: if value is None or value == "": return None try: return Decimal(str(value)) except (InvalidOperation, ValueError): return None @staticmethod def _execute_actions( db: Session, trigger: Trigger, task: Task, current_user: User, old_values: Dict[str, Any], new_values: Dict[str, Any], ) -> TriggerLog: """Execute trigger actions and log the result. Uses a database savepoint to ensure atomicity - if any action fails, all previously executed actions within this trigger are rolled back. """ actions = trigger.actions if isinstance(trigger.actions, list) else [trigger.actions] executed_actions = [] error_message = None # Build execution context context = { "old_values": old_values, "new_values": new_values, "current_user": current_user, "trigger": trigger, } # Use savepoint for transaction atomicity - if any action fails, # all changes made by previous actions will be rolled back savepoint = db.begin_nested() try: for action in actions: action_type = action.get("type") # Handle built-in notify action if action_type == "notify": TriggerService._execute_notify_action(db, action, task, current_user, old_values, new_values) executed_actions.append({"type": action_type, "status": "success"}) # Handle update_field action (FEAT-014) elif action_type == "update_field": result = ActionExecutor.execute_action(db, task, action, context) if result: executed_actions.append(result) logger.info( f"Trigger '{trigger.name}' executed update_field: " f"field={result.get('field')}, new_value={result.get('new_value')}" ) # Handle auto_assign action (FEAT-015) elif action_type == "auto_assign": result = ActionExecutor.execute_action(db, task, action, context) if result: executed_actions.append(result) logger.info( f"Trigger '{trigger.name}' executed auto_assign: " f"strategy={result.get('strategy')}, assignee={result.get('new_assignee_id')}" ) # Try to execute via ActionExecutor for extensibility else: result = ActionExecutor.execute_action(db, task, action, context) if result: executed_actions.append(result) # All actions succeeded, commit the savepoint savepoint.commit() status = "success" except ActionExecutionError as e: # Rollback all changes made by previously executed actions savepoint.rollback() status = "failed" error_message = str(e) executed_actions.append({"type": "error", "message": str(e)}) logger.error(f"Trigger '{trigger.name}' action execution failed, rolling back: {e}") except Exception as e: # Rollback all changes made by previously executed actions savepoint.rollback() status = "failed" error_message = str(e) executed_actions.append({"type": "error", "message": str(e)}) logger.exception(f"Trigger '{trigger.name}' unexpected error, rolling back: {e}") log = TriggerLog( id=str(uuid.uuid4()), trigger_id=trigger.id, task_id=task.id, status=status, details={ "trigger_name": trigger.name, "old_values": old_values, "new_values": new_values, "actions_executed": executed_actions, }, error_message=error_message, ) db.add(log) return log @staticmethod def _execute_notify_action( db: Session, action: Dict[str, Any], task: Task, current_user: User, old_values: Dict[str, Any], new_values: Dict[str, Any], ) -> None: """Execute a notify action.""" target = action.get("target", "assignee") template = action.get("template", "任務 {task_title} 已觸發自動化規則") recipients = TriggerService._resolve_targets(db, task, target) if not recipients: return # Format message with variables message = TriggerService._format_template(template, task, old_values, new_values) for user_id in recipients: if user_id == current_user.id: continue NotificationService.create_notification( db=db, user_id=user_id, notification_type="status_change", reference_type="task", reference_id=task.id, title=f"自動化通知: {task.title}", message=message, ) @staticmethod def _resolve_targets(db: Session, task: Task, target: str) -> List[str]: """Resolve notification target to user IDs.""" recipients: Set[str] = set() if target == "assignee": if task.assignee_id: recipients.add(task.assignee_id) elif target == "creator": if task.created_by: recipients.add(task.created_by) elif target == "project_owner": if task.project and task.project.owner_id: recipients.add(task.project.owner_id) elif target == "project_members": if task.project: if task.project.owner_id: recipients.add(task.project.owner_id) member_rows = db.query(ProjectMember.user_id).join( User, User.id == ProjectMember.user_id, ).filter( ProjectMember.project_id == task.project_id, User.is_active == True, ).all() recipients.update(row[0] for row in member_rows if row and row[0]) elif target.startswith("department:"): department_id = target.split(":", 1)[1] if department_id: user_rows = db.query(User.id).filter( User.department_id == department_id, User.is_active == True, ).all() recipients.update(row[0] for row in user_rows if row and row[0]) elif target.startswith("role:"): role_name = target.split(":", 1)[1].strip() if role_name: role = db.query(Role).filter(func.lower(Role.name) == role_name.lower()).first() if role: user_rows = db.query(User.id).filter( User.role_id == role.id, User.is_active == True, ).all() recipients.update(row[0] for row in user_rows if row and row[0]) elif target.startswith("user:"): user_id = target.split(":", 1)[1] if user_id: recipients.add(user_id) return list(recipients) @staticmethod def _format_template( template: str, task: Task, old_values: Dict[str, Any], new_values: Dict[str, Any], ) -> str: """Format message template with task variables.""" replacements = { "{task_title}": task.title, "{task_id}": task.id, "{old_value}": str(old_values.get("status_id", old_values.get("assignee_id", old_values.get("priority", "")))), "{new_value}": str(new_values.get("status_id", new_values.get("assignee_id", new_values.get("priority", "")))), } result = template for key, value in replacements.items(): result = result.replace(key, value) return result @staticmethod def log_execution( db: Session, trigger: Trigger, task: Optional[Task], status: str, details: Optional[Dict[str, Any]] = None, error_message: Optional[str] = None, ) -> TriggerLog: """Log a trigger execution.""" log = TriggerLog( id=str(uuid.uuid4()), trigger_id=trigger.id, task_id=task.id if task else None, status=status, details=details, error_message=error_message, ) db.add(log) return log @staticmethod def validate_actions(actions: List[Dict[str, Any]], db: Session) -> None: """Validate trigger actions configuration. Args: actions: List of action configurations db: Database session Raises: ActionValidationError: If any action is invalid """ valid_action_types = ["notify", "update_field", "auto_assign"] for action in actions: action_type = action.get("type") if not action_type: raise ActionValidationError("Missing action 'type'") if action_type not in valid_action_types: raise ActionValidationError( f"Invalid action type '{action_type}'. " f"Valid types: {valid_action_types}" ) # Validate via ActionExecutor for extensible actions if action_type in ["update_field", "auto_assign"]: ActionExecutor.validate_action(action, db) @staticmethod def get_supported_action_types() -> List[str]: """Get list of all supported action types.""" return ["notify"] + ActionExecutor.get_supported_actions()