Files
PROJECT-CONTORL/backend/app/services/trigger_service.py
2026-01-11 08:37:21 +08:00

659 lines
25 KiB
Python

import uuid
import logging
from datetime import datetime, date
from decimal import Decimal, InvalidOperation
from typing import List, Dict, Any, Optional, Set, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.models import Trigger, TriggerLog, Task, User, ProjectMember, CustomField, Role
from app.services.notification_service import NotificationService
from app.services.custom_value_service import CustomValueService
from app.services.action_executor import (
ActionExecutor,
ActionExecutionError,
ActionValidationError,
)
logger = logging.getLogger(__name__)
class TriggerService:
"""Service for evaluating and executing triggers."""
SUPPORTED_FIELDS = ["status_id", "assignee_id", "priority", "start_date", "due_date", "custom_fields"]
SUPPORTED_OPERATORS = ["equals", "not_equals", "changed_to", "changed_from", "before", "after", "in"]
DATE_FIELDS = {"start_date", "due_date"}
@staticmethod
def evaluate_triggers(
db: Session,
task: Task,
old_values: Dict[str, Any],
new_values: Dict[str, Any],
current_user: User,
) -> List[TriggerLog]:
"""Evaluate all active triggers for a project when task values change."""
logs: List[TriggerLog] = []
old_values = old_values or {}
new_values = new_values or {}
# Get active field_change triggers for the project
triggers = db.query(Trigger).filter(
Trigger.project_id == task.project_id,
Trigger.is_active == True,
Trigger.trigger_type == "field_change",
).all()
if not triggers:
return logs
custom_field_ids: Set[str] = set()
needs_custom_fields = False
for trigger in triggers:
rules = TriggerService._extract_rules(trigger.conditions or {})
for rule in rules:
if rule.get("field") == "custom_fields":
needs_custom_fields = True
field_id = rule.get("field_id")
if field_id:
custom_field_ids.add(field_id)
custom_field_types: Dict[str, str] = {}
if custom_field_ids:
fields = db.query(CustomField).filter(CustomField.id.in_(custom_field_ids)).all()
custom_field_types = {f.id: f.field_type for f in fields}
current_custom_values = None
if needs_custom_fields:
if isinstance(new_values.get("custom_fields"), dict):
current_custom_values = new_values.get("custom_fields")
else:
current_custom_values = TriggerService._get_custom_values_map(db, task)
current_values = {
"status_id": task.status_id,
"assignee_id": task.assignee_id,
"priority": task.priority,
"start_date": task.start_date,
"due_date": task.due_date,
"custom_fields": current_custom_values or {},
}
changed_fields = TriggerService._detect_field_changes(old_values, new_values)
changed_custom_field_ids = TriggerService._detect_custom_field_changes(
old_values.get("custom_fields"),
new_values.get("custom_fields"),
)
for trigger in triggers:
if TriggerService._check_conditions(
trigger.conditions,
old_values,
new_values,
current_values=current_values,
changed_fields=changed_fields,
changed_custom_field_ids=changed_custom_field_ids,
custom_field_types=custom_field_types,
):
log = TriggerService._execute_actions(db, trigger, task, current_user, old_values, new_values)
logs.append(log)
return logs
@staticmethod
def _check_conditions(
conditions: Dict[str, Any],
old_values: Dict[str, Any],
new_values: Dict[str, Any],
current_values: Optional[Dict[str, Any]] = None,
changed_fields: Optional[Set[str]] = None,
changed_custom_field_ids: Optional[Set[str]] = None,
custom_field_types: Optional[Dict[str, str]] = None,
) -> bool:
"""Check if trigger conditions are met."""
old_values = old_values or {}
new_values = new_values or {}
current_values = current_values or new_values
changed_fields = changed_fields or TriggerService._detect_field_changes(old_values, new_values)
changed_custom_field_ids = changed_custom_field_ids or TriggerService._detect_custom_field_changes(
old_values.get("custom_fields"),
new_values.get("custom_fields"),
)
custom_field_types = custom_field_types or {}
rules = TriggerService._extract_rules(conditions)
if not rules:
return False
if conditions.get("rules") is not None and conditions.get("logic") != "and":
return False
any_rule_changed = False
for rule in rules:
field = rule.get("field")
operator = rule.get("operator")
value = rule.get("value")
field_id = rule.get("field_id")
if field not in TriggerService.SUPPORTED_FIELDS:
return False
if operator not in TriggerService.SUPPORTED_OPERATORS:
return False
if field == "custom_fields":
if not field_id:
return False
custom_values = current_values.get("custom_fields") or {}
old_custom = old_values.get("custom_fields") or {}
new_custom = new_values.get("custom_fields") or {}
current_value = custom_values.get(field_id)
old_value = old_custom.get(field_id)
new_value = new_custom.get(field_id)
field_type = TriggerService._normalize_field_type(custom_field_types.get(field_id))
field_changed = field_id in changed_custom_field_ids
else:
current_value = current_values.get(field)
old_value = old_values.get(field)
new_value = new_values.get(field)
field_type = "date" if field in TriggerService.DATE_FIELDS else None
field_changed = field in changed_fields
if TriggerService._evaluate_rule(
operator,
current_value,
old_value,
new_value,
value,
field_type,
field_changed,
) is False:
return False
if field_changed:
any_rule_changed = True
return any_rule_changed
@staticmethod
def _extract_rules(conditions: Dict[str, Any]) -> List[Dict[str, Any]]:
rules = conditions.get("rules")
if isinstance(rules, list):
return rules
field = conditions.get("field")
if field:
rule = {
"field": field,
"operator": conditions.get("operator"),
"value": conditions.get("value"),
}
if conditions.get("field_id"):
rule["field_id"] = conditions.get("field_id")
return [rule]
return []
@staticmethod
def _get_custom_values_map(db: Session, task: Task) -> Dict[str, Any]:
values = CustomValueService.get_custom_values_for_task(
db,
task,
include_formula_calculations=True,
)
return {cv.field_id: cv.value for cv in values}
@staticmethod
def _detect_field_changes(
old_values: Dict[str, Any],
new_values: Dict[str, Any],
) -> Set[str]:
changed = set()
for field in TriggerService.SUPPORTED_FIELDS:
if field == "custom_fields":
continue
if field in old_values or field in new_values:
if old_values.get(field) != new_values.get(field):
changed.add(field)
return changed
@staticmethod
def _detect_custom_field_changes(
old_custom_values: Any,
new_custom_values: Any,
) -> Set[str]:
if not isinstance(old_custom_values, dict) or not isinstance(new_custom_values, dict):
return set()
changed_ids = set()
field_ids = set(old_custom_values.keys()) | set(new_custom_values.keys())
for field_id in field_ids:
if old_custom_values.get(field_id) != new_custom_values.get(field_id):
changed_ids.add(field_id)
return changed_ids
@staticmethod
def _normalize_field_type(field_type: Optional[str]) -> Optional[str]:
if not field_type:
return None
if field_type == "formula":
return "number"
return field_type
@staticmethod
def _evaluate_rule(
operator: str,
current_value: Any,
old_value: Any,
new_value: Any,
target_value: Any,
field_type: Optional[str],
field_changed: bool,
) -> bool:
if operator in ("changed_to", "changed_from"):
if not field_changed:
return False
if operator == "changed_to":
return (
TriggerService._value_equals(new_value, target_value, field_type)
and not TriggerService._value_equals(old_value, target_value, field_type)
)
return (
TriggerService._value_equals(old_value, target_value, field_type)
and not TriggerService._value_equals(new_value, target_value, field_type)
)
if operator == "equals":
return TriggerService._value_equals(current_value, target_value, field_type)
if operator == "not_equals":
return not TriggerService._value_equals(current_value, target_value, field_type)
if operator == "before":
return TriggerService._compare_before(current_value, target_value, field_type)
if operator == "after":
return TriggerService._compare_after(current_value, target_value, field_type)
if operator == "in":
return TriggerService._compare_in(current_value, target_value, field_type)
return False
@staticmethod
def _value_equals(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool:
if current_value is None:
return target_value is None
if field_type == "date":
current_dt, current_date_only = TriggerService._parse_datetime_value(current_value)
target_dt, target_date_only = TriggerService._parse_datetime_value(target_value)
if not current_dt or not target_dt:
return False
if current_date_only or target_date_only:
return current_dt.date() == target_dt.date()
return current_dt == target_dt
if field_type == "number":
current_num = TriggerService._parse_number_value(current_value)
target_num = TriggerService._parse_number_value(target_value)
if current_num is None or target_num is None:
return False
return current_num == target_num
if isinstance(target_value, (list, dict)):
return False
return str(current_value) == str(target_value)
@staticmethod
def _compare_before(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool:
if current_value is None or target_value is None:
return False
if field_type == "date":
current_dt, current_date_only = TriggerService._parse_datetime_value(current_value)
target_dt, target_date_only = TriggerService._parse_datetime_value(target_value)
if not current_dt or not target_dt:
return False
if current_date_only or target_date_only:
return current_dt.date() < target_dt.date()
return current_dt < target_dt
if field_type == "number":
current_num = TriggerService._parse_number_value(current_value)
target_num = TriggerService._parse_number_value(target_value)
if current_num is None or target_num is None:
return False
return current_num < target_num
return False
@staticmethod
def _compare_after(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool:
if current_value is None or target_value is None:
return False
if field_type == "date":
current_dt, current_date_only = TriggerService._parse_datetime_value(current_value)
target_dt, target_date_only = TriggerService._parse_datetime_value(target_value)
if not current_dt or not target_dt:
return False
if current_date_only or target_date_only:
return current_dt.date() > target_dt.date()
return current_dt > target_dt
if field_type == "number":
current_num = TriggerService._parse_number_value(current_value)
target_num = TriggerService._parse_number_value(target_value)
if current_num is None or target_num is None:
return False
return current_num > target_num
return False
@staticmethod
def _compare_in(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool:
if current_value is None or target_value is None:
return False
if field_type == "date":
if not isinstance(target_value, dict):
return False
start_dt, start_date_only = TriggerService._parse_datetime_value(target_value.get("start"))
end_dt, end_date_only = TriggerService._parse_datetime_value(target_value.get("end"))
current_dt, current_date_only = TriggerService._parse_datetime_value(current_value)
if not start_dt or not end_dt or not current_dt:
return False
date_only = current_date_only or start_date_only or end_date_only
if date_only:
current_date = current_dt.date()
return start_dt.date() <= current_date <= end_dt.date()
return start_dt <= current_dt <= end_dt
if isinstance(target_value, (list, tuple, set)):
if field_type == "number":
current_num = TriggerService._parse_number_value(current_value)
if current_num is None:
return False
for item in target_value:
item_num = TriggerService._parse_number_value(item)
if item_num is not None and item_num == current_num:
return True
return False
return str(current_value) in {str(item) for item in target_value if item is not None}
return False
@staticmethod
def _parse_datetime_value(value: Any) -> Tuple[Optional[datetime], bool]:
if value is None:
return None, False
if isinstance(value, datetime):
return value, False
if isinstance(value, date):
return datetime.combine(value, datetime.min.time()), True
if isinstance(value, str):
try:
if len(value) == 10:
parsed = datetime.strptime(value, "%Y-%m-%d")
return parsed, True
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
return parsed.replace(tzinfo=None), False
except ValueError:
return None, False
return None, False
@staticmethod
def _parse_number_value(value: Any) -> Optional[Decimal]:
if value is None or value == "":
return None
try:
return Decimal(str(value))
except (InvalidOperation, ValueError):
return None
@staticmethod
def _execute_actions(
db: Session,
trigger: Trigger,
task: Task,
current_user: User,
old_values: Dict[str, Any],
new_values: Dict[str, Any],
) -> TriggerLog:
"""Execute trigger actions and log the result.
Uses a database savepoint to ensure atomicity - if any action fails,
all previously executed actions within this trigger are rolled back.
"""
actions = trigger.actions if isinstance(trigger.actions, list) else [trigger.actions]
executed_actions = []
error_message = None
# Build execution context
context = {
"old_values": old_values,
"new_values": new_values,
"current_user": current_user,
"trigger": trigger,
}
# Use savepoint for transaction atomicity - if any action fails,
# all changes made by previous actions will be rolled back
savepoint = db.begin_nested()
try:
for action in actions:
action_type = action.get("type")
# Handle built-in notify action
if action_type == "notify":
TriggerService._execute_notify_action(db, action, task, current_user, old_values, new_values)
executed_actions.append({"type": action_type, "status": "success"})
# Handle update_field action (FEAT-014)
elif action_type == "update_field":
result = ActionExecutor.execute_action(db, task, action, context)
if result:
executed_actions.append(result)
logger.info(
f"Trigger '{trigger.name}' executed update_field: "
f"field={result.get('field')}, new_value={result.get('new_value')}"
)
# Handle auto_assign action (FEAT-015)
elif action_type == "auto_assign":
result = ActionExecutor.execute_action(db, task, action, context)
if result:
executed_actions.append(result)
logger.info(
f"Trigger '{trigger.name}' executed auto_assign: "
f"strategy={result.get('strategy')}, assignee={result.get('new_assignee_id')}"
)
# Try to execute via ActionExecutor for extensibility
else:
result = ActionExecutor.execute_action(db, task, action, context)
if result:
executed_actions.append(result)
# All actions succeeded, commit the savepoint
savepoint.commit()
status = "success"
except ActionExecutionError as e:
# Rollback all changes made by previously executed actions
savepoint.rollback()
status = "failed"
error_message = str(e)
executed_actions.append({"type": "error", "message": str(e)})
logger.error(f"Trigger '{trigger.name}' action execution failed, rolling back: {e}")
except Exception as e:
# Rollback all changes made by previously executed actions
savepoint.rollback()
status = "failed"
error_message = str(e)
executed_actions.append({"type": "error", "message": str(e)})
logger.exception(f"Trigger '{trigger.name}' unexpected error, rolling back: {e}")
log = TriggerLog(
id=str(uuid.uuid4()),
trigger_id=trigger.id,
task_id=task.id,
status=status,
details={
"trigger_name": trigger.name,
"old_values": old_values,
"new_values": new_values,
"actions_executed": executed_actions,
},
error_message=error_message,
)
db.add(log)
return log
@staticmethod
def _execute_notify_action(
db: Session,
action: Dict[str, Any],
task: Task,
current_user: User,
old_values: Dict[str, Any],
new_values: Dict[str, Any],
) -> None:
"""Execute a notify action."""
target = action.get("target", "assignee")
template = action.get("template", "任務 {task_title} 已觸發自動化規則")
recipients = TriggerService._resolve_targets(db, task, target)
if not recipients:
return
# Format message with variables
message = TriggerService._format_template(template, task, old_values, new_values)
for user_id in recipients:
if user_id == current_user.id:
continue
NotificationService.create_notification(
db=db,
user_id=user_id,
notification_type="status_change",
reference_type="task",
reference_id=task.id,
title=f"自動化通知: {task.title}",
message=message,
)
@staticmethod
def _resolve_targets(db: Session, task: Task, target: str) -> List[str]:
"""Resolve notification target to user IDs."""
recipients: Set[str] = set()
if target == "assignee":
if task.assignee_id:
recipients.add(task.assignee_id)
elif target == "creator":
if task.created_by:
recipients.add(task.created_by)
elif target == "project_owner":
if task.project and task.project.owner_id:
recipients.add(task.project.owner_id)
elif target == "project_members":
if task.project:
if task.project.owner_id:
recipients.add(task.project.owner_id)
member_rows = db.query(ProjectMember.user_id).join(
User,
User.id == ProjectMember.user_id,
).filter(
ProjectMember.project_id == task.project_id,
User.is_active == True,
).all()
recipients.update(row[0] for row in member_rows if row and row[0])
elif target.startswith("department:"):
department_id = target.split(":", 1)[1]
if department_id:
user_rows = db.query(User.id).filter(
User.department_id == department_id,
User.is_active == True,
).all()
recipients.update(row[0] for row in user_rows if row and row[0])
elif target.startswith("role:"):
role_name = target.split(":", 1)[1].strip()
if role_name:
role = db.query(Role).filter(func.lower(Role.name) == role_name.lower()).first()
if role:
user_rows = db.query(User.id).filter(
User.role_id == role.id,
User.is_active == True,
).all()
recipients.update(row[0] for row in user_rows if row and row[0])
elif target.startswith("user:"):
user_id = target.split(":", 1)[1]
if user_id:
recipients.add(user_id)
return list(recipients)
@staticmethod
def _format_template(
template: str,
task: Task,
old_values: Dict[str, Any],
new_values: Dict[str, Any],
) -> str:
"""Format message template with task variables."""
replacements = {
"{task_title}": task.title,
"{task_id}": task.id,
"{old_value}": str(old_values.get("status_id", old_values.get("assignee_id", old_values.get("priority", "")))),
"{new_value}": str(new_values.get("status_id", new_values.get("assignee_id", new_values.get("priority", "")))),
}
result = template
for key, value in replacements.items():
result = result.replace(key, value)
return result
@staticmethod
def log_execution(
db: Session,
trigger: Trigger,
task: Optional[Task],
status: str,
details: Optional[Dict[str, Any]] = None,
error_message: Optional[str] = None,
) -> TriggerLog:
"""Log a trigger execution."""
log = TriggerLog(
id=str(uuid.uuid4()),
trigger_id=trigger.id,
task_id=task.id if task else None,
status=status,
details=details,
error_message=error_message,
)
db.add(log)
return log
@staticmethod
def validate_actions(actions: List[Dict[str, Any]], db: Session) -> None:
"""Validate trigger actions configuration.
Args:
actions: List of action configurations
db: Database session
Raises:
ActionValidationError: If any action is invalid
"""
valid_action_types = ["notify", "update_field", "auto_assign"]
for action in actions:
action_type = action.get("type")
if not action_type:
raise ActionValidationError("Missing action 'type'")
if action_type not in valid_action_types:
raise ActionValidationError(
f"Invalid action type '{action_type}'. "
f"Valid types: {valid_action_types}"
)
# Validate via ActionExecutor for extensible actions
if action_type in ["update_field", "auto_assign"]:
ActionExecutor.validate_action(action, db)
@staticmethod
def get_supported_action_types() -> List[str]:
"""Get list of all supported action types."""
return ["notify"] + ActionExecutor.get_supported_actions()