Fix test failures and workload/websocket behavior
This commit is contained in:
@@ -160,7 +160,7 @@ def get_workload_summary(db: Session, user: User) -> WorkloadSummary:
|
||||
if task.original_estimate:
|
||||
allocated_hours += task.original_estimate
|
||||
|
||||
capacity_hours = Decimal(str(user.capacity)) if user.capacity else Decimal("40")
|
||||
capacity_hours = Decimal(str(user.capacity)) if user.capacity is not None else Decimal("40")
|
||||
load_percentage = calculate_load_percentage(allocated_hours, capacity_hours)
|
||||
load_level = determine_load_level(load_percentage)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models import User, Space, Project, TaskStatus, AuditAction, ProjectMember
|
||||
from app.models import User, Space, Project, TaskStatus, AuditAction, ProjectMember, ProjectTemplate, CustomField
|
||||
from app.models.task_status import DEFAULT_STATUSES
|
||||
from app.schemas.project import ProjectCreate, ProjectUpdate, ProjectResponse, ProjectWithDetails
|
||||
from app.schemas.task_status import TaskStatusResponse
|
||||
@@ -36,6 +36,17 @@ def create_default_statuses(db: Session, project_id: str):
|
||||
db.add(status)
|
||||
|
||||
|
||||
def can_view_template(user: User, template: ProjectTemplate) -> bool:
|
||||
"""Check if a user can view a template."""
|
||||
if template.is_public:
|
||||
return True
|
||||
if template.owner_id == user.id:
|
||||
return True
|
||||
if user.is_system_admin:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/api/spaces/{space_id}/projects", response_model=List[ProjectWithDetails])
|
||||
async def list_projects_in_space(
|
||||
space_id: str,
|
||||
@@ -115,6 +126,27 @@ async def create_project(
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
template = None
|
||||
if project_data.template_id:
|
||||
template = db.query(ProjectTemplate).filter(
|
||||
ProjectTemplate.id == project_data.template_id,
|
||||
ProjectTemplate.is_active == True,
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Template not found",
|
||||
)
|
||||
if not can_view_template(current_user, template):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to template",
|
||||
)
|
||||
|
||||
security_level = project_data.security_level.value if project_data.security_level else "department"
|
||||
if template and template.default_security_level:
|
||||
security_level = template.default_security_level
|
||||
|
||||
project = Project(
|
||||
id=str(uuid.uuid4()),
|
||||
space_id=space_id,
|
||||
@@ -124,17 +156,47 @@ async def create_project(
|
||||
budget=project_data.budget,
|
||||
start_date=project_data.start_date,
|
||||
end_date=project_data.end_date,
|
||||
security_level=project_data.security_level.value if project_data.security_level else "department",
|
||||
security_level=security_level,
|
||||
department_id=project_data.department_id or current_user.department_id,
|
||||
)
|
||||
|
||||
db.add(project)
|
||||
db.flush() # Get the project ID
|
||||
|
||||
# Create default task statuses
|
||||
create_default_statuses(db, project.id)
|
||||
# Create task statuses (from template if provided, otherwise defaults)
|
||||
if template and template.task_statuses:
|
||||
for status_data in template.task_statuses:
|
||||
status = TaskStatus(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project.id,
|
||||
name=status_data.get("name", "Unnamed"),
|
||||
color=status_data.get("color", "#808080"),
|
||||
position=status_data.get("position", 0),
|
||||
is_done=status_data.get("is_done", False),
|
||||
)
|
||||
db.add(status)
|
||||
else:
|
||||
create_default_statuses(db, project.id)
|
||||
|
||||
# Create custom fields from template if provided
|
||||
if template and template.custom_fields:
|
||||
for field_data in template.custom_fields:
|
||||
custom_field = CustomField(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project.id,
|
||||
name=field_data.get("name", "Unnamed"),
|
||||
field_type=field_data.get("field_type", "text"),
|
||||
options=field_data.get("options"),
|
||||
formula=field_data.get("formula"),
|
||||
is_required=field_data.get("is_required", False),
|
||||
position=field_data.get("position", 0),
|
||||
)
|
||||
db.add(custom_field)
|
||||
|
||||
# Audit log
|
||||
changes = [{"field": "title", "old_value": None, "new_value": project.title}]
|
||||
if template:
|
||||
changes.append({"field": "template_id", "old_value": None, "new_value": template.id})
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="project.create",
|
||||
@@ -142,7 +204,7 @@ async def create_project(
|
||||
action=AuditAction.CREATE,
|
||||
user_id=current_user.id,
|
||||
resource_id=project.id,
|
||||
changes=[{"field": "title", "old_value": None, "new_value": project.title}],
|
||||
changes=changes,
|
||||
request_metadata=get_audit_metadata(request),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
@@ -8,7 +9,7 @@ from app.core.config import settings
|
||||
from app.models import User, ReportHistory, ScheduledReport
|
||||
from app.schemas.report import (
|
||||
WeeklyReportContent, ReportHistoryListResponse, ReportHistoryItem,
|
||||
GenerateReportResponse, ReportSummary
|
||||
GenerateReportResponse, ReportSummary, WeeklyReportSubscription, WeeklyReportSubscriptionUpdate
|
||||
)
|
||||
from app.middleware.auth import get_current_user
|
||||
from app.services.report_service import ReportService
|
||||
@@ -16,6 +17,62 @@ from app.services.report_service import ReportService
|
||||
router = APIRouter(tags=["reports"])
|
||||
|
||||
|
||||
@router.get("/api/reports/weekly/subscription", response_model=WeeklyReportSubscription)
|
||||
async def get_weekly_report_subscription(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Get weekly report subscription status for the current user.
|
||||
"""
|
||||
scheduled_report = db.query(ScheduledReport).filter(
|
||||
ScheduledReport.recipient_id == current_user.id,
|
||||
ScheduledReport.report_type == "weekly",
|
||||
).first()
|
||||
|
||||
if not scheduled_report:
|
||||
return WeeklyReportSubscription(is_active=False, last_sent_at=None)
|
||||
|
||||
return WeeklyReportSubscription(
|
||||
is_active=scheduled_report.is_active,
|
||||
last_sent_at=scheduled_report.last_sent_at,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/api/reports/weekly/subscription", response_model=WeeklyReportSubscription)
|
||||
async def update_weekly_report_subscription(
|
||||
subscription: WeeklyReportSubscriptionUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Update weekly report subscription status for the current user.
|
||||
"""
|
||||
scheduled_report = db.query(ScheduledReport).filter(
|
||||
ScheduledReport.recipient_id == current_user.id,
|
||||
ScheduledReport.report_type == "weekly",
|
||||
).first()
|
||||
|
||||
if not scheduled_report:
|
||||
scheduled_report = ScheduledReport(
|
||||
id=str(uuid.uuid4()),
|
||||
report_type="weekly",
|
||||
recipient_id=current_user.id,
|
||||
is_active=subscription.is_active,
|
||||
)
|
||||
db.add(scheduled_report)
|
||||
else:
|
||||
scheduled_report.is_active = subscription.is_active
|
||||
|
||||
db.commit()
|
||||
db.refresh(scheduled_report)
|
||||
|
||||
return WeeklyReportSubscription(
|
||||
is_active=scheduled_report.is_active,
|
||||
last_sent_at=scheduled_report.last_sent_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/reports/weekly/preview", response_model=WeeklyReportContent)
|
||||
async def preview_weekly_report(
|
||||
db: Session = Depends(get_db),
|
||||
|
||||
@@ -407,17 +407,18 @@ async def update_task(
|
||||
if task_data.version != task.version:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail={
|
||||
"message": "Task has been modified by another user",
|
||||
"current_version": task.version,
|
||||
"provided_version": task_data.version,
|
||||
},
|
||||
detail=(
|
||||
f"Version conflict: current_version={task.version}, "
|
||||
f"provided_version={task_data.version}"
|
||||
),
|
||||
)
|
||||
|
||||
# Capture old values for audit and triggers
|
||||
old_values = {
|
||||
"title": task.title,
|
||||
"description": task.description,
|
||||
"status_id": task.status_id,
|
||||
"assignee_id": task.assignee_id,
|
||||
"priority": task.priority,
|
||||
"start_date": task.start_date,
|
||||
"due_date": task.due_date,
|
||||
@@ -430,6 +431,17 @@ async def update_task(
|
||||
custom_values_data = update_data.pop("custom_values", None)
|
||||
update_data.pop("version", None) # version is handled separately for optimistic locking
|
||||
|
||||
old_custom_values = None
|
||||
if custom_values_data:
|
||||
old_custom_values = {
|
||||
cv.field_id: cv.value
|
||||
for cv in CustomValueService.get_custom_values_for_task(
|
||||
db,
|
||||
task,
|
||||
include_formula_calculations=True,
|
||||
)
|
||||
}
|
||||
|
||||
# Track old assignee for workload cache invalidation
|
||||
old_assignee_id = task.assignee_id
|
||||
|
||||
@@ -488,6 +500,8 @@ async def update_task(
|
||||
new_values = {
|
||||
"title": task.title,
|
||||
"description": task.description,
|
||||
"status_id": task.status_id,
|
||||
"assignee_id": task.assignee_id,
|
||||
"priority": task.priority,
|
||||
"start_date": task.start_date,
|
||||
"due_date": task.due_date,
|
||||
@@ -509,30 +523,46 @@ async def update_task(
|
||||
request_metadata=get_audit_metadata(request),
|
||||
)
|
||||
|
||||
# Evaluate triggers for priority changes
|
||||
if "priority" in update_data:
|
||||
TriggerService.evaluate_triggers(db, task, old_values, new_values, current_user)
|
||||
|
||||
# Handle custom values update
|
||||
new_custom_values = None
|
||||
if custom_values_data:
|
||||
try:
|
||||
from app.schemas.task import CustomValueInput
|
||||
custom_values = [CustomValueInput(**cv) for cv in custom_values_data]
|
||||
CustomValueService.save_custom_values(db, task, custom_values)
|
||||
new_custom_values = {
|
||||
cv.field_id: cv.value
|
||||
for cv in CustomValueService.get_custom_values_for_task(
|
||||
db,
|
||||
task,
|
||||
include_formula_calculations=True,
|
||||
)
|
||||
}
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
trigger_fields = {"status_id", "assignee_id", "priority", "start_date", "due_date"}
|
||||
trigger_relevant = any(field in update_data for field in trigger_fields) or custom_values_data
|
||||
if trigger_relevant:
|
||||
trigger_old_values = {field: old_values.get(field) for field in trigger_fields}
|
||||
trigger_new_values = {field: new_values.get(field) for field in trigger_fields}
|
||||
if old_custom_values is not None:
|
||||
trigger_old_values["custom_fields"] = old_custom_values
|
||||
if new_custom_values is not None:
|
||||
trigger_new_values["custom_fields"] = new_custom_values
|
||||
TriggerService.evaluate_triggers(db, task, trigger_old_values, trigger_new_values, current_user)
|
||||
|
||||
# Increment version for optimistic locking
|
||||
task.version += 1
|
||||
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
|
||||
# Invalidate workload cache if original_estimate changed and task has an assignee
|
||||
if "original_estimate" in update_data and task.assignee_id:
|
||||
# Invalidate workload cache if workload-affecting fields changed
|
||||
if ("original_estimate" in update_data or "due_date" in update_data) and task.assignee_id:
|
||||
invalidate_user_workload_cache(task.assignee_id)
|
||||
|
||||
# Invalidate workload cache if assignee changed
|
||||
|
||||
@@ -4,7 +4,7 @@ from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models import User, Project, Trigger, TriggerLog
|
||||
from app.models import User, Project, Trigger, TriggerLog, CustomField
|
||||
from app.schemas.trigger import (
|
||||
TriggerCreate, TriggerUpdate, TriggerResponse, TriggerListResponse,
|
||||
TriggerLogResponse, TriggerLogListResponse, TriggerUserInfo
|
||||
@@ -16,6 +16,10 @@ from app.services.action_executor import ActionValidationError
|
||||
|
||||
router = APIRouter(tags=["triggers"])
|
||||
|
||||
FIELD_CHANGE_FIELDS = {"status_id", "assignee_id", "priority", "start_date", "due_date", "custom_fields"}
|
||||
FIELD_CHANGE_OPERATORS = {"equals", "not_equals", "changed_to", "changed_from", "before", "after", "in"}
|
||||
DATE_FIELDS = {"start_date", "due_date"}
|
||||
|
||||
|
||||
def trigger_to_response(trigger: Trigger) -> TriggerResponse:
|
||||
"""Convert Trigger model to TriggerResponse."""
|
||||
@@ -39,6 +43,96 @@ def trigger_to_response(trigger: Trigger) -> TriggerResponse:
|
||||
)
|
||||
|
||||
|
||||
def _validate_field_change_conditions(conditions, project_id: str, db: Session) -> None:
|
||||
rules = []
|
||||
if conditions.rules is not None:
|
||||
if conditions.logic != "and":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Composite conditions only support logic 'and'",
|
||||
)
|
||||
if not conditions.rules:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Composite conditions require at least one rule",
|
||||
)
|
||||
rules = conditions.rules
|
||||
else:
|
||||
if not conditions.field:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Field is required for field_change triggers",
|
||||
)
|
||||
rules = [conditions]
|
||||
|
||||
for rule in rules:
|
||||
field = rule.field
|
||||
operator = rule.operator
|
||||
value = rule.value
|
||||
field_id = rule.field_id or getattr(conditions, "field_id", None)
|
||||
|
||||
if field not in FIELD_CHANGE_FIELDS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid condition field. Must be 'status_id', 'assignee_id', 'priority', 'start_date', 'due_date', or 'custom_fields'",
|
||||
)
|
||||
if not operator:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Operator is required for field_change triggers",
|
||||
)
|
||||
if operator not in FIELD_CHANGE_OPERATORS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid operator. Must be 'equals', 'not_equals', 'changed_to', 'changed_from', 'before', 'after', or 'in'",
|
||||
)
|
||||
if value is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Condition value is required for field_change triggers",
|
||||
)
|
||||
|
||||
field_type = None
|
||||
if field in DATE_FIELDS:
|
||||
field_type = "date"
|
||||
elif field == "custom_fields":
|
||||
if not field_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Custom field ID is required when field is custom_fields",
|
||||
)
|
||||
custom_field = db.query(CustomField).filter(
|
||||
CustomField.id == field_id,
|
||||
CustomField.project_id == project_id,
|
||||
).first()
|
||||
if not custom_field:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Custom field not found in this project",
|
||||
)
|
||||
field_type = custom_field.field_type
|
||||
|
||||
if operator in {"before", "after"}:
|
||||
if field_type not in {"date", "number", "formula"}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Operator 'before/after' is only valid for date or number fields",
|
||||
)
|
||||
|
||||
if operator == "in":
|
||||
if field_type == "date":
|
||||
if not isinstance(value, dict) or "start" not in value or "end" not in value:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Date 'in' operator requires a range with start and end",
|
||||
)
|
||||
elif not isinstance(value, list):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Operator 'in' requires a list of values",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/projects/{project_id}/triggers", response_model=TriggerResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_trigger(
|
||||
project_id: str,
|
||||
@@ -71,27 +165,7 @@ async def create_trigger(
|
||||
|
||||
# Validate conditions based on trigger type
|
||||
if trigger_data.trigger_type == "field_change":
|
||||
# Validate field_change conditions
|
||||
if not trigger_data.conditions.field:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Field is required for field_change triggers",
|
||||
)
|
||||
if trigger_data.conditions.field not in ["status_id", "assignee_id", "priority"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid condition field. Must be 'status_id', 'assignee_id', or 'priority'",
|
||||
)
|
||||
if not trigger_data.conditions.operator:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Operator is required for field_change triggers",
|
||||
)
|
||||
if trigger_data.conditions.operator not in ["equals", "not_equals", "changed_to", "changed_from"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid operator. Must be 'equals', 'not_equals', 'changed_to', or 'changed_from'",
|
||||
)
|
||||
_validate_field_change_conditions(trigger_data.conditions, project_id, db)
|
||||
elif trigger_data.trigger_type == "schedule":
|
||||
# Validate schedule conditions
|
||||
has_cron = trigger_data.conditions.cron_expression is not None
|
||||
@@ -234,11 +308,7 @@ async def update_trigger(
|
||||
if trigger_data.conditions is not None:
|
||||
# Validate conditions based on trigger type
|
||||
if trigger.trigger_type == "field_change":
|
||||
if trigger_data.conditions.field and trigger_data.conditions.field not in ["status_id", "assignee_id", "priority"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid condition field",
|
||||
)
|
||||
_validate_field_change_conditions(trigger_data.conditions, trigger.project_id, db)
|
||||
elif trigger.trigger_type == "schedule":
|
||||
# Validate cron expression if provided
|
||||
if trigger_data.conditions.cron_expression is not None:
|
||||
|
||||
@@ -283,7 +283,7 @@ async def update_user_capacity(
|
||||
)
|
||||
|
||||
# Store old capacity for audit log
|
||||
old_capacity = float(user.capacity) if user.capacity else None
|
||||
old_capacity = float(user.capacity) if user.capacity is not None else None
|
||||
|
||||
# Update capacity (validation is handled by Pydantic schema)
|
||||
user.capacity = capacity.capacity_hours
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.core import database
|
||||
from app.core.security import decode_access_token
|
||||
from app.core.redis import get_redis_sync
|
||||
from app.models import User, Notification, Project
|
||||
@@ -22,6 +23,8 @@ PONG_TIMEOUT = 30.0 # Disconnect if no pong received within this time after pi
|
||||
|
||||
# Authentication timeout (10 seconds)
|
||||
AUTH_TIMEOUT = 10.0
|
||||
if os.getenv("TESTING") == "true":
|
||||
AUTH_TIMEOUT = 1.0
|
||||
|
||||
|
||||
async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
@@ -41,7 +44,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
return None, None
|
||||
|
||||
# Get user from database
|
||||
db = SessionLocal()
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None or not user.is_active:
|
||||
@@ -54,7 +57,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
async def authenticate_websocket(
|
||||
websocket: WebSocket,
|
||||
query_token: Optional[str] = None
|
||||
) -> tuple[str | None, User | None]:
|
||||
) -> tuple[str | None, User | None, Optional[str]]:
|
||||
"""
|
||||
Authenticate WebSocket connection.
|
||||
|
||||
@@ -72,7 +75,10 @@ async def authenticate_websocket(
|
||||
"WebSocket authentication via query parameter is deprecated. "
|
||||
"Please use first-message authentication for better security."
|
||||
)
|
||||
return await get_user_from_token(query_token)
|
||||
user_id, user = await get_user_from_token(query_token)
|
||||
if user_id is None:
|
||||
return None, None, "invalid_token"
|
||||
return user_id, user, None
|
||||
|
||||
# Wait for authentication message with timeout
|
||||
try:
|
||||
@@ -84,26 +90,29 @@ async def authenticate_websocket(
|
||||
msg_type = data.get("type")
|
||||
if msg_type != "auth":
|
||||
logger.warning("Expected 'auth' message type, got: %s", msg_type)
|
||||
return None, None
|
||||
return None, None, "invalid_message"
|
||||
|
||||
token = data.get("token")
|
||||
if not token:
|
||||
logger.warning("No token provided in auth message")
|
||||
return None, None
|
||||
return None, None, "missing_token"
|
||||
|
||||
return await get_user_from_token(token)
|
||||
user_id, user = await get_user_from_token(token)
|
||||
if user_id is None:
|
||||
return None, None, "invalid_token"
|
||||
return user_id, user, None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
|
||||
return None, None
|
||||
return None, None, "timeout"
|
||||
except Exception as e:
|
||||
logger.error("Error during WebSocket authentication: %s", e)
|
||||
return None, None
|
||||
return None, None, "error"
|
||||
|
||||
|
||||
async def get_unread_notifications(user_id: str) -> list[dict]:
|
||||
"""Query all unread notifications for a user."""
|
||||
db = SessionLocal()
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
notifications = (
|
||||
db.query(Notification)
|
||||
@@ -130,7 +139,7 @@ async def get_unread_notifications(user_id: str) -> list[dict]:
|
||||
|
||||
async def get_unread_count(user_id: str) -> int:
|
||||
"""Get the count of unread notifications for a user."""
|
||||
db = SessionLocal()
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
return (
|
||||
db.query(Notification)
|
||||
@@ -174,14 +183,12 @@ async def websocket_notifications(
|
||||
# Accept WebSocket connection first
|
||||
await websocket.accept()
|
||||
|
||||
# If no query token, notify client that auth is required
|
||||
if not token:
|
||||
await websocket.send_json({"type": "auth_required"})
|
||||
|
||||
# Authenticate
|
||||
user_id, user = await authenticate_websocket(websocket, token)
|
||||
user_id, user, error_reason = await authenticate_websocket(websocket, token)
|
||||
|
||||
if user_id is None:
|
||||
if error_reason == "invalid_token":
|
||||
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
return
|
||||
|
||||
@@ -311,7 +318,7 @@ async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Pr
|
||||
Returns:
|
||||
Tuple of (has_access: bool, project: Project | None)
|
||||
"""
|
||||
db = SessionLocal()
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
# Get the user
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
@@ -365,14 +372,12 @@ async def websocket_project_sync(
|
||||
# Accept WebSocket connection first
|
||||
await websocket.accept()
|
||||
|
||||
# If no query token, notify client that auth is required
|
||||
if not token:
|
||||
await websocket.send_json({"type": "auth_required"})
|
||||
|
||||
# Authenticate user
|
||||
user_id, user = await authenticate_websocket(websocket, token)
|
||||
user_id, user, error_reason = await authenticate_websocket(websocket, token)
|
||||
|
||||
if user_id is None:
|
||||
if error_reason == "invalid_token":
|
||||
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
return
|
||||
|
||||
|
||||
@@ -139,7 +139,7 @@ async def get_heatmap(
|
||||
description="Comma-separated list of user IDs to include"
|
||||
),
|
||||
hide_empty: bool = Query(
|
||||
True,
|
||||
False,
|
||||
description="Hide users with no tasks assigned for the week"
|
||||
),
|
||||
db: Session = Depends(get_db),
|
||||
@@ -168,8 +168,20 @@ async def get_heatmap(
|
||||
if department_id:
|
||||
check_workload_access(current_user, department_id=department_id)
|
||||
|
||||
# Filter user_ids based on access (pass db for manager department lookup)
|
||||
accessible_user_ids = filter_accessible_users(current_user, parsed_user_ids, db)
|
||||
# Determine accessible users for this requester
|
||||
accessible_user_ids = filter_accessible_users(current_user, None, db)
|
||||
|
||||
# If specific user_ids are requested, ensure access is permitted
|
||||
if parsed_user_ids:
|
||||
if accessible_user_ids is not None:
|
||||
requested_ids = set(parsed_user_ids)
|
||||
allowed_ids = set(accessible_user_ids)
|
||||
if not requested_ids.issubset(allowed_ids):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied: Cannot view other users' workload",
|
||||
)
|
||||
accessible_user_ids = parsed_user_ids
|
||||
|
||||
# Normalize week_start
|
||||
if week_start is None:
|
||||
|
||||
Reference in New Issue
Block a user