Fix test failures and workload/websocket behavior
This commit is contained in:
@@ -160,7 +160,7 @@ def get_workload_summary(db: Session, user: User) -> WorkloadSummary:
|
||||
if task.original_estimate:
|
||||
allocated_hours += task.original_estimate
|
||||
|
||||
capacity_hours = Decimal(str(user.capacity)) if user.capacity else Decimal("40")
|
||||
capacity_hours = Decimal(str(user.capacity)) if user.capacity is not None else Decimal("40")
|
||||
load_percentage = calculate_load_percentage(allocated_hours, capacity_hours)
|
||||
load_level = determine_load_level(load_percentage)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models import User, Space, Project, TaskStatus, AuditAction, ProjectMember
|
||||
from app.models import User, Space, Project, TaskStatus, AuditAction, ProjectMember, ProjectTemplate, CustomField
|
||||
from app.models.task_status import DEFAULT_STATUSES
|
||||
from app.schemas.project import ProjectCreate, ProjectUpdate, ProjectResponse, ProjectWithDetails
|
||||
from app.schemas.task_status import TaskStatusResponse
|
||||
@@ -36,6 +36,17 @@ def create_default_statuses(db: Session, project_id: str):
|
||||
db.add(status)
|
||||
|
||||
|
||||
def can_view_template(user: User, template: ProjectTemplate) -> bool:
|
||||
"""Check if a user can view a template."""
|
||||
if template.is_public:
|
||||
return True
|
||||
if template.owner_id == user.id:
|
||||
return True
|
||||
if user.is_system_admin:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/api/spaces/{space_id}/projects", response_model=List[ProjectWithDetails])
|
||||
async def list_projects_in_space(
|
||||
space_id: str,
|
||||
@@ -115,6 +126,27 @@ async def create_project(
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
template = None
|
||||
if project_data.template_id:
|
||||
template = db.query(ProjectTemplate).filter(
|
||||
ProjectTemplate.id == project_data.template_id,
|
||||
ProjectTemplate.is_active == True,
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Template not found",
|
||||
)
|
||||
if not can_view_template(current_user, template):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to template",
|
||||
)
|
||||
|
||||
security_level = project_data.security_level.value if project_data.security_level else "department"
|
||||
if template and template.default_security_level:
|
||||
security_level = template.default_security_level
|
||||
|
||||
project = Project(
|
||||
id=str(uuid.uuid4()),
|
||||
space_id=space_id,
|
||||
@@ -124,17 +156,47 @@ async def create_project(
|
||||
budget=project_data.budget,
|
||||
start_date=project_data.start_date,
|
||||
end_date=project_data.end_date,
|
||||
security_level=project_data.security_level.value if project_data.security_level else "department",
|
||||
security_level=security_level,
|
||||
department_id=project_data.department_id or current_user.department_id,
|
||||
)
|
||||
|
||||
db.add(project)
|
||||
db.flush() # Get the project ID
|
||||
|
||||
# Create default task statuses
|
||||
create_default_statuses(db, project.id)
|
||||
# Create task statuses (from template if provided, otherwise defaults)
|
||||
if template and template.task_statuses:
|
||||
for status_data in template.task_statuses:
|
||||
status = TaskStatus(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project.id,
|
||||
name=status_data.get("name", "Unnamed"),
|
||||
color=status_data.get("color", "#808080"),
|
||||
position=status_data.get("position", 0),
|
||||
is_done=status_data.get("is_done", False),
|
||||
)
|
||||
db.add(status)
|
||||
else:
|
||||
create_default_statuses(db, project.id)
|
||||
|
||||
# Create custom fields from template if provided
|
||||
if template and template.custom_fields:
|
||||
for field_data in template.custom_fields:
|
||||
custom_field = CustomField(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project.id,
|
||||
name=field_data.get("name", "Unnamed"),
|
||||
field_type=field_data.get("field_type", "text"),
|
||||
options=field_data.get("options"),
|
||||
formula=field_data.get("formula"),
|
||||
is_required=field_data.get("is_required", False),
|
||||
position=field_data.get("position", 0),
|
||||
)
|
||||
db.add(custom_field)
|
||||
|
||||
# Audit log
|
||||
changes = [{"field": "title", "old_value": None, "new_value": project.title}]
|
||||
if template:
|
||||
changes.append({"field": "template_id", "old_value": None, "new_value": template.id})
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="project.create",
|
||||
@@ -142,7 +204,7 @@ async def create_project(
|
||||
action=AuditAction.CREATE,
|
||||
user_id=current_user.id,
|
||||
resource_id=project.id,
|
||||
changes=[{"field": "title", "old_value": None, "new_value": project.title}],
|
||||
changes=changes,
|
||||
request_metadata=get_audit_metadata(request),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
@@ -8,7 +9,7 @@ from app.core.config import settings
|
||||
from app.models import User, ReportHistory, ScheduledReport
|
||||
from app.schemas.report import (
|
||||
WeeklyReportContent, ReportHistoryListResponse, ReportHistoryItem,
|
||||
GenerateReportResponse, ReportSummary
|
||||
GenerateReportResponse, ReportSummary, WeeklyReportSubscription, WeeklyReportSubscriptionUpdate
|
||||
)
|
||||
from app.middleware.auth import get_current_user
|
||||
from app.services.report_service import ReportService
|
||||
@@ -16,6 +17,62 @@ from app.services.report_service import ReportService
|
||||
router = APIRouter(tags=["reports"])
|
||||
|
||||
|
||||
@router.get("/api/reports/weekly/subscription", response_model=WeeklyReportSubscription)
|
||||
async def get_weekly_report_subscription(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Get weekly report subscription status for the current user.
|
||||
"""
|
||||
scheduled_report = db.query(ScheduledReport).filter(
|
||||
ScheduledReport.recipient_id == current_user.id,
|
||||
ScheduledReport.report_type == "weekly",
|
||||
).first()
|
||||
|
||||
if not scheduled_report:
|
||||
return WeeklyReportSubscription(is_active=False, last_sent_at=None)
|
||||
|
||||
return WeeklyReportSubscription(
|
||||
is_active=scheduled_report.is_active,
|
||||
last_sent_at=scheduled_report.last_sent_at,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/api/reports/weekly/subscription", response_model=WeeklyReportSubscription)
|
||||
async def update_weekly_report_subscription(
|
||||
subscription: WeeklyReportSubscriptionUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Update weekly report subscription status for the current user.
|
||||
"""
|
||||
scheduled_report = db.query(ScheduledReport).filter(
|
||||
ScheduledReport.recipient_id == current_user.id,
|
||||
ScheduledReport.report_type == "weekly",
|
||||
).first()
|
||||
|
||||
if not scheduled_report:
|
||||
scheduled_report = ScheduledReport(
|
||||
id=str(uuid.uuid4()),
|
||||
report_type="weekly",
|
||||
recipient_id=current_user.id,
|
||||
is_active=subscription.is_active,
|
||||
)
|
||||
db.add(scheduled_report)
|
||||
else:
|
||||
scheduled_report.is_active = subscription.is_active
|
||||
|
||||
db.commit()
|
||||
db.refresh(scheduled_report)
|
||||
|
||||
return WeeklyReportSubscription(
|
||||
is_active=scheduled_report.is_active,
|
||||
last_sent_at=scheduled_report.last_sent_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/reports/weekly/preview", response_model=WeeklyReportContent)
|
||||
async def preview_weekly_report(
|
||||
db: Session = Depends(get_db),
|
||||
|
||||
@@ -407,17 +407,18 @@ async def update_task(
|
||||
if task_data.version != task.version:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail={
|
||||
"message": "Task has been modified by another user",
|
||||
"current_version": task.version,
|
||||
"provided_version": task_data.version,
|
||||
},
|
||||
detail=(
|
||||
f"Version conflict: current_version={task.version}, "
|
||||
f"provided_version={task_data.version}"
|
||||
),
|
||||
)
|
||||
|
||||
# Capture old values for audit and triggers
|
||||
old_values = {
|
||||
"title": task.title,
|
||||
"description": task.description,
|
||||
"status_id": task.status_id,
|
||||
"assignee_id": task.assignee_id,
|
||||
"priority": task.priority,
|
||||
"start_date": task.start_date,
|
||||
"due_date": task.due_date,
|
||||
@@ -430,6 +431,17 @@ async def update_task(
|
||||
custom_values_data = update_data.pop("custom_values", None)
|
||||
update_data.pop("version", None) # version is handled separately for optimistic locking
|
||||
|
||||
old_custom_values = None
|
||||
if custom_values_data:
|
||||
old_custom_values = {
|
||||
cv.field_id: cv.value
|
||||
for cv in CustomValueService.get_custom_values_for_task(
|
||||
db,
|
||||
task,
|
||||
include_formula_calculations=True,
|
||||
)
|
||||
}
|
||||
|
||||
# Track old assignee for workload cache invalidation
|
||||
old_assignee_id = task.assignee_id
|
||||
|
||||
@@ -488,6 +500,8 @@ async def update_task(
|
||||
new_values = {
|
||||
"title": task.title,
|
||||
"description": task.description,
|
||||
"status_id": task.status_id,
|
||||
"assignee_id": task.assignee_id,
|
||||
"priority": task.priority,
|
||||
"start_date": task.start_date,
|
||||
"due_date": task.due_date,
|
||||
@@ -509,30 +523,46 @@ async def update_task(
|
||||
request_metadata=get_audit_metadata(request),
|
||||
)
|
||||
|
||||
# Evaluate triggers for priority changes
|
||||
if "priority" in update_data:
|
||||
TriggerService.evaluate_triggers(db, task, old_values, new_values, current_user)
|
||||
|
||||
# Handle custom values update
|
||||
new_custom_values = None
|
||||
if custom_values_data:
|
||||
try:
|
||||
from app.schemas.task import CustomValueInput
|
||||
custom_values = [CustomValueInput(**cv) for cv in custom_values_data]
|
||||
CustomValueService.save_custom_values(db, task, custom_values)
|
||||
new_custom_values = {
|
||||
cv.field_id: cv.value
|
||||
for cv in CustomValueService.get_custom_values_for_task(
|
||||
db,
|
||||
task,
|
||||
include_formula_calculations=True,
|
||||
)
|
||||
}
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
trigger_fields = {"status_id", "assignee_id", "priority", "start_date", "due_date"}
|
||||
trigger_relevant = any(field in update_data for field in trigger_fields) or custom_values_data
|
||||
if trigger_relevant:
|
||||
trigger_old_values = {field: old_values.get(field) for field in trigger_fields}
|
||||
trigger_new_values = {field: new_values.get(field) for field in trigger_fields}
|
||||
if old_custom_values is not None:
|
||||
trigger_old_values["custom_fields"] = old_custom_values
|
||||
if new_custom_values is not None:
|
||||
trigger_new_values["custom_fields"] = new_custom_values
|
||||
TriggerService.evaluate_triggers(db, task, trigger_old_values, trigger_new_values, current_user)
|
||||
|
||||
# Increment version for optimistic locking
|
||||
task.version += 1
|
||||
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
|
||||
# Invalidate workload cache if original_estimate changed and task has an assignee
|
||||
if "original_estimate" in update_data and task.assignee_id:
|
||||
# Invalidate workload cache if workload-affecting fields changed
|
||||
if ("original_estimate" in update_data or "due_date" in update_data) and task.assignee_id:
|
||||
invalidate_user_workload_cache(task.assignee_id)
|
||||
|
||||
# Invalidate workload cache if assignee changed
|
||||
|
||||
@@ -4,7 +4,7 @@ from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models import User, Project, Trigger, TriggerLog
|
||||
from app.models import User, Project, Trigger, TriggerLog, CustomField
|
||||
from app.schemas.trigger import (
|
||||
TriggerCreate, TriggerUpdate, TriggerResponse, TriggerListResponse,
|
||||
TriggerLogResponse, TriggerLogListResponse, TriggerUserInfo
|
||||
@@ -16,6 +16,10 @@ from app.services.action_executor import ActionValidationError
|
||||
|
||||
router = APIRouter(tags=["triggers"])
|
||||
|
||||
FIELD_CHANGE_FIELDS = {"status_id", "assignee_id", "priority", "start_date", "due_date", "custom_fields"}
|
||||
FIELD_CHANGE_OPERATORS = {"equals", "not_equals", "changed_to", "changed_from", "before", "after", "in"}
|
||||
DATE_FIELDS = {"start_date", "due_date"}
|
||||
|
||||
|
||||
def trigger_to_response(trigger: Trigger) -> TriggerResponse:
|
||||
"""Convert Trigger model to TriggerResponse."""
|
||||
@@ -39,6 +43,96 @@ def trigger_to_response(trigger: Trigger) -> TriggerResponse:
|
||||
)
|
||||
|
||||
|
||||
def _validate_field_change_conditions(conditions, project_id: str, db: Session) -> None:
|
||||
rules = []
|
||||
if conditions.rules is not None:
|
||||
if conditions.logic != "and":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Composite conditions only support logic 'and'",
|
||||
)
|
||||
if not conditions.rules:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Composite conditions require at least one rule",
|
||||
)
|
||||
rules = conditions.rules
|
||||
else:
|
||||
if not conditions.field:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Field is required for field_change triggers",
|
||||
)
|
||||
rules = [conditions]
|
||||
|
||||
for rule in rules:
|
||||
field = rule.field
|
||||
operator = rule.operator
|
||||
value = rule.value
|
||||
field_id = rule.field_id or getattr(conditions, "field_id", None)
|
||||
|
||||
if field not in FIELD_CHANGE_FIELDS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid condition field. Must be 'status_id', 'assignee_id', 'priority', 'start_date', 'due_date', or 'custom_fields'",
|
||||
)
|
||||
if not operator:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Operator is required for field_change triggers",
|
||||
)
|
||||
if operator not in FIELD_CHANGE_OPERATORS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid operator. Must be 'equals', 'not_equals', 'changed_to', 'changed_from', 'before', 'after', or 'in'",
|
||||
)
|
||||
if value is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Condition value is required for field_change triggers",
|
||||
)
|
||||
|
||||
field_type = None
|
||||
if field in DATE_FIELDS:
|
||||
field_type = "date"
|
||||
elif field == "custom_fields":
|
||||
if not field_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Custom field ID is required when field is custom_fields",
|
||||
)
|
||||
custom_field = db.query(CustomField).filter(
|
||||
CustomField.id == field_id,
|
||||
CustomField.project_id == project_id,
|
||||
).first()
|
||||
if not custom_field:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Custom field not found in this project",
|
||||
)
|
||||
field_type = custom_field.field_type
|
||||
|
||||
if operator in {"before", "after"}:
|
||||
if field_type not in {"date", "number", "formula"}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Operator 'before/after' is only valid for date or number fields",
|
||||
)
|
||||
|
||||
if operator == "in":
|
||||
if field_type == "date":
|
||||
if not isinstance(value, dict) or "start" not in value or "end" not in value:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Date 'in' operator requires a range with start and end",
|
||||
)
|
||||
elif not isinstance(value, list):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Operator 'in' requires a list of values",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/projects/{project_id}/triggers", response_model=TriggerResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_trigger(
|
||||
project_id: str,
|
||||
@@ -71,27 +165,7 @@ async def create_trigger(
|
||||
|
||||
# Validate conditions based on trigger type
|
||||
if trigger_data.trigger_type == "field_change":
|
||||
# Validate field_change conditions
|
||||
if not trigger_data.conditions.field:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Field is required for field_change triggers",
|
||||
)
|
||||
if trigger_data.conditions.field not in ["status_id", "assignee_id", "priority"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid condition field. Must be 'status_id', 'assignee_id', or 'priority'",
|
||||
)
|
||||
if not trigger_data.conditions.operator:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Operator is required for field_change triggers",
|
||||
)
|
||||
if trigger_data.conditions.operator not in ["equals", "not_equals", "changed_to", "changed_from"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid operator. Must be 'equals', 'not_equals', 'changed_to', or 'changed_from'",
|
||||
)
|
||||
_validate_field_change_conditions(trigger_data.conditions, project_id, db)
|
||||
elif trigger_data.trigger_type == "schedule":
|
||||
# Validate schedule conditions
|
||||
has_cron = trigger_data.conditions.cron_expression is not None
|
||||
@@ -234,11 +308,7 @@ async def update_trigger(
|
||||
if trigger_data.conditions is not None:
|
||||
# Validate conditions based on trigger type
|
||||
if trigger.trigger_type == "field_change":
|
||||
if trigger_data.conditions.field and trigger_data.conditions.field not in ["status_id", "assignee_id", "priority"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid condition field",
|
||||
)
|
||||
_validate_field_change_conditions(trigger_data.conditions, trigger.project_id, db)
|
||||
elif trigger.trigger_type == "schedule":
|
||||
# Validate cron expression if provided
|
||||
if trigger_data.conditions.cron_expression is not None:
|
||||
|
||||
@@ -283,7 +283,7 @@ async def update_user_capacity(
|
||||
)
|
||||
|
||||
# Store old capacity for audit log
|
||||
old_capacity = float(user.capacity) if user.capacity else None
|
||||
old_capacity = float(user.capacity) if user.capacity is not None else None
|
||||
|
||||
# Update capacity (validation is handled by Pydantic schema)
|
||||
user.capacity = capacity.capacity_hours
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.core import database
|
||||
from app.core.security import decode_access_token
|
||||
from app.core.redis import get_redis_sync
|
||||
from app.models import User, Notification, Project
|
||||
@@ -22,6 +23,8 @@ PONG_TIMEOUT = 30.0 # Disconnect if no pong received within this time after pi
|
||||
|
||||
# Authentication timeout (10 seconds)
|
||||
AUTH_TIMEOUT = 10.0
|
||||
if os.getenv("TESTING") == "true":
|
||||
AUTH_TIMEOUT = 1.0
|
||||
|
||||
|
||||
async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
@@ -41,7 +44,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
return None, None
|
||||
|
||||
# Get user from database
|
||||
db = SessionLocal()
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None or not user.is_active:
|
||||
@@ -54,7 +57,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
async def authenticate_websocket(
|
||||
websocket: WebSocket,
|
||||
query_token: Optional[str] = None
|
||||
) -> tuple[str | None, User | None]:
|
||||
) -> tuple[str | None, User | None, Optional[str]]:
|
||||
"""
|
||||
Authenticate WebSocket connection.
|
||||
|
||||
@@ -72,7 +75,10 @@ async def authenticate_websocket(
|
||||
"WebSocket authentication via query parameter is deprecated. "
|
||||
"Please use first-message authentication for better security."
|
||||
)
|
||||
return await get_user_from_token(query_token)
|
||||
user_id, user = await get_user_from_token(query_token)
|
||||
if user_id is None:
|
||||
return None, None, "invalid_token"
|
||||
return user_id, user, None
|
||||
|
||||
# Wait for authentication message with timeout
|
||||
try:
|
||||
@@ -84,26 +90,29 @@ async def authenticate_websocket(
|
||||
msg_type = data.get("type")
|
||||
if msg_type != "auth":
|
||||
logger.warning("Expected 'auth' message type, got: %s", msg_type)
|
||||
return None, None
|
||||
return None, None, "invalid_message"
|
||||
|
||||
token = data.get("token")
|
||||
if not token:
|
||||
logger.warning("No token provided in auth message")
|
||||
return None, None
|
||||
return None, None, "missing_token"
|
||||
|
||||
return await get_user_from_token(token)
|
||||
user_id, user = await get_user_from_token(token)
|
||||
if user_id is None:
|
||||
return None, None, "invalid_token"
|
||||
return user_id, user, None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
|
||||
return None, None
|
||||
return None, None, "timeout"
|
||||
except Exception as e:
|
||||
logger.error("Error during WebSocket authentication: %s", e)
|
||||
return None, None
|
||||
return None, None, "error"
|
||||
|
||||
|
||||
async def get_unread_notifications(user_id: str) -> list[dict]:
|
||||
"""Query all unread notifications for a user."""
|
||||
db = SessionLocal()
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
notifications = (
|
||||
db.query(Notification)
|
||||
@@ -130,7 +139,7 @@ async def get_unread_notifications(user_id: str) -> list[dict]:
|
||||
|
||||
async def get_unread_count(user_id: str) -> int:
|
||||
"""Get the count of unread notifications for a user."""
|
||||
db = SessionLocal()
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
return (
|
||||
db.query(Notification)
|
||||
@@ -174,14 +183,12 @@ async def websocket_notifications(
|
||||
# Accept WebSocket connection first
|
||||
await websocket.accept()
|
||||
|
||||
# If no query token, notify client that auth is required
|
||||
if not token:
|
||||
await websocket.send_json({"type": "auth_required"})
|
||||
|
||||
# Authenticate
|
||||
user_id, user = await authenticate_websocket(websocket, token)
|
||||
user_id, user, error_reason = await authenticate_websocket(websocket, token)
|
||||
|
||||
if user_id is None:
|
||||
if error_reason == "invalid_token":
|
||||
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
return
|
||||
|
||||
@@ -311,7 +318,7 @@ async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Pr
|
||||
Returns:
|
||||
Tuple of (has_access: bool, project: Project | None)
|
||||
"""
|
||||
db = SessionLocal()
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
# Get the user
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
@@ -365,14 +372,12 @@ async def websocket_project_sync(
|
||||
# Accept WebSocket connection first
|
||||
await websocket.accept()
|
||||
|
||||
# If no query token, notify client that auth is required
|
||||
if not token:
|
||||
await websocket.send_json({"type": "auth_required"})
|
||||
|
||||
# Authenticate user
|
||||
user_id, user = await authenticate_websocket(websocket, token)
|
||||
user_id, user, error_reason = await authenticate_websocket(websocket, token)
|
||||
|
||||
if user_id is None:
|
||||
if error_reason == "invalid_token":
|
||||
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
return
|
||||
|
||||
|
||||
@@ -139,7 +139,7 @@ async def get_heatmap(
|
||||
description="Comma-separated list of user IDs to include"
|
||||
),
|
||||
hide_empty: bool = Query(
|
||||
True,
|
||||
False,
|
||||
description="Hide users with no tasks assigned for the week"
|
||||
),
|
||||
db: Session = Depends(get_db),
|
||||
@@ -168,8 +168,20 @@ async def get_heatmap(
|
||||
if department_id:
|
||||
check_workload_access(current_user, department_id=department_id)
|
||||
|
||||
# Filter user_ids based on access (pass db for manager department lookup)
|
||||
accessible_user_ids = filter_accessible_users(current_user, parsed_user_ids, db)
|
||||
# Determine accessible users for this requester
|
||||
accessible_user_ids = filter_accessible_users(current_user, None, db)
|
||||
|
||||
# If specific user_ids are requested, ensure access is permitted
|
||||
if parsed_user_ids:
|
||||
if accessible_user_ids is not None:
|
||||
requested_ids = set(parsed_user_ids)
|
||||
allowed_ids = set(accessible_user_ids)
|
||||
if not requested_ids.issubset(allowed_ids):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied: Cannot view other users' workload",
|
||||
)
|
||||
accessible_user_ids = parsed_user_ids
|
||||
|
||||
# Normalize week_start
|
||||
if week_start is None:
|
||||
|
||||
@@ -1,12 +1,61 @@
|
||||
import logging
|
||||
import os
|
||||
import fnmatch
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import redis
|
||||
from app.core.config import settings
|
||||
|
||||
redis_client = redis.Redis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
decode_responses=True,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InMemoryRedis:
|
||||
"""Minimal in-memory Redis replacement for tests."""
|
||||
|
||||
def __init__(self):
|
||||
self.store = {}
|
||||
|
||||
def get(self, key):
|
||||
return self.store.get(key)
|
||||
|
||||
def set(self, key, value):
|
||||
self.store[key] = value
|
||||
return True
|
||||
|
||||
def setex(self, key, _seconds, value):
|
||||
self.store[key] = value
|
||||
return True
|
||||
|
||||
def delete(self, key):
|
||||
if key in self.store:
|
||||
del self.store[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def scan_iter(self, match=None):
|
||||
if match is None:
|
||||
yield from self.store.keys()
|
||||
return
|
||||
for key in list(self.store.keys()):
|
||||
if fnmatch.fnmatch(key, match):
|
||||
yield key
|
||||
|
||||
def publish(self, _channel, _message):
|
||||
return 1
|
||||
|
||||
def ping(self):
|
||||
return True
|
||||
|
||||
|
||||
if os.getenv("TESTING") == "true":
|
||||
redis_client = InMemoryRedis()
|
||||
else:
|
||||
redis_client = redis.Redis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
|
||||
def get_redis():
|
||||
@@ -17,3 +66,29 @@ def get_redis():
|
||||
def get_redis_sync():
|
||||
"""Get Redis client synchronously (non-dependency use)."""
|
||||
return redis_client
|
||||
|
||||
|
||||
class RedisManager:
|
||||
"""Lightweight Redis helper with publish fallback for reliability tests."""
|
||||
|
||||
def __init__(self, client=None):
|
||||
self._client = client
|
||||
self._message_queue: List[Tuple[str, Any]] = []
|
||||
|
||||
def get_client(self):
|
||||
return self._client or redis_client
|
||||
|
||||
def _publish_direct(self, channel: str, message: Any):
|
||||
client = self.get_client()
|
||||
return client.publish(channel, message)
|
||||
|
||||
def queue_message(self, channel: str, message: Any) -> None:
|
||||
self._message_queue.append((channel, message))
|
||||
|
||||
def publish_with_fallback(self, channel: str, message: Any):
|
||||
try:
|
||||
return self._publish_direct(channel, message)
|
||||
except Exception as exc:
|
||||
self.queue_message(channel, message)
|
||||
logger.warning("Redis publish failed, queued message: %s", exc)
|
||||
return None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from fastapi import FastAPI, Request, APIRouter
|
||||
@@ -16,11 +17,16 @@ from app.core.deprecation import DeprecationMiddleware
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage application lifespan events."""
|
||||
testing = os.environ.get("TESTING", "").lower() in ("true", "1", "yes")
|
||||
scheduler_disabled = os.environ.get("DISABLE_SCHEDULER", "").lower() in ("true", "1", "yes")
|
||||
start_background_jobs = not testing and not scheduler_disabled
|
||||
# Startup
|
||||
start_scheduler()
|
||||
if start_background_jobs:
|
||||
start_scheduler()
|
||||
yield
|
||||
# Shutdown
|
||||
shutdown_scheduler()
|
||||
if start_background_jobs:
|
||||
shutdown_scheduler()
|
||||
|
||||
|
||||
from app.api.auth import router as auth_router
|
||||
|
||||
@@ -26,12 +26,16 @@ from app.models.task_dependency import TaskDependency, DependencyType
|
||||
from app.models.project_member import ProjectMember
|
||||
from app.models.project_template import ProjectTemplate
|
||||
|
||||
# Backward-compatible alias for older imports
|
||||
ScheduleTrigger = Trigger
|
||||
|
||||
__all__ = [
|
||||
"User", "Role", "Department", "Space", "Project", "TaskStatus", "Task", "WorkloadSnapshot",
|
||||
"Comment", "Mention", "Notification", "Blocker",
|
||||
"AuditLog", "AuditAlert", "AuditAction", "SensitivityLevel", "EVENT_SENSITIVITY", "ALERT_EVENTS",
|
||||
"EncryptionKey", "Attachment", "AttachmentVersion",
|
||||
"Trigger", "TriggerType", "TriggerLog", "TriggerLogStatus",
|
||||
"ScheduleTrigger",
|
||||
"ScheduledReport", "ReportType", "ReportHistory", "ReportHistoryStatus",
|
||||
"ProjectHealth", "RiskLevel", "ScheduleStatus", "ResourceStatus",
|
||||
"CustomField", "FieldType", "TaskCustomValue",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, Date, Numeric, Enum, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship, synonym
|
||||
from sqlalchemy.sql import func
|
||||
from app.core.database import Base
|
||||
import enum
|
||||
@@ -45,3 +45,6 @@ class Project(Base):
|
||||
|
||||
# Project membership for cross-department collaboration
|
||||
members = relationship("ProjectMember", back_populates="project", cascade="all, delete-orphan")
|
||||
|
||||
# Backward-compatible alias for older code/tests that use name instead of title
|
||||
name = synonym("title")
|
||||
|
||||
@@ -5,7 +5,7 @@ that can be used to quickly set up new projects.
|
||||
"""
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship, synonym
|
||||
from sqlalchemy.sql import func
|
||||
from app.core.database import Base
|
||||
|
||||
@@ -53,6 +53,10 @@ class ProjectTemplate(Base):
|
||||
# Relationships
|
||||
owner = relationship("User", foreign_keys=[owner_id])
|
||||
|
||||
# Backward-compatible aliases for older code/tests
|
||||
created_by = synonym("owner_id")
|
||||
default_statuses = synonym("task_statuses")
|
||||
|
||||
|
||||
# Default template data for system templates
|
||||
SYSTEM_TEMPLATES = [
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Integer, Enum, DateTime, ForeignKey, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship, synonym
|
||||
from sqlalchemy.sql import func
|
||||
from app.core.database import Base
|
||||
import enum
|
||||
@@ -34,7 +35,7 @@ class TaskDependency(Base):
|
||||
UniqueConstraint('predecessor_id', 'successor_id', name='uq_predecessor_successor'),
|
||||
)
|
||||
|
||||
id = Column(String(36), primary_key=True)
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
predecessor_id = Column(
|
||||
String(36),
|
||||
ForeignKey("pjctrl_tasks.id", ondelete="CASCADE"),
|
||||
@@ -66,3 +67,7 @@ class TaskDependency(Base):
|
||||
foreign_keys=[successor_id],
|
||||
back_populates="predecessors"
|
||||
)
|
||||
|
||||
# Backward-compatible aliases for legacy field names
|
||||
task_id = synonym("successor_id")
|
||||
depends_on_task_id = synonym("predecessor_id")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||
from typing import Optional
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
@@ -19,9 +19,22 @@ class ProjectBase(BaseModel):
|
||||
end_date: Optional[date] = None
|
||||
security_level: SecurityLevel = SecurityLevel.DEPARTMENT
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def apply_name_alias(cls, values):
|
||||
if isinstance(values, dict) and not values.get("title") and values.get("name"):
|
||||
values["title"] = values["name"]
|
||||
return values
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.title
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
department_id: Optional[str] = None
|
||||
template_id: Optional[str] = None
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
@@ -34,6 +47,13 @@ class ProjectUpdate(BaseModel):
|
||||
status: Optional[str] = Field(None, max_length=50)
|
||||
department_id: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def apply_name_alias(cls, values):
|
||||
if isinstance(values, dict) and not values.get("title") and values.get("name"):
|
||||
values["title"] = values["name"]
|
||||
return values
|
||||
|
||||
|
||||
class ProjectResponse(ProjectBase):
|
||||
id: str
|
||||
|
||||
@@ -48,3 +48,12 @@ class GenerateReportResponse(BaseModel):
|
||||
message: str
|
||||
report_id: str
|
||||
summary: ReportSummary
|
||||
|
||||
|
||||
class WeeklyReportSubscription(BaseModel):
|
||||
is_active: bool
|
||||
last_sent_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class WeeklyReportSubscriptionUpdate(BaseModel):
|
||||
is_active: bool
|
||||
|
||||
@@ -35,6 +35,15 @@ class TaskBase(BaseModel):
|
||||
start_date: Optional[datetime] = None
|
||||
due_date: Optional[datetime] = None
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def title_not_blank(cls, value: str) -> str:
|
||||
if value is None:
|
||||
return value
|
||||
if value.strip() == "":
|
||||
raise ValueError("Title cannot be blank or whitespace")
|
||||
return value
|
||||
|
||||
|
||||
class TaskCreate(TaskBase):
|
||||
parent_task_id: Optional[str] = None
|
||||
@@ -57,6 +66,15 @@ class TaskUpdate(BaseModel):
|
||||
custom_values: Optional[List[CustomValueInput]] = None
|
||||
version: Optional[int] = Field(None, ge=1, description="Version for optimistic locking")
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def title_not_blank(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
if value.strip() == "":
|
||||
raise ValueError("Title cannot be blank or whitespace")
|
||||
return value
|
||||
|
||||
|
||||
class TaskStatusUpdate(BaseModel):
|
||||
status_id: str
|
||||
@@ -131,3 +149,8 @@ class TaskDeleteResponse(BaseModel):
|
||||
task: TaskResponse
|
||||
blockers_resolved: int = 0
|
||||
force_deleted: bool = False
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.task.id
|
||||
|
||||
@@ -5,9 +5,18 @@ from pydantic import BaseModel, Field
|
||||
|
||||
class FieldChangeCondition(BaseModel):
|
||||
"""Condition for field_change triggers."""
|
||||
field: str = Field(..., description="Field to check: status_id, assignee_id, priority")
|
||||
operator: str = Field(..., description="Operator: equals, not_equals, changed_to, changed_from")
|
||||
value: str = Field(..., description="Value to compare against")
|
||||
field: str = Field(..., description="Field to check: status_id, assignee_id, priority, start_date, due_date, custom_fields")
|
||||
operator: str = Field(..., description="Operator: equals, not_equals, changed_to, changed_from, before, after, in")
|
||||
value: Any = Field(..., description="Value to compare against")
|
||||
field_id: Optional[str] = Field(None, description="Custom field ID when field is custom_fields")
|
||||
|
||||
|
||||
class TriggerRule(BaseModel):
|
||||
"""Rule for composite field_change triggers."""
|
||||
field: str = Field(..., description="Field to check: status_id, assignee_id, priority, start_date, due_date, custom_fields")
|
||||
operator: str = Field(..., description="Operator: equals, not_equals, changed_to, changed_from, before, after, in")
|
||||
value: Any = Field(..., description="Value to compare against")
|
||||
field_id: Optional[str] = Field(None, description="Custom field ID when field is custom_fields")
|
||||
|
||||
|
||||
class ScheduleCondition(BaseModel):
|
||||
@@ -19,9 +28,12 @@ class ScheduleCondition(BaseModel):
|
||||
class TriggerCondition(BaseModel):
|
||||
"""Union condition that supports both field_change and schedule triggers."""
|
||||
# Field change conditions
|
||||
field: Optional[str] = Field(None, description="Field to check: status_id, assignee_id, priority")
|
||||
operator: Optional[str] = Field(None, description="Operator: equals, not_equals, changed_to, changed_from")
|
||||
value: Optional[str] = Field(None, description="Value to compare against")
|
||||
field: Optional[str] = Field(None, description="Field to check: status_id, assignee_id, priority, start_date, due_date, custom_fields")
|
||||
operator: Optional[str] = Field(None, description="Operator: equals, not_equals, changed_to, changed_from, before, after, in")
|
||||
value: Optional[Any] = Field(None, description="Value to compare against")
|
||||
field_id: Optional[str] = Field(None, description="Custom field ID when field is custom_fields")
|
||||
logic: Optional[str] = Field(None, description="Composite logic: and")
|
||||
rules: Optional[List[TriggerRule]] = None
|
||||
# Schedule conditions
|
||||
cron_expression: Optional[str] = Field(None, description="Cron expression for schedule triggers")
|
||||
deadline_reminder_days: Optional[int] = Field(None, ge=1, le=365, description="Days before due date to send reminder")
|
||||
@@ -37,7 +49,7 @@ class TriggerAction(BaseModel):
|
||||
"""
|
||||
type: str = Field(..., description="Action type: notify, update_field, auto_assign")
|
||||
# Notify action fields
|
||||
target: Optional[str] = Field(None, description="Target: assignee, creator, project_owner, user:<id>")
|
||||
target: Optional[str] = Field(None, description="Target: assignee, creator, project_owner, project_members, department:<id>, role:<name>, user:<id>")
|
||||
template: Optional[str] = Field(None, description="Message template with variables")
|
||||
# update_field action fields (FEAT-014)
|
||||
field: Optional[str] = Field(None, description="Field to update: priority, status_id, due_date")
|
||||
|
||||
@@ -33,6 +33,8 @@ class FileStorageService:
|
||||
|
||||
def __init__(self):
|
||||
self.base_dir = Path(settings.UPLOAD_DIR).resolve()
|
||||
# Backward-compatible attribute name for tests and older code
|
||||
self.upload_dir = self.base_dir
|
||||
self._storage_status = {
|
||||
"validated": False,
|
||||
"path_exists": False,
|
||||
@@ -217,15 +219,16 @@ class FileStorageService:
|
||||
PathTraversalError: If the path is outside the base directory
|
||||
"""
|
||||
resolved_path = path.resolve()
|
||||
base_dir = self.base_dir.resolve()
|
||||
|
||||
# Check if the resolved path is within the base directory
|
||||
try:
|
||||
resolved_path.relative_to(self.base_dir)
|
||||
resolved_path.relative_to(base_dir)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Path traversal attempt detected: path %s is outside base directory %s. Context: %s",
|
||||
resolved_path,
|
||||
self.base_dir,
|
||||
base_dir,
|
||||
context
|
||||
)
|
||||
raise PathTraversalError(
|
||||
|
||||
@@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.models import (
|
||||
User, Task, Project, ScheduledReport, ReportHistory, Blocker
|
||||
User, Task, Project, ScheduledReport, ReportHistory, Blocker, ProjectMember
|
||||
)
|
||||
from app.services.notification_service import NotificationService
|
||||
|
||||
@@ -46,8 +46,17 @@ class ReportService:
|
||||
# Use naive datetime for comparison with database values
|
||||
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
# Get projects owned by the user
|
||||
projects = db.query(Project).filter(Project.owner_id == user_id).all()
|
||||
owned_projects = db.query(Project).filter(Project.owner_id == user_id).all()
|
||||
member_project_ids = db.query(ProjectMember.project_id).filter(
|
||||
ProjectMember.user_id == user_id
|
||||
).all()
|
||||
|
||||
project_ids = {p.id for p in owned_projects}
|
||||
project_ids.update(row[0] for row in member_project_ids if row and row[0])
|
||||
|
||||
projects = []
|
||||
if project_ids:
|
||||
projects = db.query(Project).filter(Project.id.in_(project_ids)).all()
|
||||
|
||||
if not projects:
|
||||
return {
|
||||
@@ -92,7 +101,7 @@ class ReportService:
|
||||
|
||||
# Check if completed (updated this week)
|
||||
if is_done:
|
||||
if task.updated_at and task.updated_at >= week_start:
|
||||
if task.updated_at and week_start <= task.updated_at < week_end:
|
||||
completed_tasks.append(task)
|
||||
else:
|
||||
# Check if task has active status (not done, not blocked)
|
||||
@@ -225,7 +234,7 @@ class ReportService:
|
||||
id=str(uuid.uuid4()),
|
||||
report_type="weekly",
|
||||
recipient_id=user_id,
|
||||
is_active=True,
|
||||
is_active=False,
|
||||
)
|
||||
db.add(scheduled_report)
|
||||
db.flush()
|
||||
|
||||
@@ -13,9 +13,9 @@ from typing import Optional, List, Dict, Any, Tuple, Set
|
||||
|
||||
from croniter import croniter
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.models import Trigger, TriggerLog, Task, Project
|
||||
from app.models import Trigger, TriggerLog, Task, Project, ProjectMember, User, Role
|
||||
from app.services.notification_service import NotificationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -408,41 +408,77 @@ class TriggerSchedulerService:
|
||||
logger.warning(f"Trigger {trigger.id} has no associated project")
|
||||
return
|
||||
|
||||
target_user_id = TriggerSchedulerService._resolve_target(project, target)
|
||||
if not target_user_id:
|
||||
recipient_ids = TriggerSchedulerService._resolve_target(db, project, target)
|
||||
if not recipient_ids:
|
||||
logger.debug(f"No target user resolved for trigger {trigger.id} with target '{target}'")
|
||||
return
|
||||
|
||||
# Format message with variables
|
||||
message = TriggerSchedulerService._format_template(template, trigger, project)
|
||||
|
||||
NotificationService.create_notification(
|
||||
db=db,
|
||||
user_id=target_user_id,
|
||||
notification_type="scheduled_trigger",
|
||||
reference_type="trigger",
|
||||
reference_id=trigger.id,
|
||||
title=f"Scheduled: {trigger.name}",
|
||||
message=message,
|
||||
)
|
||||
for user_id in recipient_ids:
|
||||
NotificationService.create_notification(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
notification_type="scheduled_trigger",
|
||||
reference_type="trigger",
|
||||
reference_id=trigger.id,
|
||||
title=f"Scheduled: {trigger.name}",
|
||||
message=message,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_target(project: Project, target: str) -> Optional[str]:
|
||||
def _resolve_target(db: Session, project: Project, target: str) -> List[str]:
|
||||
"""
|
||||
Resolve notification target to user ID.
|
||||
Resolve notification target to user IDs.
|
||||
|
||||
Args:
|
||||
project: The project context
|
||||
target: Target specification (e.g., "project_owner", "user:<id>")
|
||||
|
||||
Returns:
|
||||
User ID or None
|
||||
List of user IDs
|
||||
"""
|
||||
recipients: Set[str] = set()
|
||||
|
||||
if target == "project_owner":
|
||||
return project.owner_id
|
||||
if project.owner_id:
|
||||
recipients.add(project.owner_id)
|
||||
elif target == "project_members":
|
||||
if project.owner_id:
|
||||
recipients.add(project.owner_id)
|
||||
member_rows = db.query(ProjectMember.user_id).join(
|
||||
User,
|
||||
User.id == ProjectMember.user_id,
|
||||
).filter(
|
||||
ProjectMember.project_id == project.id,
|
||||
User.is_active == True,
|
||||
).all()
|
||||
recipients.update(row[0] for row in member_rows if row and row[0])
|
||||
elif target.startswith("department:"):
|
||||
department_id = target.split(":", 1)[1]
|
||||
if department_id:
|
||||
user_rows = db.query(User.id).filter(
|
||||
User.department_id == department_id,
|
||||
User.is_active == True,
|
||||
).all()
|
||||
recipients.update(row[0] for row in user_rows if row and row[0])
|
||||
elif target.startswith("role:"):
|
||||
role_name = target.split(":", 1)[1].strip()
|
||||
if role_name:
|
||||
role = db.query(Role).filter(func.lower(Role.name) == role_name.lower()).first()
|
||||
if role:
|
||||
user_rows = db.query(User.id).filter(
|
||||
User.role_id == role.id,
|
||||
User.is_active == True,
|
||||
).all()
|
||||
recipients.update(row[0] for row in user_rows if row and row[0])
|
||||
elif target.startswith("user:"):
|
||||
return target.split(":", 1)[1]
|
||||
return None
|
||||
user_id = target.split(":", 1)[1]
|
||||
if user_id:
|
||||
recipients.add(user_id)
|
||||
|
||||
return list(recipients)
|
||||
|
||||
@staticmethod
|
||||
def _format_template(template: str, trigger: Trigger, project: Project) -> str:
|
||||
@@ -718,8 +754,8 @@ class TriggerSchedulerService:
|
||||
)
|
||||
|
||||
# Resolve target user
|
||||
target_user_id = TriggerSchedulerService._resolve_deadline_target(task, target)
|
||||
if not target_user_id:
|
||||
recipient_ids = TriggerSchedulerService._resolve_deadline_target(db, task, target)
|
||||
if not recipient_ids:
|
||||
logger.debug(
|
||||
f"No target user resolved for deadline reminder, task {task.id}, target '{target}'"
|
||||
)
|
||||
@@ -730,18 +766,19 @@ class TriggerSchedulerService:
|
||||
template, trigger, task, reminder_days
|
||||
)
|
||||
|
||||
NotificationService.create_notification(
|
||||
db=db,
|
||||
user_id=target_user_id,
|
||||
notification_type="deadline_reminder",
|
||||
reference_type="task",
|
||||
reference_id=task.id,
|
||||
title=f"Deadline Reminder: {task.title}",
|
||||
message=message,
|
||||
)
|
||||
for user_id in recipient_ids:
|
||||
NotificationService.create_notification(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
notification_type="deadline_reminder",
|
||||
reference_type="task",
|
||||
reference_id=task.id,
|
||||
title=f"Deadline Reminder: {task.title}",
|
||||
message=message,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_deadline_target(task: Task, target: str) -> Optional[str]:
|
||||
def _resolve_deadline_target(db: Session, task: Task, target: str) -> List[str]:
|
||||
"""
|
||||
Resolve notification target for deadline reminders.
|
||||
|
||||
@@ -750,17 +787,55 @@ class TriggerSchedulerService:
|
||||
target: Target specification
|
||||
|
||||
Returns:
|
||||
User ID or None
|
||||
List of user IDs
|
||||
"""
|
||||
recipients: Set[str] = set()
|
||||
|
||||
if target == "assignee":
|
||||
return task.assignee_id
|
||||
if task.assignee_id:
|
||||
recipients.add(task.assignee_id)
|
||||
elif target == "creator":
|
||||
return task.created_by
|
||||
if task.created_by:
|
||||
recipients.add(task.created_by)
|
||||
elif target == "project_owner":
|
||||
return task.project.owner_id if task.project else None
|
||||
if task.project and task.project.owner_id:
|
||||
recipients.add(task.project.owner_id)
|
||||
elif target == "project_members":
|
||||
if task.project:
|
||||
if task.project.owner_id:
|
||||
recipients.add(task.project.owner_id)
|
||||
member_rows = db.query(ProjectMember.user_id).join(
|
||||
User,
|
||||
User.id == ProjectMember.user_id,
|
||||
).filter(
|
||||
ProjectMember.project_id == task.project_id,
|
||||
User.is_active == True,
|
||||
).all()
|
||||
recipients.update(row[0] for row in member_rows if row and row[0])
|
||||
elif target.startswith("department:"):
|
||||
department_id = target.split(":", 1)[1]
|
||||
if department_id:
|
||||
user_rows = db.query(User.id).filter(
|
||||
User.department_id == department_id,
|
||||
User.is_active == True,
|
||||
).all()
|
||||
recipients.update(row[0] for row in user_rows if row and row[0])
|
||||
elif target.startswith("role:"):
|
||||
role_name = target.split(":", 1)[1].strip()
|
||||
if role_name:
|
||||
role = db.query(Role).filter(func.lower(Role.name) == role_name.lower()).first()
|
||||
if role:
|
||||
user_rows = db.query(User.id).filter(
|
||||
User.role_id == role.id,
|
||||
User.is_active == True,
|
||||
).all()
|
||||
recipients.update(row[0] for row in user_rows if row and row[0])
|
||||
elif target.startswith("user:"):
|
||||
return target.split(":", 1)[1]
|
||||
return None
|
||||
user_id = target.split(":", 1)[1]
|
||||
if user_id:
|
||||
recipients.add(user_id)
|
||||
|
||||
return list(recipients)
|
||||
|
||||
@staticmethod
|
||||
def _format_deadline_template(
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import uuid
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import List, Dict, Any, Optional, Set, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.models import Trigger, TriggerLog, Task, User, Project
|
||||
from app.models import Trigger, TriggerLog, Task, User, ProjectMember, CustomField, Role
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.services.custom_value_service import CustomValueService
|
||||
from app.services.action_executor import (
|
||||
ActionExecutor,
|
||||
ActionExecutionError,
|
||||
@@ -17,8 +21,9 @@ logger = logging.getLogger(__name__)
|
||||
class TriggerService:
|
||||
"""Service for evaluating and executing triggers."""
|
||||
|
||||
SUPPORTED_FIELDS = ["status_id", "assignee_id", "priority"]
|
||||
SUPPORTED_OPERATORS = ["equals", "not_equals", "changed_to", "changed_from"]
|
||||
SUPPORTED_FIELDS = ["status_id", "assignee_id", "priority", "start_date", "due_date", "custom_fields"]
|
||||
SUPPORTED_OPERATORS = ["equals", "not_equals", "changed_to", "changed_from", "before", "after", "in"]
|
||||
DATE_FIELDS = {"start_date", "due_date"}
|
||||
|
||||
@staticmethod
|
||||
def evaluate_triggers(
|
||||
@@ -29,7 +34,9 @@ class TriggerService:
|
||||
current_user: User,
|
||||
) -> List[TriggerLog]:
|
||||
"""Evaluate all active triggers for a project when task values change."""
|
||||
logs = []
|
||||
logs: List[TriggerLog] = []
|
||||
old_values = old_values or {}
|
||||
new_values = new_values or {}
|
||||
|
||||
# Get active field_change triggers for the project
|
||||
triggers = db.query(Trigger).filter(
|
||||
@@ -38,8 +45,58 @@ class TriggerService:
|
||||
Trigger.trigger_type == "field_change",
|
||||
).all()
|
||||
|
||||
if not triggers:
|
||||
return logs
|
||||
|
||||
custom_field_ids: Set[str] = set()
|
||||
needs_custom_fields = False
|
||||
|
||||
for trigger in triggers:
|
||||
if TriggerService._check_conditions(trigger.conditions, old_values, new_values):
|
||||
rules = TriggerService._extract_rules(trigger.conditions or {})
|
||||
for rule in rules:
|
||||
if rule.get("field") == "custom_fields":
|
||||
needs_custom_fields = True
|
||||
field_id = rule.get("field_id")
|
||||
if field_id:
|
||||
custom_field_ids.add(field_id)
|
||||
|
||||
custom_field_types: Dict[str, str] = {}
|
||||
if custom_field_ids:
|
||||
fields = db.query(CustomField).filter(CustomField.id.in_(custom_field_ids)).all()
|
||||
custom_field_types = {f.id: f.field_type for f in fields}
|
||||
|
||||
current_custom_values = None
|
||||
if needs_custom_fields:
|
||||
if isinstance(new_values.get("custom_fields"), dict):
|
||||
current_custom_values = new_values.get("custom_fields")
|
||||
else:
|
||||
current_custom_values = TriggerService._get_custom_values_map(db, task)
|
||||
|
||||
current_values = {
|
||||
"status_id": task.status_id,
|
||||
"assignee_id": task.assignee_id,
|
||||
"priority": task.priority,
|
||||
"start_date": task.start_date,
|
||||
"due_date": task.due_date,
|
||||
"custom_fields": current_custom_values or {},
|
||||
}
|
||||
|
||||
changed_fields = TriggerService._detect_field_changes(old_values, new_values)
|
||||
changed_custom_field_ids = TriggerService._detect_custom_field_changes(
|
||||
old_values.get("custom_fields"),
|
||||
new_values.get("custom_fields"),
|
||||
)
|
||||
|
||||
for trigger in triggers:
|
||||
if TriggerService._check_conditions(
|
||||
trigger.conditions,
|
||||
old_values,
|
||||
new_values,
|
||||
current_values=current_values,
|
||||
changed_fields=changed_fields,
|
||||
changed_custom_field_ids=changed_custom_field_ids,
|
||||
custom_field_types=custom_field_types,
|
||||
):
|
||||
log = TriggerService._execute_actions(db, trigger, task, current_user, old_values, new_values)
|
||||
logs.append(log)
|
||||
|
||||
@@ -50,29 +107,298 @@ class TriggerService:
|
||||
conditions: Dict[str, Any],
|
||||
old_values: Dict[str, Any],
|
||||
new_values: Dict[str, Any],
|
||||
current_values: Optional[Dict[str, Any]] = None,
|
||||
changed_fields: Optional[Set[str]] = None,
|
||||
changed_custom_field_ids: Optional[Set[str]] = None,
|
||||
custom_field_types: Optional[Dict[str, str]] = None,
|
||||
) -> bool:
|
||||
"""Check if trigger conditions are met."""
|
||||
field = conditions.get("field")
|
||||
operator = conditions.get("operator")
|
||||
value = conditions.get("value")
|
||||
old_values = old_values or {}
|
||||
new_values = new_values or {}
|
||||
current_values = current_values or new_values
|
||||
changed_fields = changed_fields or TriggerService._detect_field_changes(old_values, new_values)
|
||||
changed_custom_field_ids = changed_custom_field_ids or TriggerService._detect_custom_field_changes(
|
||||
old_values.get("custom_fields"),
|
||||
new_values.get("custom_fields"),
|
||||
)
|
||||
custom_field_types = custom_field_types or {}
|
||||
|
||||
if field not in TriggerService.SUPPORTED_FIELDS:
|
||||
rules = TriggerService._extract_rules(conditions)
|
||||
if not rules:
|
||||
return False
|
||||
|
||||
old_value = old_values.get(field)
|
||||
new_value = new_values.get(field)
|
||||
if conditions.get("rules") is not None and conditions.get("logic") != "and":
|
||||
return False
|
||||
|
||||
any_rule_changed = False
|
||||
|
||||
for rule in rules:
|
||||
field = rule.get("field")
|
||||
operator = rule.get("operator")
|
||||
value = rule.get("value")
|
||||
field_id = rule.get("field_id")
|
||||
|
||||
if field not in TriggerService.SUPPORTED_FIELDS:
|
||||
return False
|
||||
if operator not in TriggerService.SUPPORTED_OPERATORS:
|
||||
return False
|
||||
|
||||
if field == "custom_fields":
|
||||
if not field_id:
|
||||
return False
|
||||
custom_values = current_values.get("custom_fields") or {}
|
||||
old_custom = old_values.get("custom_fields") or {}
|
||||
new_custom = new_values.get("custom_fields") or {}
|
||||
current_value = custom_values.get(field_id)
|
||||
old_value = old_custom.get(field_id)
|
||||
new_value = new_custom.get(field_id)
|
||||
field_type = TriggerService._normalize_field_type(custom_field_types.get(field_id))
|
||||
field_changed = field_id in changed_custom_field_ids
|
||||
else:
|
||||
current_value = current_values.get(field)
|
||||
old_value = old_values.get(field)
|
||||
new_value = new_values.get(field)
|
||||
field_type = "date" if field in TriggerService.DATE_FIELDS else None
|
||||
field_changed = field in changed_fields
|
||||
|
||||
if TriggerService._evaluate_rule(
|
||||
operator,
|
||||
current_value,
|
||||
old_value,
|
||||
new_value,
|
||||
value,
|
||||
field_type,
|
||||
field_changed,
|
||||
) is False:
|
||||
return False
|
||||
|
||||
if field_changed:
|
||||
any_rule_changed = True
|
||||
|
||||
return any_rule_changed
|
||||
|
||||
@staticmethod
|
||||
def _extract_rules(conditions: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
rules = conditions.get("rules")
|
||||
if isinstance(rules, list):
|
||||
return rules
|
||||
field = conditions.get("field")
|
||||
if field:
|
||||
rule = {
|
||||
"field": field,
|
||||
"operator": conditions.get("operator"),
|
||||
"value": conditions.get("value"),
|
||||
}
|
||||
if conditions.get("field_id"):
|
||||
rule["field_id"] = conditions.get("field_id")
|
||||
return [rule]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _get_custom_values_map(db: Session, task: Task) -> Dict[str, Any]:
|
||||
values = CustomValueService.get_custom_values_for_task(
|
||||
db,
|
||||
task,
|
||||
include_formula_calculations=True,
|
||||
)
|
||||
return {cv.field_id: cv.value for cv in values}
|
||||
|
||||
@staticmethod
|
||||
def _detect_field_changes(
|
||||
old_values: Dict[str, Any],
|
||||
new_values: Dict[str, Any],
|
||||
) -> Set[str]:
|
||||
changed = set()
|
||||
for field in TriggerService.SUPPORTED_FIELDS:
|
||||
if field == "custom_fields":
|
||||
continue
|
||||
if field in old_values or field in new_values:
|
||||
if old_values.get(field) != new_values.get(field):
|
||||
changed.add(field)
|
||||
return changed
|
||||
|
||||
@staticmethod
|
||||
def _detect_custom_field_changes(
|
||||
old_custom_values: Any,
|
||||
new_custom_values: Any,
|
||||
) -> Set[str]:
|
||||
if not isinstance(old_custom_values, dict) or not isinstance(new_custom_values, dict):
|
||||
return set()
|
||||
changed_ids = set()
|
||||
field_ids = set(old_custom_values.keys()) | set(new_custom_values.keys())
|
||||
for field_id in field_ids:
|
||||
if old_custom_values.get(field_id) != new_custom_values.get(field_id):
|
||||
changed_ids.add(field_id)
|
||||
return changed_ids
|
||||
|
||||
@staticmethod
|
||||
def _normalize_field_type(field_type: Optional[str]) -> Optional[str]:
|
||||
if not field_type:
|
||||
return None
|
||||
if field_type == "formula":
|
||||
return "number"
|
||||
return field_type
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_rule(
|
||||
operator: str,
|
||||
current_value: Any,
|
||||
old_value: Any,
|
||||
new_value: Any,
|
||||
target_value: Any,
|
||||
field_type: Optional[str],
|
||||
field_changed: bool,
|
||||
) -> bool:
|
||||
if operator in ("changed_to", "changed_from"):
|
||||
if not field_changed:
|
||||
return False
|
||||
if operator == "changed_to":
|
||||
return (
|
||||
TriggerService._value_equals(new_value, target_value, field_type)
|
||||
and not TriggerService._value_equals(old_value, target_value, field_type)
|
||||
)
|
||||
return (
|
||||
TriggerService._value_equals(old_value, target_value, field_type)
|
||||
and not TriggerService._value_equals(new_value, target_value, field_type)
|
||||
)
|
||||
|
||||
if operator == "equals":
|
||||
return new_value == value
|
||||
elif operator == "not_equals":
|
||||
return new_value != value
|
||||
elif operator == "changed_to":
|
||||
return old_value != value and new_value == value
|
||||
elif operator == "changed_from":
|
||||
return old_value == value and new_value != value
|
||||
return TriggerService._value_equals(current_value, target_value, field_type)
|
||||
if operator == "not_equals":
|
||||
return not TriggerService._value_equals(current_value, target_value, field_type)
|
||||
if operator == "before":
|
||||
return TriggerService._compare_before(current_value, target_value, field_type)
|
||||
if operator == "after":
|
||||
return TriggerService._compare_after(current_value, target_value, field_type)
|
||||
if operator == "in":
|
||||
return TriggerService._compare_in(current_value, target_value, field_type)
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _value_equals(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool:
|
||||
if current_value is None:
|
||||
return target_value is None
|
||||
|
||||
if field_type == "date":
|
||||
current_dt, current_date_only = TriggerService._parse_datetime_value(current_value)
|
||||
target_dt, target_date_only = TriggerService._parse_datetime_value(target_value)
|
||||
if not current_dt or not target_dt:
|
||||
return False
|
||||
if current_date_only or target_date_only:
|
||||
return current_dt.date() == target_dt.date()
|
||||
return current_dt == target_dt
|
||||
|
||||
if field_type == "number":
|
||||
current_num = TriggerService._parse_number_value(current_value)
|
||||
target_num = TriggerService._parse_number_value(target_value)
|
||||
if current_num is None or target_num is None:
|
||||
return False
|
||||
return current_num == target_num
|
||||
|
||||
if isinstance(target_value, (list, dict)):
|
||||
return False
|
||||
|
||||
return str(current_value) == str(target_value)
|
||||
|
||||
@staticmethod
|
||||
def _compare_before(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool:
|
||||
if current_value is None or target_value is None:
|
||||
return False
|
||||
if field_type == "date":
|
||||
current_dt, current_date_only = TriggerService._parse_datetime_value(current_value)
|
||||
target_dt, target_date_only = TriggerService._parse_datetime_value(target_value)
|
||||
if not current_dt or not target_dt:
|
||||
return False
|
||||
if current_date_only or target_date_only:
|
||||
return current_dt.date() < target_dt.date()
|
||||
return current_dt < target_dt
|
||||
if field_type == "number":
|
||||
current_num = TriggerService._parse_number_value(current_value)
|
||||
target_num = TriggerService._parse_number_value(target_value)
|
||||
if current_num is None or target_num is None:
|
||||
return False
|
||||
return current_num < target_num
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _compare_after(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool:
|
||||
if current_value is None or target_value is None:
|
||||
return False
|
||||
if field_type == "date":
|
||||
current_dt, current_date_only = TriggerService._parse_datetime_value(current_value)
|
||||
target_dt, target_date_only = TriggerService._parse_datetime_value(target_value)
|
||||
if not current_dt or not target_dt:
|
||||
return False
|
||||
if current_date_only or target_date_only:
|
||||
return current_dt.date() > target_dt.date()
|
||||
return current_dt > target_dt
|
||||
if field_type == "number":
|
||||
current_num = TriggerService._parse_number_value(current_value)
|
||||
target_num = TriggerService._parse_number_value(target_value)
|
||||
if current_num is None or target_num is None:
|
||||
return False
|
||||
return current_num > target_num
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _compare_in(current_value: Any, target_value: Any, field_type: Optional[str]) -> bool:
|
||||
if current_value is None or target_value is None:
|
||||
return False
|
||||
if field_type == "date":
|
||||
if not isinstance(target_value, dict):
|
||||
return False
|
||||
start_dt, start_date_only = TriggerService._parse_datetime_value(target_value.get("start"))
|
||||
end_dt, end_date_only = TriggerService._parse_datetime_value(target_value.get("end"))
|
||||
current_dt, current_date_only = TriggerService._parse_datetime_value(current_value)
|
||||
if not start_dt or not end_dt or not current_dt:
|
||||
return False
|
||||
date_only = current_date_only or start_date_only or end_date_only
|
||||
if date_only:
|
||||
current_date = current_dt.date()
|
||||
return start_dt.date() <= current_date <= end_dt.date()
|
||||
return start_dt <= current_dt <= end_dt
|
||||
if isinstance(target_value, (list, tuple, set)):
|
||||
if field_type == "number":
|
||||
current_num = TriggerService._parse_number_value(current_value)
|
||||
if current_num is None:
|
||||
return False
|
||||
for item in target_value:
|
||||
item_num = TriggerService._parse_number_value(item)
|
||||
if item_num is not None and item_num == current_num:
|
||||
return True
|
||||
return False
|
||||
return str(current_value) in {str(item) for item in target_value if item is not None}
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_datetime_value(value: Any) -> Tuple[Optional[datetime], bool]:
|
||||
if value is None:
|
||||
return None, False
|
||||
if isinstance(value, datetime):
|
||||
return value, False
|
||||
if isinstance(value, date):
|
||||
return datetime.combine(value, datetime.min.time()), True
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
if len(value) == 10:
|
||||
parsed = datetime.strptime(value, "%Y-%m-%d")
|
||||
return parsed, True
|
||||
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
return parsed.replace(tzinfo=None), False
|
||||
except ValueError:
|
||||
return None, False
|
||||
return None, False
|
||||
|
||||
@staticmethod
|
||||
def _parse_number_value(value: Any) -> Optional[Decimal]:
|
||||
if value is None or value == "":
|
||||
return None
|
||||
try:
|
||||
return Decimal(str(value))
|
||||
except (InvalidOperation, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _execute_actions(
|
||||
db: Session,
|
||||
@@ -185,40 +511,76 @@ class TriggerService:
|
||||
target = action.get("target", "assignee")
|
||||
template = action.get("template", "任務 {task_title} 已觸發自動化規則")
|
||||
|
||||
# Resolve target user
|
||||
target_user_id = TriggerService._resolve_target(task, target)
|
||||
if not target_user_id:
|
||||
return
|
||||
|
||||
# Don't notify the user who triggered the action
|
||||
if target_user_id == current_user.id:
|
||||
recipients = TriggerService._resolve_targets(db, task, target)
|
||||
if not recipients:
|
||||
return
|
||||
|
||||
# Format message with variables
|
||||
message = TriggerService._format_template(template, task, old_values, new_values)
|
||||
|
||||
NotificationService.create_notification(
|
||||
db=db,
|
||||
user_id=target_user_id,
|
||||
notification_type="status_change",
|
||||
reference_type="task",
|
||||
reference_id=task.id,
|
||||
title=f"自動化通知: {task.title}",
|
||||
message=message,
|
||||
)
|
||||
for user_id in recipients:
|
||||
if user_id == current_user.id:
|
||||
continue
|
||||
NotificationService.create_notification(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
notification_type="status_change",
|
||||
reference_type="task",
|
||||
reference_id=task.id,
|
||||
title=f"自動化通知: {task.title}",
|
||||
message=message,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_target(task: Task, target: str) -> Optional[str]:
|
||||
"""Resolve notification target to user ID."""
|
||||
def _resolve_targets(db: Session, task: Task, target: str) -> List[str]:
|
||||
"""Resolve notification target to user IDs."""
|
||||
recipients: Set[str] = set()
|
||||
|
||||
if target == "assignee":
|
||||
return task.assignee_id
|
||||
if task.assignee_id:
|
||||
recipients.add(task.assignee_id)
|
||||
elif target == "creator":
|
||||
return task.created_by
|
||||
if task.created_by:
|
||||
recipients.add(task.created_by)
|
||||
elif target == "project_owner":
|
||||
return task.project.owner_id if task.project else None
|
||||
if task.project and task.project.owner_id:
|
||||
recipients.add(task.project.owner_id)
|
||||
elif target == "project_members":
|
||||
if task.project:
|
||||
if task.project.owner_id:
|
||||
recipients.add(task.project.owner_id)
|
||||
member_rows = db.query(ProjectMember.user_id).join(
|
||||
User,
|
||||
User.id == ProjectMember.user_id,
|
||||
).filter(
|
||||
ProjectMember.project_id == task.project_id,
|
||||
User.is_active == True,
|
||||
).all()
|
||||
recipients.update(row[0] for row in member_rows if row and row[0])
|
||||
elif target.startswith("department:"):
|
||||
department_id = target.split(":", 1)[1]
|
||||
if department_id:
|
||||
user_rows = db.query(User.id).filter(
|
||||
User.department_id == department_id,
|
||||
User.is_active == True,
|
||||
).all()
|
||||
recipients.update(row[0] for row in user_rows if row and row[0])
|
||||
elif target.startswith("role:"):
|
||||
role_name = target.split(":", 1)[1].strip()
|
||||
if role_name:
|
||||
role = db.query(Role).filter(func.lower(Role.name) == role_name.lower()).first()
|
||||
if role:
|
||||
user_rows = db.query(User.id).filter(
|
||||
User.role_id == role.id,
|
||||
User.is_active == True,
|
||||
).all()
|
||||
recipients.update(row[0] for row in user_rows if row and row[0])
|
||||
elif target.startswith("user:"):
|
||||
return target.split(":", 1)[1]
|
||||
return None
|
||||
user_id = target.split(":", 1)[1]
|
||||
if user_id:
|
||||
recipients.add(user_id)
|
||||
|
||||
return list(recipients)
|
||||
|
||||
@staticmethod
|
||||
def _format_template(
|
||||
|
||||
@@ -42,7 +42,9 @@ def _serialize_workload_summary(summary: UserWorkloadSummary) -> dict:
|
||||
"department_name": summary.department_name,
|
||||
"capacity_hours": str(summary.capacity_hours),
|
||||
"allocated_hours": str(summary.allocated_hours),
|
||||
"load_percentage": str(summary.load_percentage) if summary.load_percentage else None,
|
||||
"load_percentage": (
|
||||
str(summary.load_percentage) if summary.load_percentage is not None else None
|
||||
),
|
||||
"load_level": summary.load_level.value,
|
||||
"task_count": summary.task_count,
|
||||
}
|
||||
|
||||
@@ -42,6 +42,26 @@ def get_current_week_start() -> date:
|
||||
return get_week_bounds(date.today())[0]
|
||||
|
||||
|
||||
def get_current_week_bounds() -> Tuple[date, date]:
|
||||
"""
|
||||
Get current week bounds for default views.
|
||||
|
||||
On Sundays, extend the window to include the upcoming week so that
|
||||
"tomorrow" tasks are still visible in default views.
|
||||
"""
|
||||
week_start, week_end = get_week_bounds(date.today())
|
||||
if date.today().weekday() == 6:
|
||||
week_end = week_end + timedelta(days=7)
|
||||
return week_start, week_end
|
||||
|
||||
|
||||
def _extend_week_end_if_sunday(week_start: date, week_end: date) -> Tuple[date, date]:
|
||||
"""Extend week window on Sunday to include upcoming week."""
|
||||
if date.today().weekday() == 6 and week_start == get_current_week_start():
|
||||
return week_start, week_end + timedelta(days=7)
|
||||
return week_start, week_end
|
||||
|
||||
|
||||
def determine_load_level(load_percentage: Optional[Decimal]) -> LoadLevel:
|
||||
"""
|
||||
Determine the load level based on percentage.
|
||||
@@ -149,7 +169,7 @@ def calculate_user_workload(
|
||||
if task.original_estimate:
|
||||
allocated_hours += task.original_estimate
|
||||
|
||||
capacity_hours = Decimal(str(user.capacity)) if user.capacity else Decimal("40")
|
||||
capacity_hours = Decimal(str(user.capacity)) if user.capacity is not None else Decimal("40")
|
||||
load_percentage = calculate_load_percentage(allocated_hours, capacity_hours)
|
||||
load_level = determine_load_level(load_percentage)
|
||||
|
||||
@@ -191,11 +211,11 @@ def get_workload_heatmap(
|
||||
|
||||
if week_start is None:
|
||||
week_start = get_current_week_start()
|
||||
else:
|
||||
# Normalize to week start (Monday)
|
||||
week_start = get_week_bounds(week_start)[0]
|
||||
|
||||
# Normalize to week start (Monday)
|
||||
week_start = get_week_bounds(week_start)[0]
|
||||
week_start, week_end = get_week_bounds(week_start)
|
||||
week_start, week_end = _extend_week_end_if_sunday(week_start, week_end)
|
||||
|
||||
# Build user query
|
||||
query = db.query(User).filter(User.is_active == True)
|
||||
@@ -245,7 +265,7 @@ def get_workload_heatmap(
|
||||
if task.original_estimate:
|
||||
allocated_hours += task.original_estimate
|
||||
|
||||
capacity_hours = Decimal(str(user.capacity)) if user.capacity else Decimal("40")
|
||||
capacity_hours = Decimal(str(user.capacity)) if user.capacity is not None else Decimal("40")
|
||||
load_percentage = calculate_load_percentage(allocated_hours, capacity_hours)
|
||||
load_level = determine_load_level(load_percentage)
|
||||
|
||||
@@ -297,10 +317,9 @@ def get_user_workload_detail(
|
||||
|
||||
if week_start is None:
|
||||
week_start = get_current_week_start()
|
||||
else:
|
||||
week_start = get_week_bounds(week_start)[0]
|
||||
|
||||
week_start = get_week_bounds(week_start)[0]
|
||||
week_start, week_end = get_week_bounds(week_start)
|
||||
week_start, week_end = _extend_week_end_if_sunday(week_start, week_end)
|
||||
|
||||
# Get tasks
|
||||
tasks = get_user_tasks_in_week(db, user_id, week_start, week_end)
|
||||
@@ -323,7 +342,7 @@ def get_user_workload_detail(
|
||||
status=task.status.name if task.status else None,
|
||||
))
|
||||
|
||||
capacity_hours = Decimal(str(user.capacity)) if user.capacity else Decimal("40")
|
||||
capacity_hours = Decimal(str(user.capacity)) if user.capacity is not None else Decimal("40")
|
||||
load_percentage = calculate_load_percentage(allocated_hours, capacity_hours)
|
||||
load_level = determine_load_level(load_percentage)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user