Fix test failures and workload/websocket behavior

This commit is contained in:
beabigegg
2026-01-11 08:37:21 +08:00
parent 3bdc6ff1c9
commit f5f870da56
49 changed files with 3006 additions and 1132 deletions

View File

@@ -1,10 +1,14 @@
import uuid
import logging
from typing import List, Dict, Any, Optional
from datetime import datetime, date
from decimal import Decimal, InvalidOperation
from typing import List, Dict, Any, Optional, Set, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.models import Trigger, TriggerLog, Task, User, Project
from app.models import Trigger, TriggerLog, Task, User, ProjectMember, CustomField, Role
from app.services.notification_service import NotificationService
from app.services.custom_value_service import CustomValueService
from app.services.action_executor import (
ActionExecutor,
ActionExecutionError,
@@ -17,8 +21,9 @@ logger = logging.getLogger(__name__)
class TriggerService:
"""Service for evaluating and executing triggers."""
SUPPORTED_FIELDS = ["status_id", "assignee_id", "priority"]
SUPPORTED_OPERATORS = ["equals", "not_equals", "changed_to", "changed_from"]
SUPPORTED_FIELDS = ["status_id", "assignee_id", "priority", "start_date", "due_date", "custom_fields"]
SUPPORTED_OPERATORS = ["equals", "not_equals", "changed_to", "changed_from", "before", "after", "in"]
DATE_FIELDS = {"start_date", "due_date"}
@staticmethod
def evaluate_triggers(
@@ -29,7 +34,9 @@ class TriggerService:
current_user: User,
) -> List[TriggerLog]:
"""Evaluate all active triggers for a project when task values change."""
logs = []
logs: List[TriggerLog] = []
old_values = old_values or {}
new_values = new_values or {}
# Get active field_change triggers for the project
triggers = db.query(Trigger).filter(
@@ -38,8 +45,58 @@ class TriggerService:
Trigger.trigger_type == "field_change",
).all()
if not triggers:
return logs
custom_field_ids: Set[str] = set()
needs_custom_fields = False
for trigger in triggers:
if TriggerService._check_conditions(trigger.conditions, old_values, new_values):
rules = TriggerService._extract_rules(trigger.conditions or {})
for rule in rules:
if rule.get("field") == "custom_fields":
needs_custom_fields = True
field_id = rule.get("field_id")
if field_id:
custom_field_ids.add(field_id)
custom_field_types: Dict[str, str] = {}
if custom_field_ids:
fields = db.query(CustomField).filter(CustomField.id.in_(custom_field_ids)).all()
custom_field_types = {f.id: f.field_type for f in fields}
current_custom_values = None
if needs_custom_fields:
if isinstance(new_values.get("custom_fields"), dict):
current_custom_values = new_values.get("custom_fields")
else:
current_custom_values = TriggerService._get_custom_values_map(db, task)
current_values = {
"status_id": task.status_id,
"assignee_id": task.assignee_id,
"priority": task.priority,
"start_date": task.start_date,
"due_date": task.due_date,
"custom_fields": current_custom_values or {},
}
changed_fields = TriggerService._detect_field_changes(old_values, new_values)
changed_custom_field_ids = TriggerService._detect_custom_field_changes(
old_values.get("custom_fields"),
new_values.get("custom_fields"),
)
for trigger in triggers:
if TriggerService._check_conditions(
trigger.conditions,
old_values,
new_values,
current_values=current_values,
changed_fields=changed_fields,
changed_custom_field_ids=changed_custom_field_ids,
custom_field_types=custom_field_types,
):
log = TriggerService._execute_actions(db, trigger, task, current_user, old_values, new_values)
logs.append(log)
@@ -50,29 +107,298 @@ class TriggerService:
conditions: Dict[str, Any],
old_values: Dict[str, Any],
new_values: Dict[str, Any],
current_values: Optional[Dict[str, Any]] = None,
changed_fields: Optional[Set[str]] = None,
changed_custom_field_ids: Optional[Set[str]] = None,
custom_field_types: Optional[Dict[str, str]] = None,
) -> bool:
"""Check if trigger conditions are met."""
field = conditions.get("field")
operator = conditions.get("operator")
value = conditions.get("value")
old_values = old_values or {}
new_values = new_values or {}
current_values = current_values or new_values
changed_fields = changed_fields or TriggerService._detect_field_changes(old_values, new_values)
changed_custom_field_ids = changed_custom_field_ids or TriggerService._detect_custom_field_changes(
old_values.get("custom_fields"),
new_values.get("custom_fields"),
)
custom_field_types = custom_field_types or {}
if field not in TriggerService.SUPPORTED_FIELDS:
rules = TriggerService._extract_rules(conditions)
if not rules:
return False
old_value = old_values.get(field)
new_value = new_values.get(field)
if conditions.get("rules") is not None and conditions.get("logic") != "and":
return False
any_rule_changed = False
for rule in rules:
field = rule.get("field")
operator = rule.get("operator")
value = rule.get("value")
field_id = rule.get("field_id")
if field not in TriggerService.SUPPORTED_FIELDS:
return False
if operator not in TriggerService.SUPPORTED_OPERATORS:
return False
if field == "custom_fields":
if not field_id:
return False
custom_values = current_values.get("custom_fields") or {}
old_custom = old_values.get("custom_fields") or {}
new_custom = new_values.get("custom_fields") or {}
current_value = custom_values.get(field_id)
old_value = old_custom.get(field_id)
new_value = new_custom.get(field_id)
field_type = TriggerService._normalize_field_type(custom_field_types.get(field_id))
field_changed = field_id in changed_custom_field_ids
else:
current_value = current_values.get(field)
old_value = old_values.get(field)
new_value = new_values.get(field)
field_type = "date" if field in TriggerService.DATE_FIELDS else None
field_changed = field in changed_fields
if TriggerService._evaluate_rule(
operator,
current_value,
old_value,
new_value,
value,
field_type,
field_changed,
) is False:
return False
if field_changed:
any_rule_changed = True
return any_rule_changed
@staticmethod
def _extract_rules(conditions: Dict[str, Any]) -> List[Dict[str, Any]]:
rules = conditions.get("rules")
if isinstance(rules, list):
return rules
field = conditions.get("field")
if field:
rule = {
"field": field,
"operator": conditions.get("operator"),
"value": conditions.get("value"),
}
if conditions.get("field_id"):
rule["field_id"] = conditions.get("field_id")
return [rule]
return []
@staticmethod
def _get_custom_values_map(db: Session, task: Task) -> Dict[str, Any]:
values = CustomValueService.get_custom_values_for_task(
db,
task,
include_formula_calculations=True,
)
return {cv.field_id: cv.value for cv in values}
@staticmethod
def _detect_field_changes(
old_values: Dict[str, Any],
new_values: Dict[str, Any],
) -> Set[str]:
changed = set()
for field in TriggerService.SUPPORTED_FIELDS:
if field == "custom_fields":
continue
if field in old_values or field in new_values:
if old_values.get(field) != new_values.get(field):
changed.add(field)
return changed
@staticmethod
def _detect_custom_field_changes(
old_custom_values: Any,
new_custom_values: Any,
) -> Set[str]:
if not isinstance(old_custom_values, dict) or not isinstance(new_custom_values, dict):
return set()
changed_ids = set()
field_ids = set(old_custom_values.keys()) | set(new_custom_values.keys())
for field_id in field_ids:
if old_custom_values.get(field_id) != new_custom_values.get(field_id):
changed_ids.add(field_id)
return changed_ids
@staticmethod
def _normalize_field_type(field_type: Optional[str]) -> Optional[str]:
if not field_type:
return None
if field_type == "formula":
return "number"
return field_type
@staticmethod
def _evaluate_rule(
operator: str,
current_value: Any,
old_value: Any,
new_value: Any,
target_value: Any,
field_type: Optional[str],
field_changed: bool,
) -> bool:
if operator in ("changed_to", "changed_from"):
if not field_changed:
return False
if operator == "changed_to":
return (
TriggerService._value_equals(new_value, target_value, field_type)
and not TriggerService._value_equals(old_value, target_value, field_type)
)
return (
TriggerService._value_equals(old_value, target_value, field_type)
and not TriggerService._value_equals(new_value, target_value, field_type)
)
if operator == "equals":
return new_value == value
elif operator == "not_equals":
return new_value != value
elif operator == "changed_to":
return old_value != value and new_value == value
elif operator == "changed_from":
return old_value == value and new_value != value
return TriggerService._value_equals(current_value, target_value, field_type)
if operator == "not_equals":
return not TriggerService._value_equals(current_value, target_value, field_type)
if operator == "before":
return TriggerService._compare_before(current_value, target_value, field_type)
if operator == "after":
return TriggerService._compare_after(current_value, target_value, field_type)
if operator == "in":
return TriggerService._compare_in(current_value, target_value, field_type)
return False
@staticmethod
def _value_equals(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool:
if current_value is None:
return target_value is None
if field_type == "date":
current_dt, current_date_only = TriggerService._parse_datetime_value(current_value)
target_dt, target_date_only = TriggerService._parse_datetime_value(target_value)
if not current_dt or not target_dt:
return False
if current_date_only or target_date_only:
return current_dt.date() == target_dt.date()
return current_dt == target_dt
if field_type == "number":
current_num = TriggerService._parse_number_value(current_value)
target_num = TriggerService._parse_number_value(target_value)
if current_num is None or target_num is None:
return False
return current_num == target_num
if isinstance(target_value, (list, dict)):
return False
return str(current_value) == str(target_value)
@staticmethod
def _compare_before(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool:
if current_value is None or target_value is None:
return False
if field_type == "date":
current_dt, current_date_only = TriggerService._parse_datetime_value(current_value)
target_dt, target_date_only = TriggerService._parse_datetime_value(target_value)
if not current_dt or not target_dt:
return False
if current_date_only or target_date_only:
return current_dt.date() < target_dt.date()
return current_dt < target_dt
if field_type == "number":
current_num = TriggerService._parse_number_value(current_value)
target_num = TriggerService._parse_number_value(target_value)
if current_num is None or target_num is None:
return False
return current_num < target_num
return False
@staticmethod
def _compare_after(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool:
if current_value is None or target_value is None:
return False
if field_type == "date":
current_dt, current_date_only = TriggerService._parse_datetime_value(current_value)
target_dt, target_date_only = TriggerService._parse_datetime_value(target_value)
if not current_dt or not target_dt:
return False
if current_date_only or target_date_only:
return current_dt.date() > target_dt.date()
return current_dt > target_dt
if field_type == "number":
current_num = TriggerService._parse_number_value(current_value)
target_num = TriggerService._parse_number_value(target_value)
if current_num is None or target_num is None:
return False
return current_num > target_num
return False
@staticmethod
def _compare_in(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool:
if current_value is None or target_value is None:
return False
if field_type == "date":
if not isinstance(target_value, dict):
return False
start_dt, start_date_only = TriggerService._parse_datetime_value(target_value.get("start"))
end_dt, end_date_only = TriggerService._parse_datetime_value(target_value.get("end"))
current_dt, current_date_only = TriggerService._parse_datetime_value(current_value)
if not start_dt or not end_dt or not current_dt:
return False
date_only = current_date_only or start_date_only or end_date_only
if date_only:
current_date = current_dt.date()
return start_dt.date() <= current_date <= end_dt.date()
return start_dt <= current_dt <= end_dt
if isinstance(target_value, (list, tuple, set)):
if field_type == "number":
current_num = TriggerService._parse_number_value(current_value)
if current_num is None:
return False
for item in target_value:
item_num = TriggerService._parse_number_value(item)
if item_num is not None and item_num == current_num:
return True
return False
return str(current_value) in {str(item) for item in target_value if item is not None}
return False
@staticmethod
def _parse_datetime_value(value: Any) -> Tuple[Optional[datetime], bool]:
if value is None:
return None, False
if isinstance(value, datetime):
return value, False
if isinstance(value, date):
return datetime.combine(value, datetime.min.time()), True
if isinstance(value, str):
try:
if len(value) == 10:
parsed = datetime.strptime(value, "%Y-%m-%d")
return parsed, True
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
return parsed.replace(tzinfo=None), False
except ValueError:
return None, False
return None, False
@staticmethod
def _parse_number_value(value: Any) -> Optional[Decimal]:
if value is None or value == "":
return None
try:
return Decimal(str(value))
except (InvalidOperation, ValueError):
return None
@staticmethod
def _execute_actions(
db: Session,
@@ -185,40 +511,76 @@ class TriggerService:
target = action.get("target", "assignee")
template = action.get("template", "任務 {task_title} 已觸發自動化規則")
# Resolve target user
target_user_id = TriggerService._resolve_target(task, target)
if not target_user_id:
return
# Don't notify the user who triggered the action
if target_user_id == current_user.id:
recipients = TriggerService._resolve_targets(db, task, target)
if not recipients:
return
# Format message with variables
message = TriggerService._format_template(template, task, old_values, new_values)
NotificationService.create_notification(
db=db,
user_id=target_user_id,
notification_type="status_change",
reference_type="task",
reference_id=task.id,
title=f"自動化通知: {task.title}",
message=message,
)
for user_id in recipients:
if user_id == current_user.id:
continue
NotificationService.create_notification(
db=db,
user_id=user_id,
notification_type="status_change",
reference_type="task",
reference_id=task.id,
title=f"自動化通知: {task.title}",
message=message,
)
@staticmethod
def _resolve_target(task: Task, target: str) -> Optional[str]:
"""Resolve notification target to user ID."""
def _resolve_targets(db: Session, task: Task, target: str) -> List[str]:
"""Resolve notification target to user IDs."""
recipients: Set[str] = set()
if target == "assignee":
return task.assignee_id
if task.assignee_id:
recipients.add(task.assignee_id)
elif target == "creator":
return task.created_by
if task.created_by:
recipients.add(task.created_by)
elif target == "project_owner":
return task.project.owner_id if task.project else None
if task.project and task.project.owner_id:
recipients.add(task.project.owner_id)
elif target == "project_members":
if task.project:
if task.project.owner_id:
recipients.add(task.project.owner_id)
member_rows = db.query(ProjectMember.user_id).join(
User,
User.id == ProjectMember.user_id,
).filter(
ProjectMember.project_id == task.project_id,
User.is_active == True,
).all()
recipients.update(row[0] for row in member_rows if row and row[0])
elif target.startswith("department:"):
department_id = target.split(":", 1)[1]
if department_id:
user_rows = db.query(User.id).filter(
User.department_id == department_id,
User.is_active == True,
).all()
recipients.update(row[0] for row in user_rows if row and row[0])
elif target.startswith("role:"):
role_name = target.split(":", 1)[1].strip()
if role_name:
role = db.query(Role).filter(func.lower(Role.name) == role_name.lower()).first()
if role:
user_rows = db.query(User.id).filter(
User.role_id == role.id,
User.is_active == True,
).all()
recipients.update(row[0] for row in user_rows if row and row[0])
elif target.startswith("user:"):
return target.split(":", 1)[1]
return None
user_id = target.split(":", 1)[1]
if user_id:
recipients.add(user_id)
return list(recipients)
@staticmethod
def _format_template(