Fix test failures and workload/websocket behavior

This commit is contained in:
beabigegg
2026-01-11 08:37:21 +08:00
parent 3bdc6ff1c9
commit f5f870da56
49 changed files with 3006 additions and 1132 deletions

View File

@@ -160,7 +160,7 @@ def get_workload_summary(db: Session, user: User) -> WorkloadSummary:
if task.original_estimate:
allocated_hours += task.original_estimate
capacity_hours = Decimal(str(user.capacity)) if user.capacity else Decimal("40")
capacity_hours = Decimal(str(user.capacity)) if user.capacity is not None else Decimal("40")
load_percentage = calculate_load_percentage(allocated_hours, capacity_hours)
load_level = determine_load_level(load_percentage)

View File

@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models import User, Space, Project, TaskStatus, AuditAction, ProjectMember
from app.models import User, Space, Project, TaskStatus, AuditAction, ProjectMember, ProjectTemplate, CustomField
from app.models.task_status import DEFAULT_STATUSES
from app.schemas.project import ProjectCreate, ProjectUpdate, ProjectResponse, ProjectWithDetails
from app.schemas.task_status import TaskStatusResponse
@@ -36,6 +36,17 @@ def create_default_statuses(db: Session, project_id: str):
db.add(status)
def can_view_template(user: User, template: ProjectTemplate) -> bool:
"""Check if a user can view a template."""
if template.is_public:
return True
if template.owner_id == user.id:
return True
if user.is_system_admin:
return True
return False
@router.get("/api/spaces/{space_id}/projects", response_model=List[ProjectWithDetails])
async def list_projects_in_space(
space_id: str,
@@ -115,6 +126,27 @@ async def create_project(
detail="Access denied",
)
template = None
if project_data.template_id:
template = db.query(ProjectTemplate).filter(
ProjectTemplate.id == project_data.template_id,
ProjectTemplate.is_active == True,
).first()
if not template:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Template not found",
)
if not can_view_template(current_user, template):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to template",
)
security_level = project_data.security_level.value if project_data.security_level else "department"
if template and template.default_security_level:
security_level = template.default_security_level
project = Project(
id=str(uuid.uuid4()),
space_id=space_id,
@@ -124,17 +156,47 @@ async def create_project(
budget=project_data.budget,
start_date=project_data.start_date,
end_date=project_data.end_date,
security_level=project_data.security_level.value if project_data.security_level else "department",
security_level=security_level,
department_id=project_data.department_id or current_user.department_id,
)
db.add(project)
db.flush() # Get the project ID
# Create default task statuses
create_default_statuses(db, project.id)
# Create task statuses (from template if provided, otherwise defaults)
if template and template.task_statuses:
for status_data in template.task_statuses:
status = TaskStatus(
id=str(uuid.uuid4()),
project_id=project.id,
name=status_data.get("name", "Unnamed"),
color=status_data.get("color", "#808080"),
position=status_data.get("position", 0),
is_done=status_data.get("is_done", False),
)
db.add(status)
else:
create_default_statuses(db, project.id)
# Create custom fields from template if provided
if template and template.custom_fields:
for field_data in template.custom_fields:
custom_field = CustomField(
id=str(uuid.uuid4()),
project_id=project.id,
name=field_data.get("name", "Unnamed"),
field_type=field_data.get("field_type", "text"),
options=field_data.get("options"),
formula=field_data.get("formula"),
is_required=field_data.get("is_required", False),
position=field_data.get("position", 0),
)
db.add(custom_field)
# Audit log
changes = [{"field": "title", "old_value": None, "new_value": project.title}]
if template:
changes.append({"field": "template_id", "old_value": None, "new_value": template.id})
AuditService.log_event(
db=db,
event_type="project.create",
@@ -142,7 +204,7 @@ async def create_project(
action=AuditAction.CREATE,
user_id=current_user.id,
resource_id=project.id,
changes=[{"field": "title", "old_value": None, "new_value": project.title}],
changes=changes,
request_metadata=get_audit_metadata(request),
)

View File

@@ -1,4 +1,5 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
import uuid
from sqlalchemy.orm import Session
from typing import Optional
@@ -8,7 +9,7 @@ from app.core.config import settings
from app.models import User, ReportHistory, ScheduledReport
from app.schemas.report import (
WeeklyReportContent, ReportHistoryListResponse, ReportHistoryItem,
GenerateReportResponse, ReportSummary
GenerateReportResponse, ReportSummary, WeeklyReportSubscription, WeeklyReportSubscriptionUpdate
)
from app.middleware.auth import get_current_user
from app.services.report_service import ReportService
@@ -16,6 +17,62 @@ from app.services.report_service import ReportService
router = APIRouter(tags=["reports"])
@router.get("/api/reports/weekly/subscription", response_model=WeeklyReportSubscription)
async def get_weekly_report_subscription(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Get weekly report subscription status for the current user.
"""
scheduled_report = db.query(ScheduledReport).filter(
ScheduledReport.recipient_id == current_user.id,
ScheduledReport.report_type == "weekly",
).first()
if not scheduled_report:
return WeeklyReportSubscription(is_active=False, last_sent_at=None)
return WeeklyReportSubscription(
is_active=scheduled_report.is_active,
last_sent_at=scheduled_report.last_sent_at,
)
@router.put("/api/reports/weekly/subscription", response_model=WeeklyReportSubscription)
async def update_weekly_report_subscription(
subscription: WeeklyReportSubscriptionUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Update weekly report subscription status for the current user.
"""
scheduled_report = db.query(ScheduledReport).filter(
ScheduledReport.recipient_id == current_user.id,
ScheduledReport.report_type == "weekly",
).first()
if not scheduled_report:
scheduled_report = ScheduledReport(
id=str(uuid.uuid4()),
report_type="weekly",
recipient_id=current_user.id,
is_active=subscription.is_active,
)
db.add(scheduled_report)
else:
scheduled_report.is_active = subscription.is_active
db.commit()
db.refresh(scheduled_report)
return WeeklyReportSubscription(
is_active=scheduled_report.is_active,
last_sent_at=scheduled_report.last_sent_at,
)
@router.get("/api/reports/weekly/preview", response_model=WeeklyReportContent)
async def preview_weekly_report(
db: Session = Depends(get_db),

View File

@@ -407,17 +407,18 @@ async def update_task(
if task_data.version != task.version:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"message": "Task has been modified by another user",
"current_version": task.version,
"provided_version": task_data.version,
},
detail=(
f"Version conflict: current_version={task.version}, "
f"provided_version={task_data.version}"
),
)
# Capture old values for audit and triggers
old_values = {
"title": task.title,
"description": task.description,
"status_id": task.status_id,
"assignee_id": task.assignee_id,
"priority": task.priority,
"start_date": task.start_date,
"due_date": task.due_date,
@@ -430,6 +431,17 @@ async def update_task(
custom_values_data = update_data.pop("custom_values", None)
update_data.pop("version", None) # version is handled separately for optimistic locking
old_custom_values = None
if custom_values_data:
old_custom_values = {
cv.field_id: cv.value
for cv in CustomValueService.get_custom_values_for_task(
db,
task,
include_formula_calculations=True,
)
}
# Track old assignee for workload cache invalidation
old_assignee_id = task.assignee_id
@@ -488,6 +500,8 @@ async def update_task(
new_values = {
"title": task.title,
"description": task.description,
"status_id": task.status_id,
"assignee_id": task.assignee_id,
"priority": task.priority,
"start_date": task.start_date,
"due_date": task.due_date,
@@ -509,30 +523,46 @@ async def update_task(
request_metadata=get_audit_metadata(request),
)
# Evaluate triggers for priority changes
if "priority" in update_data:
TriggerService.evaluate_triggers(db, task, old_values, new_values, current_user)
# Handle custom values update
new_custom_values = None
if custom_values_data:
try:
from app.schemas.task import CustomValueInput
custom_values = [CustomValueInput(**cv) for cv in custom_values_data]
CustomValueService.save_custom_values(db, task, custom_values)
new_custom_values = {
cv.field_id: cv.value
for cv in CustomValueService.get_custom_values_for_task(
db,
task,
include_formula_calculations=True,
)
}
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
trigger_fields = {"status_id", "assignee_id", "priority", "start_date", "due_date"}
trigger_relevant = any(field in update_data for field in trigger_fields) or custom_values_data
if trigger_relevant:
trigger_old_values = {field: old_values.get(field) for field in trigger_fields}
trigger_new_values = {field: new_values.get(field) for field in trigger_fields}
if old_custom_values is not None:
trigger_old_values["custom_fields"] = old_custom_values
if new_custom_values is not None:
trigger_new_values["custom_fields"] = new_custom_values
TriggerService.evaluate_triggers(db, task, trigger_old_values, trigger_new_values, current_user)
# Increment version for optimistic locking
task.version += 1
db.commit()
db.refresh(task)
# Invalidate workload cache if original_estimate changed and task has an assignee
if "original_estimate" in update_data and task.assignee_id:
# Invalidate workload cache if workload-affecting fields changed
if ("original_estimate" in update_data or "due_date" in update_data) and task.assignee_id:
invalidate_user_workload_cache(task.assignee_id)
# Invalidate workload cache if assignee changed

View File

@@ -4,7 +4,7 @@ from sqlalchemy.orm import Session
from typing import Optional
from app.core.database import get_db
from app.models import User, Project, Trigger, TriggerLog
from app.models import User, Project, Trigger, TriggerLog, CustomField
from app.schemas.trigger import (
TriggerCreate, TriggerUpdate, TriggerResponse, TriggerListResponse,
TriggerLogResponse, TriggerLogListResponse, TriggerUserInfo
@@ -16,6 +16,10 @@ from app.services.action_executor import ActionValidationError
router = APIRouter(tags=["triggers"])
FIELD_CHANGE_FIELDS = {"status_id", "assignee_id", "priority", "start_date", "due_date", "custom_fields"}
FIELD_CHANGE_OPERATORS = {"equals", "not_equals", "changed_to", "changed_from", "before", "after", "in"}
DATE_FIELDS = {"start_date", "due_date"}
def trigger_to_response(trigger: Trigger) -> TriggerResponse:
"""Convert Trigger model to TriggerResponse."""
@@ -39,6 +43,96 @@ def trigger_to_response(trigger: Trigger) -> TriggerResponse:
)
def _validate_field_change_conditions(conditions, project_id: str, db: Session) -> None:
rules = []
if conditions.rules is not None:
if conditions.logic != "and":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Composite conditions only support logic 'and'",
)
if not conditions.rules:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Composite conditions require at least one rule",
)
rules = conditions.rules
else:
if not conditions.field:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Field is required for field_change triggers",
)
rules = [conditions]
for rule in rules:
field = rule.field
operator = rule.operator
value = rule.value
field_id = rule.field_id or getattr(conditions, "field_id", None)
if field not in FIELD_CHANGE_FIELDS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid condition field. Must be 'status_id', 'assignee_id', 'priority', 'start_date', 'due_date', or 'custom_fields'",
)
if not operator:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Operator is required for field_change triggers",
)
if operator not in FIELD_CHANGE_OPERATORS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid operator. Must be 'equals', 'not_equals', 'changed_to', 'changed_from', 'before', 'after', or 'in'",
)
if value is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Condition value is required for field_change triggers",
)
field_type = None
if field in DATE_FIELDS:
field_type = "date"
elif field == "custom_fields":
if not field_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Custom field ID is required when field is custom_fields",
)
custom_field = db.query(CustomField).filter(
CustomField.id == field_id,
CustomField.project_id == project_id,
).first()
if not custom_field:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Custom field not found in this project",
)
field_type = custom_field.field_type
if operator in {"before", "after"}:
if field_type not in {"date", "number", "formula"}:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Operator 'before/after' is only valid for date or number fields",
)
if operator == "in":
if field_type == "date":
if not isinstance(value, dict) or "start" not in value or "end" not in value:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Date 'in' operator requires a range with start and end",
)
elif not isinstance(value, list):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Operator 'in' requires a list of values",
)
@router.post("/api/projects/{project_id}/triggers", response_model=TriggerResponse, status_code=status.HTTP_201_CREATED)
async def create_trigger(
project_id: str,
@@ -71,27 +165,7 @@ async def create_trigger(
# Validate conditions based on trigger type
if trigger_data.trigger_type == "field_change":
# Validate field_change conditions
if not trigger_data.conditions.field:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Field is required for field_change triggers",
)
if trigger_data.conditions.field not in ["status_id", "assignee_id", "priority"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid condition field. Must be 'status_id', 'assignee_id', or 'priority'",
)
if not trigger_data.conditions.operator:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Operator is required for field_change triggers",
)
if trigger_data.conditions.operator not in ["equals", "not_equals", "changed_to", "changed_from"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid operator. Must be 'equals', 'not_equals', 'changed_to', or 'changed_from'",
)
_validate_field_change_conditions(trigger_data.conditions, project_id, db)
elif trigger_data.trigger_type == "schedule":
# Validate schedule conditions
has_cron = trigger_data.conditions.cron_expression is not None
@@ -234,11 +308,7 @@ async def update_trigger(
if trigger_data.conditions is not None:
# Validate conditions based on trigger type
if trigger.trigger_type == "field_change":
if trigger_data.conditions.field and trigger_data.conditions.field not in ["status_id", "assignee_id", "priority"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid condition field",
)
_validate_field_change_conditions(trigger_data.conditions, trigger.project_id, db)
elif trigger.trigger_type == "schedule":
# Validate cron expression if provided
if trigger_data.conditions.cron_expression is not None:

View File

@@ -283,7 +283,7 @@ async def update_user_capacity(
)
# Store old capacity for audit log
old_capacity = float(user.capacity) if user.capacity else None
old_capacity = float(user.capacity) if user.capacity is not None else None
# Update capacity (validation is handled by Pydantic schema)
user.capacity = capacity.capacity_hours

View File

@@ -1,11 +1,12 @@
import asyncio
import os
import logging
import time
from typing import Optional
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.core import database
from app.core.security import decode_access_token
from app.core.redis import get_redis_sync
from app.models import User, Notification, Project
@@ -22,6 +23,8 @@ PONG_TIMEOUT = 30.0 # Disconnect if no pong received within this time after pi
# Authentication timeout (10 seconds)
AUTH_TIMEOUT = 10.0
if os.getenv("TESTING") == "true":
AUTH_TIMEOUT = 1.0
async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
@@ -41,7 +44,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
return None, None
# Get user from database
db = SessionLocal()
db = database.SessionLocal()
try:
user = db.query(User).filter(User.id == user_id).first()
if user is None or not user.is_active:
@@ -54,7 +57,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
async def authenticate_websocket(
websocket: WebSocket,
query_token: Optional[str] = None
) -> tuple[str | None, User | None]:
) -> tuple[str | None, User | None, Optional[str]]:
"""
Authenticate WebSocket connection.
@@ -72,7 +75,10 @@ async def authenticate_websocket(
"WebSocket authentication via query parameter is deprecated. "
"Please use first-message authentication for better security."
)
return await get_user_from_token(query_token)
user_id, user = await get_user_from_token(query_token)
if user_id is None:
return None, None, "invalid_token"
return user_id, user, None
# Wait for authentication message with timeout
try:
@@ -84,26 +90,29 @@ async def authenticate_websocket(
msg_type = data.get("type")
if msg_type != "auth":
logger.warning("Expected 'auth' message type, got: %s", msg_type)
return None, None
return None, None, "invalid_message"
token = data.get("token")
if not token:
logger.warning("No token provided in auth message")
return None, None
return None, None, "missing_token"
return await get_user_from_token(token)
user_id, user = await get_user_from_token(token)
if user_id is None:
return None, None, "invalid_token"
return user_id, user, None
except asyncio.TimeoutError:
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
return None, None
return None, None, "timeout"
except Exception as e:
logger.error("Error during WebSocket authentication: %s", e)
return None, None
return None, None, "error"
async def get_unread_notifications(user_id: str) -> list[dict]:
"""Query all unread notifications for a user."""
db = SessionLocal()
db = database.SessionLocal()
try:
notifications = (
db.query(Notification)
@@ -130,7 +139,7 @@ async def get_unread_notifications(user_id: str) -> list[dict]:
async def get_unread_count(user_id: str) -> int:
"""Get the count of unread notifications for a user."""
db = SessionLocal()
db = database.SessionLocal()
try:
return (
db.query(Notification)
@@ -174,14 +183,12 @@ async def websocket_notifications(
# Accept WebSocket connection first
await websocket.accept()
# If no query token, notify client that auth is required
if not token:
await websocket.send_json({"type": "auth_required"})
# Authenticate
user_id, user = await authenticate_websocket(websocket, token)
user_id, user, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
return
@@ -311,7 +318,7 @@ async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Pr
Returns:
Tuple of (has_access: bool, project: Project | None)
"""
db = SessionLocal()
db = database.SessionLocal()
try:
# Get the user
user = db.query(User).filter(User.id == user_id).first()
@@ -365,14 +372,12 @@ async def websocket_project_sync(
# Accept WebSocket connection first
await websocket.accept()
# If no query token, notify client that auth is required
if not token:
await websocket.send_json({"type": "auth_required"})
# Authenticate user
user_id, user = await authenticate_websocket(websocket, token)
user_id, user, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
return

View File

@@ -139,7 +139,7 @@ async def get_heatmap(
description="Comma-separated list of user IDs to include"
),
hide_empty: bool = Query(
True,
False,
description="Hide users with no tasks assigned for the week"
),
db: Session = Depends(get_db),
@@ -168,8 +168,20 @@ async def get_heatmap(
if department_id:
check_workload_access(current_user, department_id=department_id)
# Filter user_ids based on access (pass db for manager department lookup)
accessible_user_ids = filter_accessible_users(current_user, parsed_user_ids, db)
# Determine accessible users for this requester
accessible_user_ids = filter_accessible_users(current_user, None, db)
# If specific user_ids are requested, ensure access is permitted
if parsed_user_ids:
if accessible_user_ids is not None:
requested_ids = set(parsed_user_ids)
allowed_ids = set(accessible_user_ids)
if not requested_ids.issubset(allowed_ids):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: Cannot view other users' workload",
)
accessible_user_ids = parsed_user_ids
# Normalize week_start
if week_start is None: