feat: implement 8 OpenSpec proposals for security, reliability, and UX improvements

## Security Enhancements (P0)
- Add input validation with max_length and numeric range constraints
- Implement WebSocket token authentication via first message
- Add path traversal prevention in file storage service

## Permission Enhancements (P0)
- Add project member management for cross-department access
- Implement is_department_manager flag for workload visibility

## Cycle Detection (P0)
- Add DFS-based cycle detection for task dependencies
- Add formula field circular reference detection
- Display user-friendly cycle path visualization

## Concurrency & Reliability (P1)
- Implement optimistic locking with version field (409 Conflict on mismatch)
- Add trigger retry mechanism with exponential backoff (1s, 2s, 4s)
- Implement cascade restore for soft-deleted tasks

## Rate Limiting (P1)
- Add tiered rate limits: standard (60/min), sensitive (20/min), heavy (5/min)
- Apply rate limits to tasks, reports, attachments, and comments

## Frontend Improvements (P1)
- Add responsive sidebar with hamburger menu for mobile
- Improve touch-friendly UI with proper tap target sizes
- Complete i18n translations for all components

## Backend Reliability (P2)
- Configure database connection pool (size=10, overflow=20)
- Add Redis fallback mechanism with message queue
- Add blocker check before task deletion

## API Enhancements (P3)
- Add standardized response wrapper utility
- Add /health/ready and /health/live endpoints
- Implement project templates with status/field copying

## Tests Added
- test_input_validation.py - Schema and path traversal tests
- test_concurrency_reliability.py - Optimistic locking and retry tests
- test_backend_reliability.py - Connection pool and Redis tests
- test_api_enhancements.py - Health check and template tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-10 22:13:43 +08:00
parent 96210c7ad4
commit 3bdc6ff1c9
106 changed files with 9704 additions and 429 deletions

View File

@@ -8,6 +8,8 @@ from sqlalchemy.orm import Session
from typing import Optional
from app.core.database import get_db
from app.core.rate_limiter import limiter
from app.core.config import settings
from app.middleware.auth import get_current_user, check_task_access, check_task_edit_access
from app.models import User, Task, Project, Attachment, AttachmentVersion, EncryptionKey, AuditAction
from app.schemas.attachment import (
@@ -156,9 +158,10 @@ def should_encrypt_file(project: Project, db: Session) -> tuple[bool, Optional[E
@router.post("/tasks/{task_id}/attachments", response_model=AttachmentResponse)
@limiter.limit(settings.RATE_LIMIT_SENSITIVE)
async def upload_attachment(
task_id: str,
request: Request,
task_id: str,
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
@@ -167,6 +170,8 @@ async def upload_attachment(
Upload a file attachment to a task.
For confidential projects, files are automatically encrypted using AES-256-GCM.
Rate limited: 20 requests per minute (sensitive tier).
"""
task = get_task_with_access_check(db, task_id, current_user, require_edit=True)

View File

@@ -1,9 +1,11 @@
import uuid
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.core.rate_limiter import limiter
from app.core.config import settings
from app.models import User, Task, Comment
from app.schemas.comment import (
CommentCreate, CommentUpdate, CommentResponse, CommentListResponse,
@@ -49,13 +51,19 @@ def comment_to_response(comment: Comment) -> CommentResponse:
@router.post("/api/tasks/{task_id}/comments", response_model=CommentResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit(settings.RATE_LIMIT_STANDARD)
async def create_comment(
request: Request,
task_id: str,
comment_data: CommentCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Create a new comment on a task."""
"""
Create a new comment on a task.
Rate limited: 60 requests per minute (standard tier).
"""
task = db.query(Task).filter(Task.id == task_id).first()
if not task:

View File

@@ -91,13 +91,17 @@ async def create_custom_field(
detail="Formula is required for formula fields",
)
is_valid, error_msg = FormulaService.validate_formula(
is_valid, error_msg, cycle_path = FormulaService.validate_formula_with_details(
field_data.formula, project_id, db
)
if not is_valid:
detail = {"message": error_msg}
if cycle_path:
detail["cycle_path"] = cycle_path
detail["cycle_description"] = " -> ".join(cycle_path)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_msg,
detail=detail,
)
# Get next position
@@ -229,13 +233,17 @@ async def update_custom_field(
# Validate formula if updating formula field
if field.field_type == "formula" and field_data.formula is not None:
is_valid, error_msg = FormulaService.validate_formula(
is_valid, error_msg, cycle_path = FormulaService.validate_formula_with_details(
field_data.formula, field.project_id, db, field_id
)
if not is_valid:
detail = {"message": error_msg}
if cycle_path:
detail["cycle_path"] = cycle_path
detail["cycle_description"] = " -> ".join(cycle_path)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_msg,
detail=detail,
)
# Validate options if updating dropdown field

View File

@@ -4,10 +4,17 @@ from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models import User, Space, Project, TaskStatus, AuditAction
from app.models import User, Space, Project, TaskStatus, AuditAction, ProjectMember
from app.models.task_status import DEFAULT_STATUSES
from app.schemas.project import ProjectCreate, ProjectUpdate, ProjectResponse, ProjectWithDetails
from app.schemas.task_status import TaskStatusResponse
from app.schemas.project_member import (
ProjectMemberCreate,
ProjectMemberUpdate,
ProjectMemberResponse,
ProjectMemberWithDetails,
ProjectMemberListResponse,
)
from app.middleware.auth import (
get_current_user, check_space_access, check_space_edit_access,
check_project_access, check_project_edit_access
@@ -336,3 +343,271 @@ async def list_project_statuses(
).order_by(TaskStatus.position).all()
return statuses
# ============================================================================
# Project Members API - Cross-Department Collaboration
# ============================================================================
@router.get("/api/projects/{project_id}/members", response_model=ProjectMemberListResponse)
async def list_project_members(
project_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
List all members of a project.
Only users with project access can view the member list.
"""
project = db.query(Project).filter(Project.id == project_id, Project.is_active == True).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Project not found",
)
if not check_project_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied",
)
members = db.query(ProjectMember).filter(
ProjectMember.project_id == project_id
).all()
member_list = []
for member in members:
user = db.query(User).filter(User.id == member.user_id).first()
added_by_user = db.query(User).filter(User.id == member.added_by).first()
member_list.append(ProjectMemberWithDetails(
id=member.id,
project_id=member.project_id,
user_id=member.user_id,
role=member.role,
added_by=member.added_by,
created_at=member.created_at,
user_name=user.name if user else None,
user_email=user.email if user else None,
user_department_id=user.department_id if user else None,
user_department_name=user.department.name if user and user.department else None,
added_by_name=added_by_user.name if added_by_user else None,
))
return ProjectMemberListResponse(
members=member_list,
total=len(member_list),
)
@router.post("/api/projects/{project_id}/members", response_model=ProjectMemberResponse, status_code=status.HTTP_201_CREATED)
async def add_project_member(
project_id: str,
member_data: ProjectMemberCreate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Add a user as a project member for cross-department collaboration.
Only project owners and members with 'admin' role can add new members.
"""
project = db.query(Project).filter(Project.id == project_id, Project.is_active == True).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Project not found",
)
# Check if user has permission to add members (owner or admin member)
if not check_project_edit_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only project owner or admin members can add new members",
)
# Check if user exists
user_to_add = db.query(User).filter(User.id == member_data.user_id, User.is_active == True).first()
if not user_to_add:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
# Check if user is already a member
existing_member = db.query(ProjectMember).filter(
ProjectMember.project_id == project_id,
ProjectMember.user_id == member_data.user_id,
).first()
if existing_member:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="User is already a member of this project",
)
# Don't add the owner as a member (they already have access)
if member_data.user_id == project.owner_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Project owner cannot be added as a member",
)
# Create the membership
member = ProjectMember(
id=str(uuid.uuid4()),
project_id=project_id,
user_id=member_data.user_id,
role=member_data.role.value,
added_by=current_user.id,
)
db.add(member)
# Audit log
AuditService.log_event(
db=db,
event_type="project_member.add",
resource_type="project_member",
action=AuditAction.CREATE,
user_id=current_user.id,
resource_id=member.id,
changes=[
{"field": "user_id", "old_value": None, "new_value": member_data.user_id},
{"field": "role", "old_value": None, "new_value": member_data.role.value},
],
request_metadata=get_audit_metadata(request),
)
db.commit()
db.refresh(member)
return member
@router.patch("/api/projects/{project_id}/members/{member_id}", response_model=ProjectMemberResponse)
async def update_project_member(
project_id: str,
member_id: str,
member_data: ProjectMemberUpdate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Update a project member's role.
Only project owners and members with 'admin' role can update member roles.
"""
project = db.query(Project).filter(Project.id == project_id, Project.is_active == True).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Project not found",
)
if not check_project_edit_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only project owner or admin members can update member roles",
)
member = db.query(ProjectMember).filter(
ProjectMember.id == member_id,
ProjectMember.project_id == project_id,
).first()
if not member:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Member not found",
)
old_role = member.role
member.role = member_data.role.value
# Audit log
AuditService.log_event(
db=db,
event_type="project_member.update",
resource_type="project_member",
action=AuditAction.UPDATE,
user_id=current_user.id,
resource_id=member.id,
changes=[{"field": "role", "old_value": old_role, "new_value": member_data.role.value}],
request_metadata=get_audit_metadata(request),
)
db.commit()
db.refresh(member)
return member
@router.delete("/api/projects/{project_id}/members/{member_id}", status_code=status.HTTP_204_NO_CONTENT)
async def remove_project_member(
project_id: str,
member_id: str,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Remove a member from a project.
Only project owners and members with 'admin' role can remove members.
Members can also remove themselves from a project.
"""
project = db.query(Project).filter(Project.id == project_id, Project.is_active == True).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Project not found",
)
member = db.query(ProjectMember).filter(
ProjectMember.id == member_id,
ProjectMember.project_id == project_id,
).first()
if not member:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Member not found",
)
# Allow self-removal or admin access
is_self_removal = member.user_id == current_user.id
if not is_self_removal and not check_project_edit_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only project owner, admin members, or the member themselves can remove membership",
)
# Audit log
AuditService.log_event(
db=db,
event_type="project_member.remove",
resource_type="project_member",
action=AuditAction.DELETE,
user_id=current_user.id,
resource_id=member.id,
changes=[
{"field": "user_id", "old_value": member.user_id, "new_value": None},
{"field": "role", "old_value": member.role, "new_value": None},
],
request_metadata=get_audit_metadata(request),
)
db.delete(member)
db.commit()
return None

View File

@@ -1,8 +1,10 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
from sqlalchemy.orm import Session
from typing import Optional
from app.core.database import get_db
from app.core.rate_limiter import limiter
from app.core.config import settings
from app.models import User, ReportHistory, ScheduledReport
from app.schemas.report import (
WeeklyReportContent, ReportHistoryListResponse, ReportHistoryItem,
@@ -35,12 +37,16 @@ async def preview_weekly_report(
@router.post("/api/reports/weekly/generate", response_model=GenerateReportResponse)
@limiter.limit(settings.RATE_LIMIT_HEAVY)
async def generate_weekly_report(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Manually trigger weekly report generation for the current user.
Rate limited: 5 requests per minute (heavy tier).
"""
# Generate report
report_history = ReportService.generate_weekly_report(db, current_user.id)
@@ -112,13 +118,17 @@ async def list_report_history(
@router.get("/api/reports/history/{report_id}")
@limiter.limit(settings.RATE_LIMIT_SENSITIVE)
async def get_report_detail(
request: Request,
report_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Get detailed content of a specific report.
Rate limited: 20 requests per minute (sensitive tier).
"""
report = db.query(ReportHistory).filter(ReportHistory.id == report_id).first()

View File

@@ -10,13 +10,18 @@ from sqlalchemy.orm import Session
from typing import Optional
from app.core.database import get_db
from app.core.rate_limiter import limiter
from app.core.config import settings
from app.models import User, Task, TaskDependency, AuditAction
from app.schemas.task_dependency import (
TaskDependencyCreate,
TaskDependencyUpdate,
TaskDependencyResponse,
TaskDependencyListResponse,
TaskInfo
TaskInfo,
BulkDependencyCreate,
BulkDependencyValidationResult,
BulkDependencyCreateResponse,
)
from app.middleware.auth import get_current_user, check_task_access, check_task_edit_access
from app.middleware.audit import get_audit_metadata
@@ -429,3 +434,184 @@ async def list_project_dependencies(
dependencies=[dependency_to_response(d) for d in dependencies],
total=len(dependencies)
)
@router.post(
"/api/projects/{project_id}/dependencies/validate",
response_model=BulkDependencyValidationResult
)
async def validate_bulk_dependencies(
project_id: str,
bulk_data: BulkDependencyCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Validate a batch of dependencies without creating them.
This endpoint checks for:
- Self-references
- Cross-project dependencies
- Duplicate dependencies
- Circular dependencies (including cycles that would be created by the batch)
Returns validation results without modifying the database.
"""
from app.models import Project
from app.middleware.auth import check_project_access
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Project not found"
)
if not check_project_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied"
)
# Convert to tuple format for validation
dependencies = [
(dep.predecessor_id, dep.successor_id)
for dep in bulk_data.dependencies
]
errors = DependencyService.validate_bulk_dependencies(db, dependencies, project_id)
return BulkDependencyValidationResult(
valid=len(errors) == 0,
errors=errors
)
@router.post(
"/api/projects/{project_id}/dependencies/bulk",
response_model=BulkDependencyCreateResponse,
status_code=status.HTTP_201_CREATED
)
@limiter.limit(settings.RATE_LIMIT_HEAVY)
async def create_bulk_dependencies(
request: Request,
project_id: str,
bulk_data: BulkDependencyCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Create multiple dependencies at once.
This endpoint:
1. Validates all dependencies together for cycle detection
2. Creates valid dependencies
3. Returns both created dependencies and any failures
Cycle detection considers all dependencies in the batch together,
so cycles that would only appear when all dependencies are added
will be caught.
Rate limited: 5 requests per minute (heavy tier).
"""
from app.models import Project
from app.middleware.auth import check_project_edit_access
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Project not found"
)
if not check_project_edit_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Permission denied"
)
# Convert to tuple format for validation
dependencies_to_validate = [
(dep.predecessor_id, dep.successor_id)
for dep in bulk_data.dependencies
]
# Validate all dependencies together
errors = DependencyService.validate_bulk_dependencies(
db, dependencies_to_validate, project_id
)
# Build a set of failed dependency pairs for quick lookup
failed_pairs = set()
for error in errors:
pair = (error.get("predecessor_id"), error.get("successor_id"))
failed_pairs.add(pair)
created_dependencies = []
failed_items = errors # Include validation errors
# Create dependencies that passed validation
for dep_data in bulk_data.dependencies:
pair = (dep_data.predecessor_id, dep_data.successor_id)
if pair in failed_pairs:
continue
# Additional check: verify dependency limit for successor
current_count = db.query(TaskDependency).filter(
TaskDependency.successor_id == dep_data.successor_id
).count()
if current_count >= DependencyService.MAX_DIRECT_DEPENDENCIES:
failed_items.append({
"error_type": "limit_exceeded",
"predecessor_id": dep_data.predecessor_id,
"successor_id": dep_data.successor_id,
"message": f"Successor task already has {DependencyService.MAX_DIRECT_DEPENDENCIES} dependencies"
})
continue
# Create the dependency
dependency = TaskDependency(
id=str(uuid.uuid4()),
predecessor_id=dep_data.predecessor_id,
successor_id=dep_data.successor_id,
dependency_type=dep_data.dependency_type.value,
lag_days=dep_data.lag_days
)
db.add(dependency)
created_dependencies.append(dependency)
# Audit log for each created dependency
AuditService.log_event(
db=db,
event_type="task.dependency.create",
resource_type="task_dependency",
action=AuditAction.CREATE,
user_id=current_user.id,
resource_id=dependency.id,
changes=[{
"field": "dependency",
"old_value": None,
"new_value": {
"predecessor_id": dependency.predecessor_id,
"successor_id": dependency.successor_id,
"dependency_type": dependency.dependency_type,
"lag_days": dependency.lag_days
}
}],
request_metadata=get_audit_metadata(request)
)
db.commit()
# Refresh created dependencies to get relationships
for dep in created_dependencies:
db.refresh(dep)
return BulkDependencyCreateResponse(
created=[dependency_to_response(d) for d in created_dependencies],
failed=failed_items,
total_created=len(created_dependencies),
total_failed=len(failed_items)
)

View File

@@ -1,16 +1,20 @@
import logging
import uuid
from datetime import datetime, timezone
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.core.redis_pubsub import publish_task_event
from app.core.rate_limiter import limiter
from app.core.config import settings
from app.models import User, Project, Task, TaskStatus, AuditAction, Blocker
from app.schemas.task import (
TaskCreate, TaskUpdate, TaskResponse, TaskWithDetails, TaskListResponse,
TaskStatusUpdate, TaskAssignUpdate, CustomValueResponse
TaskStatusUpdate, TaskAssignUpdate, CustomValueResponse,
TaskRestoreRequest, TaskRestoreResponse,
TaskDeleteWarningResponse, TaskDeleteResponse
)
from app.middleware.auth import (
get_current_user, check_project_access, check_task_access, check_task_edit_access
@@ -72,6 +76,7 @@ def task_to_response(task: Task, db: Session = None, include_custom_values: bool
created_by=task.created_by,
created_at=task.created_at,
updated_at=task.updated_at,
version=task.version,
assignee_name=task.assignee.name if task.assignee else None,
status_name=task.status.name if task.status else None,
status_color=task.status.color if task.status else None,
@@ -161,15 +166,18 @@ async def list_tasks(
@router.post("/api/projects/{project_id}/tasks", response_model=TaskResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit(settings.RATE_LIMIT_STANDARD)
async def create_task(
request: Request,
project_id: str,
task_data: TaskCreate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Create a new task in a project.
Rate limited: 60 requests per minute (standard tier).
"""
project = db.query(Project).filter(Project.id == project_id).first()
@@ -367,15 +375,18 @@ async def get_task(
@router.patch("/api/tasks/{task_id}", response_model=TaskResponse)
@limiter.limit(settings.RATE_LIMIT_STANDARD)
async def update_task(
request: Request,
task_id: str,
task_data: TaskUpdate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Update a task.
Rate limited: 60 requests per minute (standard tier).
"""
task = db.query(Task).filter(Task.id == task_id).first()
@@ -391,6 +402,18 @@ async def update_task(
detail="Permission denied",
)
# Optimistic locking: validate version if provided
if task_data.version is not None:
if task_data.version != task.version:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"message": "Task has been modified by another user",
"current_version": task.version,
"provided_version": task_data.version,
},
)
# Capture old values for audit and triggers
old_values = {
"title": task.title,
@@ -402,9 +425,10 @@ async def update_task(
"time_spent": task.time_spent,
}
# Update fields (exclude custom_values, handle separately)
# Update fields (exclude custom_values and version, handle separately)
update_data = task_data.model_dump(exclude_unset=True)
custom_values_data = update_data.pop("custom_values", None)
update_data.pop("version", None) # version is handled separately for optimistic locking
# Track old assignee for workload cache invalidation
old_assignee_id = task.assignee_id
@@ -501,6 +525,9 @@ async def update_task(
detail=str(e),
)
# Increment version for optimistic locking
task.version += 1
db.commit()
db.refresh(task)
@@ -551,15 +578,20 @@ async def update_task(
return task
@router.delete("/api/tasks/{task_id}", response_model=TaskResponse)
@router.delete("/api/tasks/{task_id}")
async def delete_task(
task_id: str,
request: Request,
force_delete: bool = Query(False, description="Force delete even if task has unresolved blockers"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Soft delete a task (cascades to subtasks).
If the task has unresolved blockers and force_delete is False,
returns a warning response with status 200 and blocker count.
Use force_delete=true to delete anyway (auto-resolves blockers).
"""
task = db.query(Task).filter(Task.id == task_id).first()
@@ -581,9 +613,35 @@ async def delete_task(
detail="Permission denied",
)
# Check for unresolved blockers
unresolved_blockers = db.query(Blocker).filter(
Blocker.task_id == task.id,
Blocker.resolved_at == None,
).all()
blocker_count = len(unresolved_blockers)
# If there are unresolved blockers and force_delete is False, return warning
if blocker_count > 0 and not force_delete:
return TaskDeleteWarningResponse(
warning="Task has unresolved blockers",
blocker_count=blocker_count,
message=f"Task has {blocker_count} unresolved blocker(s). Use force_delete=true to delete anyway.",
)
# Use naive datetime for consistency with database storage
now = datetime.now(timezone.utc).replace(tzinfo=None)
# Auto-resolve blockers if force deleting
blockers_resolved = 0
if force_delete and blocker_count > 0:
for blocker in unresolved_blockers:
blocker.resolved_at = now
blocker.resolved_by = current_user.id
blocker.resolution_note = "Auto-resolved due to task deletion"
blockers_resolved += 1
logger.info(f"Auto-resolved {blockers_resolved} blocker(s) for task {task_id} during force delete")
# Soft delete the task
task.is_deleted = True
task.deleted_at = now
@@ -608,7 +666,11 @@ async def delete_task(
action=AuditAction.DELETE,
user_id=current_user.id,
resource_id=task.id,
changes=[{"field": "is_deleted", "old_value": False, "new_value": True}],
changes=[
{"field": "is_deleted", "old_value": False, "new_value": True},
{"field": "force_delete", "old_value": None, "new_value": force_delete},
{"field": "blockers_resolved", "old_value": None, "new_value": blockers_resolved},
],
request_metadata=get_audit_metadata(request),
)
@@ -635,18 +697,33 @@ async def delete_task(
except Exception as e:
logger.warning(f"Failed to publish task_deleted event: {e}")
return task
return TaskDeleteResponse(
task=task,
blockers_resolved=blockers_resolved,
force_deleted=force_delete and blocker_count > 0,
)
@router.post("/api/tasks/{task_id}/restore", response_model=TaskResponse)
@router.post("/api/tasks/{task_id}/restore", response_model=TaskRestoreResponse)
async def restore_task(
task_id: str,
request: Request,
restore_data: TaskRestoreRequest = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Restore a soft-deleted task (admin only).
Supports cascade restore: when enabled (default), also restores child tasks
that were deleted at the same time as the parent task.
Args:
task_id: ID of the task to restore
restore_data: Optional restore options (cascade=True by default)
Returns:
TaskRestoreResponse with restored task and list of restored children
"""
if not current_user.is_system_admin:
raise HTTPException(
@@ -654,6 +731,10 @@ async def restore_task(
detail="Only system administrators can restore deleted tasks",
)
# Handle default for optional body
if restore_data is None:
restore_data = TaskRestoreRequest()
task = db.query(Task).filter(Task.id == task_id).first()
if not task:
@@ -668,12 +749,16 @@ async def restore_task(
detail="Task is not deleted",
)
# Restore the task
# Store the parent's deleted_at timestamp for cascade restore
parent_deleted_at = task.deleted_at
restored_children_ids = []
# Restore the parent task
task.is_deleted = False
task.deleted_at = None
task.deleted_by = None
# Audit log
# Audit log for parent task
AuditService.log_event(
db=db,
event_type="task.restore",
@@ -681,18 +766,119 @@ async def restore_task(
action=AuditAction.UPDATE,
user_id=current_user.id,
resource_id=task.id,
changes=[{"field": "is_deleted", "old_value": True, "new_value": False}],
changes=[
{"field": "is_deleted", "old_value": True, "new_value": False},
{"field": "cascade", "old_value": None, "new_value": restore_data.cascade},
],
request_metadata=get_audit_metadata(request),
)
# Cascade restore child tasks if requested
if restore_data.cascade and parent_deleted_at:
restored_children_ids = _cascade_restore_children(
db=db,
parent_task=task,
parent_deleted_at=parent_deleted_at,
current_user=current_user,
request=request,
)
db.commit()
db.refresh(task)
# Invalidate workload cache for assignee
# Invalidate workload cache for parent task assignee
if task.assignee_id:
invalidate_user_workload_cache(task.assignee_id)
return task
# Invalidate workload cache for all restored children's assignees
for child_id in restored_children_ids:
child_task = db.query(Task).filter(Task.id == child_id).first()
if child_task and child_task.assignee_id:
invalidate_user_workload_cache(child_task.assignee_id)
return TaskRestoreResponse(
restored_task=task,
restored_children_count=len(restored_children_ids),
restored_children_ids=restored_children_ids,
)
def _cascade_restore_children(
db: Session,
parent_task: Task,
parent_deleted_at: datetime,
current_user: User,
request: Request,
tolerance_seconds: int = 5,
) -> List[str]:
"""
Recursively restore child tasks that were deleted at the same time as the parent.
Args:
db: Database session
parent_task: The parent task being restored
parent_deleted_at: Timestamp when the parent was deleted
current_user: Current user performing the restore
request: HTTP request for audit metadata
tolerance_seconds: Time tolerance for matching deleted_at timestamps
Returns:
List of restored child task IDs
"""
restored_ids = []
# Find all deleted child tasks with matching deleted_at timestamp
# Use a small tolerance window to account for slight timing differences
time_window_start = parent_deleted_at - timedelta(seconds=tolerance_seconds)
time_window_end = parent_deleted_at + timedelta(seconds=tolerance_seconds)
deleted_children = db.query(Task).filter(
Task.parent_task_id == parent_task.id,
Task.is_deleted == True,
Task.deleted_at >= time_window_start,
Task.deleted_at <= time_window_end,
).all()
for child in deleted_children:
# Store child's deleted_at before restoring
child_deleted_at = child.deleted_at
# Restore the child
child.is_deleted = False
child.deleted_at = None
child.deleted_by = None
restored_ids.append(child.id)
# Audit log for child task
AuditService.log_event(
db=db,
event_type="task.restore",
resource_type="task",
action=AuditAction.UPDATE,
user_id=current_user.id,
resource_id=child.id,
changes=[
{"field": "is_deleted", "old_value": True, "new_value": False},
{"field": "restored_via_cascade", "old_value": None, "new_value": parent_task.id},
],
request_metadata=get_audit_metadata(request),
)
logger.info(f"Cascade restored child task {child.id} (parent: {parent_task.id})")
# Recursively restore grandchildren
if child_deleted_at:
grandchildren_ids = _cascade_restore_children(
db=db,
parent_task=child,
parent_deleted_at=child_deleted_at,
current_user=current_user,
request=request,
tolerance_seconds=tolerance_seconds,
)
restored_ids.extend(grandchildren_ids)
return restored_ids
@router.patch("/api/tasks/{task_id}/status", response_model=TaskResponse)

View File

@@ -0,0 +1,3 @@
from app.api.templates.router import router
__all__ = ["router"]

View File

@@ -0,0 +1,440 @@
"""Project Templates API endpoints.
Provides CRUD operations for project templates and
the ability to create projects from templates.
"""
import uuid
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models import (
User, Space, Project, TaskStatus, CustomField, ProjectTemplate, AuditAction
)
from app.schemas.project_template import (
ProjectTemplateCreate,
ProjectTemplateUpdate,
ProjectTemplateResponse,
ProjectTemplateWithOwner,
ProjectTemplateListResponse,
CreateProjectFromTemplateRequest,
CreateProjectFromTemplateResponse,
)
from app.middleware.auth import get_current_user, check_space_access
from app.middleware.audit import get_audit_metadata
from app.services.audit_service import AuditService
router = APIRouter(prefix="/api/templates", tags=["Project Templates"])
def can_view_template(user: User, template: ProjectTemplate) -> bool:
"""Check if a user can view a template."""
if template.is_public:
return True
if template.owner_id == user.id:
return True
if user.is_system_admin:
return True
return False
def can_edit_template(user: User, template: ProjectTemplate) -> bool:
"""Check if a user can edit a template."""
if template.owner_id == user.id:
return True
if user.is_system_admin:
return True
return False
@router.get("", response_model=ProjectTemplateListResponse)
async def list_templates(
include_private: bool = Query(False, description="Include user's private templates"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
List available project templates.
By default, only returns public templates.
Set include_private=true to also include the user's private templates.
"""
query = db.query(ProjectTemplate).filter(ProjectTemplate.is_active == True)
if include_private:
# Public templates OR user's own templates
query = query.filter(
(ProjectTemplate.is_public == True) |
(ProjectTemplate.owner_id == current_user.id)
)
else:
# Only public templates
query = query.filter(ProjectTemplate.is_public == True)
templates = query.order_by(ProjectTemplate.name).all()
result = []
for template in templates:
result.append(ProjectTemplateWithOwner(
id=template.id,
name=template.name,
description=template.description,
is_public=template.is_public,
task_statuses=template.task_statuses,
custom_fields=template.custom_fields,
default_security_level=template.default_security_level,
owner_id=template.owner_id,
is_active=template.is_active,
created_at=template.created_at,
updated_at=template.updated_at,
owner_name=template.owner.name if template.owner else None,
))
return ProjectTemplateListResponse(
templates=result,
total=len(result),
)
@router.post("", response_model=ProjectTemplateResponse, status_code=status.HTTP_201_CREATED)
async def create_template(
template_data: ProjectTemplateCreate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Create a new project template.
The template can include predefined task statuses and custom fields
that will be copied when creating a project from this template.
"""
# Convert Pydantic models to dict for JSON storage
task_statuses_json = None
if template_data.task_statuses:
task_statuses_json = [ts.model_dump() for ts in template_data.task_statuses]
custom_fields_json = None
if template_data.custom_fields:
custom_fields_json = [cf.model_dump() for cf in template_data.custom_fields]
template = ProjectTemplate(
id=str(uuid.uuid4()),
name=template_data.name,
description=template_data.description,
owner_id=current_user.id,
is_public=template_data.is_public,
task_statuses=task_statuses_json,
custom_fields=custom_fields_json,
default_security_level=template_data.default_security_level,
)
db.add(template)
# Audit log
AuditService.log_event(
db=db,
event_type="template.create",
resource_type="project_template",
action=AuditAction.CREATE,
user_id=current_user.id,
resource_id=template.id,
changes=[{"field": "name", "old_value": None, "new_value": template.name}],
request_metadata=get_audit_metadata(request),
)
db.commit()
db.refresh(template)
return template
@router.get("/{template_id}", response_model=ProjectTemplateWithOwner)
async def get_template(
template_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Get a project template by ID.
"""
template = db.query(ProjectTemplate).filter(
ProjectTemplate.id == template_id,
ProjectTemplate.is_active == True
).first()
if not template:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Template not found",
)
if not can_view_template(current_user, template):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied",
)
return ProjectTemplateWithOwner(
id=template.id,
name=template.name,
description=template.description,
is_public=template.is_public,
task_statuses=template.task_statuses,
custom_fields=template.custom_fields,
default_security_level=template.default_security_level,
owner_id=template.owner_id,
is_active=template.is_active,
created_at=template.created_at,
updated_at=template.updated_at,
owner_name=template.owner.name if template.owner else None,
)
@router.patch("/{template_id}", response_model=ProjectTemplateResponse)
async def update_template(
template_id: str,
template_data: ProjectTemplateUpdate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Update a project template.
Only the template owner or system admin can update a template.
"""
template = db.query(ProjectTemplate).filter(
ProjectTemplate.id == template_id,
ProjectTemplate.is_active == True
).first()
if not template:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Template not found",
)
if not can_edit_template(current_user, template):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only template owner can update",
)
# Capture old values for audit
old_values = {
"name": template.name,
"description": template.description,
"is_public": template.is_public,
}
# Update fields
update_data = template_data.model_dump(exclude_unset=True)
# Convert Pydantic models to dict for JSON storage
if "task_statuses" in update_data and update_data["task_statuses"]:
update_data["task_statuses"] = [ts.model_dump() if hasattr(ts, 'model_dump') else ts for ts in update_data["task_statuses"]]
if "custom_fields" in update_data and update_data["custom_fields"]:
update_data["custom_fields"] = [cf.model_dump() if hasattr(cf, 'model_dump') else cf for cf in update_data["custom_fields"]]
for field, value in update_data.items():
setattr(template, field, value)
# Log changes
new_values = {
"name": template.name,
"description": template.description,
"is_public": template.is_public,
}
changes = AuditService.detect_changes(old_values, new_values)
if changes:
AuditService.log_event(
db=db,
event_type="template.update",
resource_type="project_template",
action=AuditAction.UPDATE,
user_id=current_user.id,
resource_id=template.id,
changes=changes,
request_metadata=get_audit_metadata(request),
)
db.commit()
db.refresh(template)
return template
@router.delete("/{template_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_template(
template_id: str,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Delete a project template (soft delete).
Only the template owner or system admin can delete a template.
"""
template = db.query(ProjectTemplate).filter(
ProjectTemplate.id == template_id,
ProjectTemplate.is_active == True
).first()
if not template:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Template not found",
)
if not can_edit_template(current_user, template):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only template owner can delete",
)
# Audit log
AuditService.log_event(
db=db,
event_type="template.delete",
resource_type="project_template",
action=AuditAction.DELETE,
user_id=current_user.id,
resource_id=template.id,
changes=[{"field": "is_active", "old_value": True, "new_value": False}],
request_metadata=get_audit_metadata(request),
)
# Soft delete
template.is_active = False
db.commit()
return None
@router.post("/create-project", response_model=CreateProjectFromTemplateResponse, status_code=status.HTTP_201_CREATED)
async def create_project_from_template(
data: CreateProjectFromTemplateRequest,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Create a new project from a template.
This will:
1. Create a new project with the specified title and description
2. Copy all task statuses from the template
3. Copy all custom field definitions from the template
"""
# Get the template
template = db.query(ProjectTemplate).filter(
ProjectTemplate.id == data.template_id,
ProjectTemplate.is_active == True
).first()
if not template:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Template not found",
)
if not can_view_template(current_user, template):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to template",
)
# Check space access
space = db.query(Space).filter(
Space.id == data.space_id,
Space.is_active == True
).first()
if not space:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Space not found",
)
if not check_space_access(current_user, space):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to space",
)
# Create the project
project = Project(
id=str(uuid.uuid4()),
space_id=data.space_id,
title=data.title,
description=data.description,
owner_id=current_user.id,
security_level=template.default_security_level or "department",
department_id=data.department_id or current_user.department_id,
)
db.add(project)
db.flush() # Get project ID
# Copy task statuses from template
task_statuses_created = 0
if template.task_statuses:
for status_data in template.task_statuses:
task_status = TaskStatus(
id=str(uuid.uuid4()),
project_id=project.id,
name=status_data.get("name", "Unnamed"),
color=status_data.get("color", "#808080"),
position=status_data.get("position", 0),
is_done=status_data.get("is_done", False),
)
db.add(task_status)
task_statuses_created += 1
# Copy custom fields from template
custom_fields_created = 0
if template.custom_fields:
for field_data in template.custom_fields:
custom_field = CustomField(
id=str(uuid.uuid4()),
project_id=project.id,
name=field_data.get("name", "Unnamed"),
field_type=field_data.get("field_type", "text"),
options=field_data.get("options"),
formula=field_data.get("formula"),
is_required=field_data.get("is_required", False),
position=field_data.get("position", 0),
)
db.add(custom_field)
custom_fields_created += 1
# Audit log
AuditService.log_event(
db=db,
event_type="project.create_from_template",
resource_type="project",
action=AuditAction.CREATE,
user_id=current_user.id,
resource_id=project.id,
changes=[
{"field": "title", "old_value": None, "new_value": project.title},
{"field": "template_id", "old_value": None, "new_value": template.id},
],
request_metadata=get_audit_metadata(request),
)
db.commit()
return CreateProjectFromTemplateResponse(
id=project.id,
title=project.title,
template_id=template.id,
template_name=template.name,
task_statuses_created=task_statuses_created,
custom_fields_created=custom_fields_created,
)

View File

@@ -1,6 +1,7 @@
import asyncio
import logging
import time
from typing import Optional
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from sqlalchemy.orm import Session
@@ -19,6 +20,9 @@ router = APIRouter(tags=["websocket"])
PING_INTERVAL = 60.0 # Send ping after this many seconds of no messages
PONG_TIMEOUT = 30.0 # Disconnect if no pong received within this time after ping
# Authentication timeout (10 seconds)
AUTH_TIMEOUT = 10.0
async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
"""Validate token and return user_id and user object."""
@@ -47,6 +51,56 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
db.close()
async def authenticate_websocket(
websocket: WebSocket,
query_token: Optional[str] = None
) -> tuple[str | None, User | None]:
"""
Authenticate WebSocket connection.
Supports two authentication methods:
1. First message authentication (preferred, more secure)
- Client sends: {"type": "auth", "token": "<jwt_token>"}
2. Query parameter authentication (deprecated, for backward compatibility)
- Client connects with: ?token=<jwt_token>
Returns (user_id, user) if authenticated, (None, None) otherwise.
"""
# If token provided via query parameter (backward compatibility)
if query_token:
logger.warning(
"WebSocket authentication via query parameter is deprecated. "
"Please use first-message authentication for better security."
)
return await get_user_from_token(query_token)
# Wait for authentication message with timeout
try:
data = await asyncio.wait_for(
websocket.receive_json(),
timeout=AUTH_TIMEOUT
)
msg_type = data.get("type")
if msg_type != "auth":
logger.warning("Expected 'auth' message type, got: %s", msg_type)
return None, None
token = data.get("token")
if not token:
logger.warning("No token provided in auth message")
return None, None
return await get_user_from_token(token)
except asyncio.TimeoutError:
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
return None, None
except Exception as e:
logger.error("Error during WebSocket authentication: %s", e)
return None, None
async def get_unread_notifications(user_id: str) -> list[dict]:
"""Query all unread notifications for a user."""
db = SessionLocal()
@@ -90,14 +144,22 @@ async def get_unread_count(user_id: str) -> int:
@router.websocket("/ws/notifications")
async def websocket_notifications(
websocket: WebSocket,
token: str = Query(..., description="JWT token for authentication"),
token: Optional[str] = Query(None, description="JWT token (deprecated, use first-message auth)"),
):
"""
WebSocket endpoint for real-time notifications.
Connect with: ws://host/ws/notifications?token=<jwt_token>
Authentication methods (in order of preference):
1. First message authentication (recommended):
- Connect without token: ws://host/ws/notifications
- Send: {"type": "auth", "token": "<jwt_token>"}
- Must authenticate within 10 seconds or connection will be closed
2. Query parameter (deprecated, for backward compatibility):
- Connect with: ws://host/ws/notifications?token=<jwt_token>
Messages sent by server:
- {"type": "auth_required"} - Sent when waiting for auth message
- {"type": "connected", "data": {"user_id": "...", "message": "..."}} - Connection success
- {"type": "unread_sync", "data": {"notifications": [...], "unread_count": N}} - All unread on connect
- {"type": "notification", "data": {...}} - New notification
@@ -106,9 +168,18 @@ async def websocket_notifications(
- {"type": "pong"} - Response to client ping
Messages accepted from client:
- {"type": "auth", "token": "..."} - Authentication (must be first message if no query token)
- {"type": "ping"} - Client keepalive ping
"""
user_id, user = await get_user_from_token(token)
# Accept WebSocket connection first
await websocket.accept()
# If no query token, notify client that auth is required
if not token:
await websocket.send_json({"type": "auth_required"})
# Authenticate
user_id, user = await authenticate_websocket(websocket, token)
if user_id is None:
await websocket.close(code=4001, reason="Invalid or expired token")
@@ -263,14 +334,22 @@ async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Pr
async def websocket_project_sync(
websocket: WebSocket,
project_id: str,
token: str = Query(..., description="JWT token for authentication"),
token: Optional[str] = Query(None, description="JWT token (deprecated, use first-message auth)"),
):
"""
WebSocket endpoint for project task real-time sync.
Connect with: ws://host/ws/projects/{project_id}?token=<jwt_token>
Authentication methods (in order of preference):
1. First message authentication (recommended):
- Connect without token: ws://host/ws/projects/{project_id}
- Send: {"type": "auth", "token": "<jwt_token>"}
- Must authenticate within 10 seconds or connection will be closed
2. Query parameter (deprecated, for backward compatibility):
- Connect with: ws://host/ws/projects/{project_id}?token=<jwt_token>
Messages sent by server:
- {"type": "auth_required"} - Sent when waiting for auth message
- {"type": "connected", "data": {"project_id": "...", "user_id": "..."}}
- {"type": "task_created", "data": {...}, "triggered_by": "..."}
- {"type": "task_updated", "data": {...}, "triggered_by": "..."}
@@ -280,10 +359,18 @@ async def websocket_project_sync(
- {"type": "ping"} / {"type": "pong"}
Messages accepted from client:
- {"type": "auth", "token": "..."} - Authentication (must be first message if no query token)
- {"type": "ping"} - Client keepalive ping
"""
# Accept WebSocket connection first
await websocket.accept()
# If no query token, notify client that auth is required
if not token:
await websocket.send_json({"type": "auth_required"})
# Authenticate user
user_id, user = await get_user_from_token(token)
user_id, user = await authenticate_websocket(websocket, token)
if user_id is None:
await websocket.close(code=4001, reason="Invalid or expired token")
@@ -300,8 +387,7 @@ async def websocket_project_sync(
await websocket.close(code=4004, reason="Project not found")
return
# Accept connection and join project room
await websocket.accept()
# Join project room
await manager.join_project(websocket, user_id, project_id)
# Create Redis subscriber for project task events

View File

@@ -41,22 +41,34 @@ def check_workload_access(
"""
Check if current user has access to view workload data.
Access rules:
- System admin: can access all workloads
- Department manager: can access workloads of users in their department
- Regular user: can only access their own workload
Raises HTTPException if access is denied.
"""
# System admin can access all
if current_user.is_system_admin:
return
# If querying specific user, must be self
# (Phase 1: only self access for non-admin users)
# If querying specific user
if target_user_id and target_user_id != current_user.id:
# Department manager can view subordinates' workload
if current_user.is_department_manager:
# Manager can view users in their department
# target_user_department_id must be provided for this check
if target_user_department_id and target_user_department_id == current_user.department_id:
return
# Access denied for non-manager or user not in same department
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: Cannot view other users' workload",
)
# If querying by department, must be same department
# If querying by department
if department_id and department_id != current_user.department_id:
# Department manager can only query their own department
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied: Cannot view other departments' workload",
@@ -66,15 +78,40 @@ def check_workload_access(
def filter_accessible_users(
current_user: User,
user_ids: Optional[List[str]] = None,
db: Optional[Session] = None,
) -> Optional[List[str]]:
"""
Filter user IDs to only those accessible by current user.
Returns None if user can access all (system admin).
Access rules:
- System admin: can see all users
- Department manager: can see all users in their department
- Regular user: can only see themselves
"""
# System admin can access all
if current_user.is_system_admin:
return user_ids
# Department manager can see all users in their department
if current_user.is_department_manager and current_user.department_id and db:
# Get all users in the same department
department_users = db.query(User.id).filter(
User.department_id == current_user.department_id,
User.is_active == True
).all()
department_user_ids = {u.id for u in department_users}
if user_ids:
# Filter to only users in manager's department
accessible = [uid for uid in user_ids if uid in department_user_ids]
if not accessible:
return [current_user.id]
return accessible
else:
# No filter specified, return all department users
return list(department_user_ids)
# Regular user can only see themselves
if user_ids:
# Filter to only accessible users
@@ -111,6 +148,11 @@ async def get_heatmap(
"""
Get workload heatmap for users.
Access Rules:
- System admin: Can view all users' workload
- Department manager: Can view workload of all users in their department
- Regular user: Can only view their own workload
Returns workload summaries for users showing:
- allocated_hours: Total estimated hours from tasks due this week
- capacity_hours: User's weekly capacity
@@ -126,8 +168,8 @@ async def get_heatmap(
if department_id:
check_workload_access(current_user, department_id=department_id)
# Filter user_ids based on access
accessible_user_ids = filter_accessible_users(current_user, parsed_user_ids)
# Filter user_ids based on access (pass db for manager department lookup)
accessible_user_ids = filter_accessible_users(current_user, parsed_user_ids, db)
# Normalize week_start
if week_start is None:
@@ -181,12 +223,25 @@ async def get_user_workload(
"""
Get detailed workload for a specific user.
Access rules:
- System admin: can view any user's workload
- Department manager: can view workload of users in their department
- Regular user: can only view their own workload
Returns:
- Workload summary (same as heatmap)
- List of tasks contributing to the workload
"""
# Check access
check_workload_access(current_user, target_user_id=user_id)
# Get target user's department for manager access check
target_user = db.query(User).filter(User.id == user_id).first()
target_user_department_id = target_user.department_id if target_user else None
# Check access (pass target user's department for manager check)
check_workload_access(
current_user,
target_user_id=user_id,
target_user_department_id=target_user_department_id
)
# Calculate workload detail
detail = get_user_workload_detail(db, user_id, week_start)

View File

@@ -115,6 +115,13 @@ class Settings(BaseSettings):
"exe", "bat", "cmd", "sh", "ps1", "dll", "msi", "com", "scr", "vbs", "js"
]
# Rate Limiting Configuration
# Tiers: standard, sensitive, heavy
# Format: "{requests}/{period}" (e.g., "60/minute", "20/minute", "5/minute")
RATE_LIMIT_STANDARD: str = "60/minute" # Task CRUD, comments
RATE_LIMIT_SENSITIVE: str = "20/minute" # Attachments, password change, report export
RATE_LIMIT_HEAVY: str = "5/minute" # Report generation, bulk operations
class Config:
env_file = ".env"
case_sensitive = True

View File

@@ -1,19 +1,109 @@
from sqlalchemy import create_engine
import logging
import threading
import os
from sqlalchemy import create_engine, event
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from app.core.config import settings
logger = logging.getLogger(__name__)
# Connection pool configuration with environment variable overrides
POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "10"))
MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "20"))
POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
POOL_STATS_INTERVAL = int(os.getenv("DB_POOL_STATS_INTERVAL", "300")) # 5 minutes
engine = create_engine(
settings.DATABASE_URL,
pool_pre_ping=True,
pool_recycle=3600,
pool_size=POOL_SIZE,
max_overflow=MAX_OVERFLOW,
pool_timeout=POOL_TIMEOUT,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Connection pool statistics tracking
_pool_stats_lock = threading.Lock()
_pool_stats = {
"checkouts": 0,
"checkins": 0,
"overflow_connections": 0,
"invalidated_connections": 0,
}
def _log_pool_statistics():
"""Log current connection pool statistics."""
pool = engine.pool
with _pool_stats_lock:
logger.info(
"Database connection pool statistics: "
"size=%d, checked_in=%d, overflow=%d, "
"total_checkouts=%d, total_checkins=%d, invalidated=%d",
pool.size(),
pool.checkedin(),
pool.overflow(),
_pool_stats["checkouts"],
_pool_stats["checkins"],
_pool_stats["invalidated_connections"],
)
def _start_pool_stats_logging():
"""Start periodic logging of connection pool statistics."""
if POOL_STATS_INTERVAL <= 0:
return
def log_stats():
_log_pool_statistics()
# Schedule next log
timer = threading.Timer(POOL_STATS_INTERVAL, log_stats)
timer.daemon = True
timer.start()
# Start the first timer
timer = threading.Timer(POOL_STATS_INTERVAL, log_stats)
timer.daemon = True
timer.start()
logger.info(
"Database connection pool initialized: pool_size=%d, max_overflow=%d, pool_timeout=%d, stats_interval=%ds",
POOL_SIZE, MAX_OVERFLOW, POOL_TIMEOUT, POOL_STATS_INTERVAL
)
# Register pool event listeners for statistics
@event.listens_for(engine, "checkout")
def _on_checkout(dbapi_conn, connection_record, connection_proxy):
"""Track connection checkout events."""
with _pool_stats_lock:
_pool_stats["checkouts"] += 1
@event.listens_for(engine, "checkin")
def _on_checkin(dbapi_conn, connection_record):
"""Track connection checkin events."""
with _pool_stats_lock:
_pool_stats["checkins"] += 1
@event.listens_for(engine, "invalidate")
def _on_invalidate(dbapi_conn, connection_record, exception):
"""Track connection invalidation events."""
with _pool_stats_lock:
_pool_stats["invalidated_connections"] += 1
if exception:
logger.warning("Database connection invalidated due to exception: %s", exception)
# Start pool statistics logging on module load
_start_pool_stats_logging()
def get_db():
"""Dependency for getting database session."""
@@ -22,3 +112,18 @@ def get_db():
yield db
finally:
db.close()
def get_pool_status() -> dict:
"""Get current connection pool status for health checks."""
pool = engine.pool
with _pool_stats_lock:
return {
"pool_size": pool.size(),
"checked_in": pool.checkedin(),
"checked_out": pool.checkedout(),
"overflow": pool.overflow(),
"total_checkouts": _pool_stats["checkouts"],
"total_checkins": _pool_stats["checkins"],
"invalidated_connections": _pool_stats["invalidated_connections"],
}

View File

@@ -0,0 +1,45 @@
"""Deprecation middleware for legacy API routes.
Provides middleware to add deprecation warning headers to legacy /api/ routes
during the transition to /api/v1/.
"""
import logging
from datetime import datetime
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
logger = logging.getLogger(__name__)
class DeprecationMiddleware(BaseHTTPMiddleware):
"""Middleware to add deprecation headers to legacy API routes.
This middleware checks if a request is using a legacy /api/ route
(instead of /api/v1/) and adds appropriate deprecation headers to
encourage migration to the new versioned API.
"""
# Sunset date for legacy routes (6 months from now, adjust as needed)
SUNSET_DATE = "2026-07-01T00:00:00Z"
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Check if this is a legacy /api/ route (not /api/v1/)
path = request.url.path
if path.startswith("/api/") and not path.startswith("/api/v1/"):
# Skip deprecation headers for health check endpoints
if path in ["/health", "/health/ready", "/health/live", "/health/detailed"]:
return response
# Add deprecation headers (RFC 8594)
response.headers["Deprecation"] = "true"
response.headers["Sunset"] = self.SUNSET_DATE
response.headers["Link"] = f'</api/v1{path[4:]}>; rel="successor-version"'
response.headers["X-Deprecation-Notice"] = (
"This API endpoint is deprecated. "
"Please migrate to /api/v1/ prefix. "
f"This endpoint will be removed after {self.SUNSET_DATE}."
)
return response

View File

@@ -3,11 +3,19 @@ Rate limiting configuration using slowapi with Redis backend.
This module provides rate limiting functionality to protect against
brute force attacks and DoS attempts on sensitive endpoints.
Rate Limit Tiers:
- standard: 60/minute - For normal CRUD operations (tasks, comments)
- sensitive: 20/minute - For sensitive operations (attachments, password change)
- heavy: 5/minute - For resource-intensive operations (reports, bulk operations)
"""
import logging
import os
from functools import wraps
from typing import Callable, Optional
from fastapi import Request, Response
from slowapi import Limiter
from slowapi.util import get_remote_address
@@ -60,8 +68,56 @@ _storage_uri = _get_storage_uri()
# Create limiter instance with appropriate storage
# Uses the client's remote address (IP) as the key for rate limiting
# Note: headers_enabled=False because slowapi's header injection requires Response objects,
# which conflicts with endpoints that return Pydantic models directly.
# Rate limit status can be checked via the 429 Too Many Requests response.
limiter = Limiter(
key_func=get_remote_address,
storage_uri=_storage_uri,
strategy="fixed-window", # Fixed window strategy for predictable rate limiting
headers_enabled=False, # Disabled due to compatibility issues with Pydantic model responses
)
# Convenience functions for rate limit tiers
def get_rate_limit_standard() -> str:
"""Get the standard rate limit tier (60/minute by default)."""
return settings.RATE_LIMIT_STANDARD
def get_rate_limit_sensitive() -> str:
"""Get the sensitive rate limit tier (20/minute by default)."""
return settings.RATE_LIMIT_SENSITIVE
def get_rate_limit_heavy() -> str:
"""Get the heavy rate limit tier (5/minute by default)."""
return settings.RATE_LIMIT_HEAVY
# Pre-configured rate limit decorators for common use cases
def rate_limit_standard(func: Optional[Callable] = None):
"""
Apply standard rate limit (60/minute) for normal CRUD operations.
Use for: Task creation/update, comment creation, etc.
"""
return limiter.limit(get_rate_limit_standard())(func) if func else limiter.limit(get_rate_limit_standard())
def rate_limit_sensitive(func: Optional[Callable] = None):
"""
Apply sensitive rate limit (20/minute) for sensitive operations.
Use for: File uploads, password changes, report exports, etc.
"""
return limiter.limit(get_rate_limit_sensitive())(func) if func else limiter.limit(get_rate_limit_sensitive())
def rate_limit_heavy(func: Optional[Callable] = None):
"""
Apply heavy rate limit (5/minute) for resource-intensive operations.
Use for: Report generation, bulk operations, data exports, etc.
"""
return limiter.limit(get_rate_limit_heavy())(func) if func else limiter.limit(get_rate_limit_heavy())

View File

@@ -0,0 +1,178 @@
"""Standardized API response wrapper.
Provides utility classes and functions for consistent API response formatting
across all endpoints.
"""
from datetime import datetime
from typing import Any, Generic, Optional, TypeVar
from pydantic import BaseModel, Field
T = TypeVar("T")
class ErrorDetail(BaseModel):
"""Detailed error information."""
error_code: str = Field(..., description="Machine-readable error code")
message: str = Field(..., description="Human-readable error message")
field: Optional[str] = Field(None, description="Field that caused the error, if applicable")
details: Optional[dict] = Field(None, description="Additional error details")
class ApiResponse(BaseModel, Generic[T]):
"""Standard API response wrapper.
All API endpoints should return responses in this format for consistency.
Attributes:
success: Whether the request was successful
data: The actual response data (null for errors)
message: Human-readable message about the result
timestamp: ISO 8601 timestamp of the response
error: Error details if success is False
"""
success: bool = Field(..., description="Whether the request was successful")
data: Optional[T] = Field(None, description="Response data")
message: Optional[str] = Field(None, description="Human-readable message")
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat() + "Z",
description="ISO 8601 timestamp"
)
error: Optional[ErrorDetail] = Field(None, description="Error details if failed")
class Config:
from_attributes = True
class PaginatedData(BaseModel, Generic[T]):
"""Paginated data structure."""
items: list[T] = Field(default_factory=list, description="List of items")
total: int = Field(..., description="Total number of items")
page: int = Field(..., description="Current page number (1-indexed)")
page_size: int = Field(..., description="Number of items per page")
total_pages: int = Field(..., description="Total number of pages")
class Config:
from_attributes = True
# Error codes for common scenarios
class ErrorCode:
"""Standard error codes for API responses."""
# Authentication & Authorization
UNAUTHORIZED = "AUTH_001"
FORBIDDEN = "AUTH_002"
TOKEN_EXPIRED = "AUTH_003"
INVALID_TOKEN = "AUTH_004"
# Validation
VALIDATION_ERROR = "VAL_001"
INVALID_INPUT = "VAL_002"
MISSING_FIELD = "VAL_003"
INVALID_FORMAT = "VAL_004"
# Resource
NOT_FOUND = "RES_001"
ALREADY_EXISTS = "RES_002"
CONFLICT = "RES_003"
DELETED = "RES_004"
# Business Logic
BUSINESS_ERROR = "BIZ_001"
INVALID_STATE = "BIZ_002"
LIMIT_EXCEEDED = "BIZ_003"
DEPENDENCY_ERROR = "BIZ_004"
# Server
INTERNAL_ERROR = "SRV_001"
DATABASE_ERROR = "SRV_002"
EXTERNAL_SERVICE_ERROR = "SRV_003"
RATE_LIMITED = "SRV_004"
def success_response(
data: Any = None,
message: Optional[str] = None,
) -> dict:
"""Create a successful API response.
Args:
data: The response data
message: Optional human-readable message
Returns:
Dictionary with standard response structure
"""
return {
"success": True,
"data": data,
"message": message,
"timestamp": datetime.utcnow().isoformat() + "Z",
"error": None,
}
def error_response(
error_code: str,
message: str,
field: Optional[str] = None,
details: Optional[dict] = None,
) -> dict:
"""Create an error API response.
Args:
error_code: Machine-readable error code (use ErrorCode constants)
message: Human-readable error message
field: Optional field name that caused the error
details: Optional additional error details
Returns:
Dictionary with standard error response structure
"""
return {
"success": False,
"data": None,
"message": message,
"timestamp": datetime.utcnow().isoformat() + "Z",
"error": {
"error_code": error_code,
"message": message,
"field": field,
"details": details,
},
}
def paginated_response(
items: list,
total: int,
page: int,
page_size: int,
message: Optional[str] = None,
) -> dict:
"""Create a paginated API response.
Args:
items: List of items for current page
total: Total number of items across all pages
page: Current page number (1-indexed)
page_size: Number of items per page
message: Optional human-readable message
Returns:
Dictionary with standard paginated response structure
"""
total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0
return {
"success": True,
"data": {
"items": items,
"total": total,
"page": page,
"page_size": page_size,
"total_pages": total_pages,
},
"message": message,
"timestamp": datetime.utcnow().isoformat() + "Z",
"error": None,
}

View File

@@ -1,13 +1,16 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from datetime import datetime
from fastapi import FastAPI, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from sqlalchemy import text
from app.middleware.audit import AuditMiddleware
from app.core.scheduler import start_scheduler, shutdown_scheduler
from app.core.scheduler import start_scheduler, shutdown_scheduler, scheduler
from app.core.rate_limiter import limiter
from app.core.deprecation import DeprecationMiddleware
@asynccontextmanager
@@ -18,6 +21,8 @@ async def lifespan(app: FastAPI):
yield
# Shutdown
shutdown_scheduler()
from app.api.auth import router as auth_router
from app.api.users import router as users_router
from app.api.departments import router as departments_router
@@ -38,12 +43,17 @@ from app.api.custom_fields import router as custom_fields_router
from app.api.task_dependencies import router as task_dependencies_router
from app.api.admin import encryption_keys as admin_encryption_keys_router
from app.api.dashboard import router as dashboard_router
from app.api.templates import router as templates_router
from app.core.config import settings
from app.core.database import get_pool_status, engine
from app.core.redis import redis_client
from app.services.notification_service import get_redis_fallback_status
from app.services.file_storage_service import file_storage_service
app = FastAPI(
title="Project Control API",
description="Cross-departmental project management system API",
version="0.1.0",
version="1.0.0",
lifespan=lifespan,
)
@@ -63,13 +73,37 @@ app.add_middleware(
# Audit middleware - extracts request metadata for audit logging
app.add_middleware(AuditMiddleware)
# Include routers
# Deprecation middleware - adds deprecation headers to legacy /api/ routes
app.add_middleware(DeprecationMiddleware)
# =============================================================================
# API Version 1 Router - Primary API namespace
# =============================================================================
api_v1_router = APIRouter(prefix="/api/v1")
# Include routers under /api/v1/
api_v1_router.include_router(auth_router.router, prefix="/auth", tags=["Authentication"])
api_v1_router.include_router(users_router.router, prefix="/users", tags=["Users"])
api_v1_router.include_router(departments_router.router, prefix="/departments", tags=["Departments"])
api_v1_router.include_router(workload_router, prefix="/workload", tags=["Workload"])
api_v1_router.include_router(dashboard_router, prefix="/dashboard", tags=["Dashboard"])
api_v1_router.include_router(templates_router, tags=["Project Templates"])
# Mount the v1 router
app.include_router(api_v1_router)
# =============================================================================
# Legacy /api/ Routes (Deprecated - for backwards compatibility)
# =============================================================================
# These routes will be removed in a future version.
# All new integrations should use /api/v1/ prefix.
app.include_router(auth_router.router, prefix="/api/auth", tags=["Authentication"])
app.include_router(users_router.router, prefix="/api/users", tags=["Users"])
app.include_router(departments_router.router, prefix="/api/departments", tags=["Departments"])
app.include_router(spaces_router)
app.include_router(projects_router)
app.include_router(tasks_router)
app.include_router(spaces_router) # Has /api/spaces prefix in router
app.include_router(projects_router) # Has routes with /api prefix
app.include_router(tasks_router) # Has routes with /api prefix
app.include_router(workload_router, prefix="/api/workload", tags=["Workload"])
app.include_router(comments_router)
app.include_router(notifications_router)
@@ -79,13 +113,176 @@ app.include_router(audit_router)
app.include_router(attachments_router)
app.include_router(triggers_router)
app.include_router(reports_router)
app.include_router(health_router)
app.include_router(health_router) # Has /api/projects/health prefix
app.include_router(custom_fields_router)
app.include_router(task_dependencies_router)
app.include_router(admin_encryption_keys_router.router)
app.include_router(dashboard_router, prefix="/api/dashboard", tags=["Dashboard"])
app.include_router(templates_router) # Has /api/templates prefix in router
# =============================================================================
# Health Check Endpoints
# =============================================================================
def check_database_health() -> dict:
"""Check database connectivity and return status."""
try:
# Execute a simple query to verify connection
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
pool_status = get_pool_status()
return {
"status": "healthy",
"connected": True,
**pool_status,
}
except Exception as e:
return {
"status": "unhealthy",
"connected": False,
"error": str(e),
}
def check_redis_health() -> dict:
"""Check Redis connectivity and return status."""
try:
# Ping Redis to verify connection
redis_client.ping()
redis_fallback = get_redis_fallback_status()
return {
"status": "healthy",
"connected": True,
**redis_fallback,
}
except Exception as e:
redis_fallback = get_redis_fallback_status()
return {
"status": "unhealthy",
"connected": False,
"error": str(e),
**redis_fallback,
}
def check_scheduler_health() -> dict:
"""Check scheduler status and return details."""
try:
running = scheduler.running
jobs = []
if running:
for job in scheduler.get_jobs():
jobs.append({
"id": job.id,
"name": job.name,
"next_run": job.next_run_time.isoformat() if job.next_run_time else None,
})
return {
"status": "healthy" if running else "stopped",
"running": running,
"jobs": jobs,
"job_count": len(jobs),
}
except Exception as e:
return {
"status": "unhealthy",
"running": False,
"error": str(e),
}
@app.get("/health")
async def health_check():
"""Basic health check endpoint for load balancers.
Returns a simple healthy status if the application is running.
For detailed status, use /health/detailed.
"""
return {"status": "healthy"}
@app.get("/health/live")
async def liveness_check():
"""Kubernetes liveness probe endpoint.
Returns healthy if the application process is running.
Does not check external dependencies.
"""
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat() + "Z",
}
@app.get("/health/ready")
async def readiness_check():
"""Kubernetes readiness probe endpoint.
Returns healthy only if all critical dependencies are available.
The application should not receive traffic until ready.
"""
db_health = check_database_health()
redis_health = check_redis_health()
# Application is ready only if database is healthy
# Redis degradation is acceptable (we have fallback)
is_ready = db_health["status"] == "healthy"
return {
"status": "ready" if is_ready else "not_ready",
"timestamp": datetime.utcnow().isoformat() + "Z",
"checks": {
"database": db_health["status"],
"redis": redis_health["status"],
},
}
@app.get("/health/detailed")
async def detailed_health_check():
"""Detailed health check endpoint.
Returns comprehensive status of all system components:
- database: Connection pool status and connectivity
- redis: Connection status and fallback queue status
- storage: File storage validation status
- scheduler: Background job scheduler status
"""
db_health = check_database_health()
redis_health = check_redis_health()
scheduler_health = check_scheduler_health()
storage_status = file_storage_service.get_storage_status()
# Determine overall health
is_healthy = (
db_health["status"] == "healthy" and
storage_status.get("validated", False)
)
# Degraded if Redis or scheduler has issues but DB is ok
is_degraded = (
is_healthy and (
redis_health["status"] != "healthy" or
scheduler_health["status"] != "healthy"
)
)
overall_status = "unhealthy"
if is_healthy:
overall_status = "degraded" if is_degraded else "healthy"
return {
"status": overall_status,
"timestamp": datetime.utcnow().isoformat() + "Z",
"version": app.version,
"components": {
"database": db_health,
"redis": redis_health,
"scheduler": scheduler_health,
"storage": {
"status": "healthy" if storage_status.get("validated", False) else "unhealthy",
**storage_status,
},
},
}

View File

@@ -207,12 +207,16 @@ def check_space_edit_access(user: User, space) -> bool:
def check_project_access(user: User, project) -> bool:
"""
Check if user has access to a project based on security level.
Check if user has access to a project based on security level and membership.
Security Levels:
- public: All logged-in users
- department: Same department users + project owner
- confidential: Only project owner (+ system admin)
Access is granted if any of the following conditions are met:
1. User is a system admin
2. User is the project owner
3. User is an explicit project member (cross-department collaboration)
4. Security level allows access:
- public: All logged-in users
- department: Same department users
- confidential: Only owner/members/admin
"""
# System admin bypasses all restrictions
if user.is_system_admin:
@@ -222,6 +226,13 @@ def check_project_access(user: User, project) -> bool:
if project.owner_id == user.id:
return True
# Check if user is an explicit project member (for cross-department collaboration)
# This allows users from other departments to access the project
if hasattr(project, 'members') and project.members:
for member in project.members:
if member.user_id == user.id:
return True
# Check by security level
security_level = project.security_level
@@ -235,20 +246,34 @@ def check_project_access(user: User, project) -> bool:
return False
else: # confidential
# Only owner has access (already checked above)
# Only owner/members have access (already checked above)
return False
def check_project_edit_access(user: User, project) -> bool:
"""
Check if user can edit/delete a project.
Edit access is granted if:
1. User is a system admin
2. User is the project owner
3. User is a project member with 'admin' role
"""
# System admin has full access
if user.is_system_admin:
return True
# Only owner can edit
return project.owner_id == user.id
# Owner can edit
if project.owner_id == user.id:
return True
# Project member with admin role can edit
if hasattr(project, 'members') and project.members:
for member in project.members:
if member.user_id == user.id and member.role == "admin":
return True
return False
def check_task_access(user: User, task, project) -> bool:

View File

@@ -23,6 +23,8 @@ from app.models.project_health import ProjectHealth, RiskLevel, ScheduleStatus,
from app.models.custom_field import CustomField, FieldType
from app.models.task_custom_value import TaskCustomValue
from app.models.task_dependency import TaskDependency, DependencyType
from app.models.project_member import ProjectMember
from app.models.project_template import ProjectTemplate
__all__ = [
"User", "Role", "Department", "Space", "Project", "TaskStatus", "Task", "WorkloadSnapshot",
@@ -33,5 +35,7 @@ __all__ = [
"ScheduledReport", "ReportType", "ReportHistory", "ReportHistoryStatus",
"ProjectHealth", "RiskLevel", "ScheduleStatus", "ResourceStatus",
"CustomField", "FieldType", "TaskCustomValue",
"TaskDependency", "DependencyType"
"TaskDependency", "DependencyType",
"ProjectMember",
"ProjectTemplate"
]

View File

@@ -42,3 +42,6 @@ class Project(Base):
triggers = relationship("Trigger", back_populates="project", cascade="all, delete-orphan")
health = relationship("ProjectHealth", back_populates="project", uselist=False, cascade="all, delete-orphan")
custom_fields = relationship("CustomField", back_populates="project", cascade="all, delete-orphan")
# Project membership for cross-department collaboration
members = relationship("ProjectMember", back_populates="project", cascade="all, delete-orphan")

View File

@@ -0,0 +1,56 @@
"""ProjectMember model for cross-department project collaboration.
This model tracks explicit project membership, allowing users from different
departments to be granted access to projects they wouldn't normally have
access to based on department isolation rules.
"""
import uuid
from sqlalchemy import Column, String, ForeignKey, DateTime, UniqueConstraint
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.core.database import Base
class ProjectMember(Base):
"""
Represents a user's membership in a project.
This enables cross-department collaboration by explicitly granting
project access to users regardless of their department.
Roles:
- member: Can view and edit tasks
- admin: Can manage project settings and add other members
"""
__tablename__ = "pjctrl_project_members"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(
String(36),
ForeignKey("pjctrl_projects.id", ondelete="CASCADE"),
nullable=False,
index=True
)
user_id = Column(
String(36),
ForeignKey("pjctrl_users.id", ondelete="CASCADE"),
nullable=False,
index=True
)
role = Column(String(50), nullable=False, default="member")
added_by = Column(
String(36),
ForeignKey("pjctrl_users.id"),
nullable=False
)
created_at = Column(DateTime, server_default=func.now(), nullable=False)
# Unique constraint to prevent duplicate memberships
__table_args__ = (
UniqueConstraint('project_id', 'user_id', name='uq_project_member'),
)
# Relationships
project = relationship("Project", back_populates="members")
user = relationship("User", foreign_keys=[user_id], back_populates="project_memberships")
added_by_user = relationship("User", foreign_keys=[added_by])

View File

@@ -0,0 +1,125 @@
"""Project Template model for reusable project configurations.
Allows users to create templates with predefined task statuses and custom fields
that can be used to quickly set up new projects.
"""
import uuid
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, JSON
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from app.core.database import Base
class ProjectTemplate(Base):
"""Template for creating projects with predefined configurations.
A template stores:
- Basic project metadata (name, description)
- Predefined task statuses (stored as JSON)
- Predefined custom field definitions (stored as JSON)
When a project is created from a template, the TaskStatus and CustomField
records are copied to the new project.
"""
__tablename__ = "pjctrl_project_templates"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
name = Column(String(200), nullable=False)
description = Column(Text, nullable=True)
# Template owner
owner_id = Column(String(36), ForeignKey("pjctrl_users.id"), nullable=False)
# Whether the template is available to all users or just the owner
is_public = Column(Boolean, default=False, nullable=False)
# Soft delete flag
is_active = Column(Boolean, default=True, nullable=False)
# Predefined task statuses as JSON array
# Format: [{"name": "To Do", "color": "#808080", "position": 0, "is_done": false}, ...]
task_statuses = Column(JSON, nullable=True)
# Predefined custom field definitions as JSON array
# Format: [{"name": "Priority", "field_type": "dropdown", "options": [...], ...}, ...]
custom_fields = Column(JSON, nullable=True)
# Optional default project settings
default_security_level = Column(String(20), default="department", nullable=True)
created_at = Column(DateTime, server_default=func.now(), nullable=False)
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False)
# Relationships
owner = relationship("User", foreign_keys=[owner_id])
# Default template data for system templates
SYSTEM_TEMPLATES = [
{
"name": "Basic Project",
"description": "A simple project template with standard task statuses.",
"is_public": True,
"task_statuses": [
{"name": "To Do", "color": "#808080", "position": 0, "is_done": False},
{"name": "In Progress", "color": "#0066cc", "position": 1, "is_done": False},
{"name": "Done", "color": "#00cc66", "position": 2, "is_done": True},
],
"custom_fields": [],
},
{
"name": "Software Development",
"description": "Template for software development projects with extended workflow.",
"is_public": True,
"task_statuses": [
{"name": "Backlog", "color": "#808080", "position": 0, "is_done": False},
{"name": "To Do", "color": "#3366cc", "position": 1, "is_done": False},
{"name": "In Progress", "color": "#0066cc", "position": 2, "is_done": False},
{"name": "Code Review", "color": "#cc6600", "position": 3, "is_done": False},
{"name": "Testing", "color": "#9933cc", "position": 4, "is_done": False},
{"name": "Done", "color": "#00cc66", "position": 5, "is_done": True},
],
"custom_fields": [
{
"name": "Story Points",
"field_type": "number",
"is_required": False,
"position": 0,
},
{
"name": "Sprint",
"field_type": "dropdown",
"options": ["Sprint 1", "Sprint 2", "Sprint 3", "Backlog"],
"is_required": False,
"position": 1,
},
],
},
{
"name": "Marketing Campaign",
"description": "Template for marketing campaign management.",
"is_public": True,
"task_statuses": [
{"name": "Planning", "color": "#808080", "position": 0, "is_done": False},
{"name": "Content Creation", "color": "#cc6600", "position": 1, "is_done": False},
{"name": "Review", "color": "#9933cc", "position": 2, "is_done": False},
{"name": "Scheduled", "color": "#0066cc", "position": 3, "is_done": False},
{"name": "Published", "color": "#00cc66", "position": 4, "is_done": True},
],
"custom_fields": [
{
"name": "Channel",
"field_type": "dropdown",
"options": ["Email", "Social Media", "Website", "Print", "Event"],
"is_required": False,
"position": 0,
},
{
"name": "Target Audience",
"field_type": "text",
"is_required": False,
"position": 1,
},
],
},
]

View File

@@ -37,6 +37,9 @@ class Task(Base):
created_at = Column(DateTime, server_default=func.now(), nullable=False)
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False)
# Optimistic locking field
version = Column(Integer, default=1, nullable=False)
# Soft delete fields
is_deleted = Column(Boolean, default=False, nullable=False, index=True)
deleted_at = Column(DateTime, nullable=True)

View File

@@ -18,6 +18,7 @@ class User(Base):
capacity = Column(Numeric(5, 2), default=40.00)
is_active = Column(Boolean, default=True)
is_system_admin = Column(Boolean, default=False)
is_department_manager = Column(Boolean, default=False)
created_at = Column(DateTime, server_default=func.now())
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
@@ -41,3 +42,11 @@ class User(Base):
# Automation relationships
created_triggers = relationship("Trigger", back_populates="creator")
scheduled_reports = relationship("ScheduledReport", back_populates="recipient", cascade="all, delete-orphan")
# Project membership relationships (for cross-department collaboration)
project_memberships = relationship(
"ProjectMember",
foreign_keys="ProjectMember.user_id",
back_populates="user",
cascade="all, delete-orphan"
)

View File

@@ -1,10 +1,10 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing import Optional
class LoginRequest(BaseModel):
email: str
password: str
email: str = Field(..., max_length=255)
password: str = Field(..., min_length=1, max_length=128)
class LoginResponse(BaseModel):

View File

@@ -1,10 +1,10 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
class DepartmentBase(BaseModel):
name: str
name: str = Field(..., min_length=1, max_length=200)
parent_id: Optional[str] = None
@@ -13,7 +13,7 @@ class DepartmentCreate(DepartmentBase):
class DepartmentUpdate(BaseModel):
name: Optional[str] = None
name: Optional[str] = Field(None, min_length=1, max_length=200)
parent_id: Optional[str] = None

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime, date
from decimal import Decimal
@@ -12,9 +12,9 @@ class SecurityLevel(str, Enum):
class ProjectBase(BaseModel):
title: str
description: Optional[str] = None
budget: Optional[Decimal] = None
title: str = Field(..., min_length=1, max_length=500)
description: Optional[str] = Field(None, max_length=10000)
budget: Optional[Decimal] = Field(None, ge=0, le=99999999999)
start_date: Optional[date] = None
end_date: Optional[date] = None
security_level: SecurityLevel = SecurityLevel.DEPARTMENT
@@ -25,13 +25,13 @@ class ProjectCreate(ProjectBase):
class ProjectUpdate(BaseModel):
title: Optional[str] = None
description: Optional[str] = None
budget: Optional[Decimal] = None
title: Optional[str] = Field(None, min_length=1, max_length=500)
description: Optional[str] = Field(None, max_length=10000)
budget: Optional[Decimal] = Field(None, ge=0, le=99999999999)
start_date: Optional[date] = None
end_date: Optional[date] = None
security_level: Optional[SecurityLevel] = None
status: Optional[str] = None
status: Optional[str] = Field(None, max_length=50)
department_id: Optional[str] = None

View File

@@ -0,0 +1,56 @@
"""Project member schemas for cross-department collaboration."""
from pydantic import BaseModel, Field
from typing import Optional, List
from datetime import datetime
from enum import Enum
class ProjectMemberRole(str, Enum):
"""Roles that can be assigned to project members."""
MEMBER = "member"
ADMIN = "admin"
class ProjectMemberBase(BaseModel):
"""Base schema for project member."""
user_id: str = Field(..., description="ID of the user to add as project member")
role: ProjectMemberRole = Field(
default=ProjectMemberRole.MEMBER,
description="Role of the member: 'member' (view/edit tasks) or 'admin' (manage project)"
)
class ProjectMemberCreate(ProjectMemberBase):
"""Schema for creating a project member."""
pass
class ProjectMemberUpdate(BaseModel):
"""Schema for updating a project member."""
role: ProjectMemberRole = Field(..., description="New role for the member")
class ProjectMemberResponse(ProjectMemberBase):
"""Schema for project member response."""
id: str
project_id: str
added_by: str
created_at: datetime
class Config:
from_attributes = True
class ProjectMemberWithDetails(ProjectMemberResponse):
"""Schema for project member with user details."""
user_name: Optional[str] = None
user_email: Optional[str] = None
user_department_id: Optional[str] = None
user_department_name: Optional[str] = None
added_by_name: Optional[str] = None
class ProjectMemberListResponse(BaseModel):
"""Schema for listing project members."""
members: List[ProjectMemberWithDetails]
total: int

View File

@@ -0,0 +1,95 @@
"""Schemas for project template API endpoints."""
from typing import Optional, List, Any
from datetime import datetime
from pydantic import BaseModel, Field
class TaskStatusDefinition(BaseModel):
"""Task status definition for templates."""
name: str = Field(..., min_length=1, max_length=50)
color: str = Field(default="#808080", pattern=r"^#[0-9A-Fa-f]{6}$")
position: int = Field(default=0, ge=0)
is_done: bool = Field(default=False)
class CustomFieldDefinition(BaseModel):
"""Custom field definition for templates."""
name: str = Field(..., min_length=1, max_length=100)
field_type: str = Field(..., pattern=r"^(text|number|dropdown|date|person|formula)$")
options: Optional[List[str]] = None
formula: Optional[str] = None
is_required: bool = Field(default=False)
position: int = Field(default=0, ge=0)
class ProjectTemplateBase(BaseModel):
"""Base schema for project template."""
name: str = Field(..., min_length=1, max_length=200)
description: Optional[str] = None
is_public: bool = Field(default=False)
task_statuses: Optional[List[TaskStatusDefinition]] = None
custom_fields: Optional[List[CustomFieldDefinition]] = None
default_security_level: Optional[str] = Field(
default="department",
pattern=r"^(public|department|confidential)$"
)
class ProjectTemplateCreate(ProjectTemplateBase):
"""Schema for creating a project template."""
pass
class ProjectTemplateUpdate(BaseModel):
"""Schema for updating a project template."""
name: Optional[str] = Field(None, min_length=1, max_length=200)
description: Optional[str] = None
is_public: Optional[bool] = None
task_statuses: Optional[List[TaskStatusDefinition]] = None
custom_fields: Optional[List[CustomFieldDefinition]] = None
default_security_level: Optional[str] = Field(
None,
pattern=r"^(public|department|confidential)$"
)
class ProjectTemplateResponse(ProjectTemplateBase):
"""Schema for project template response."""
id: str
owner_id: str
is_active: bool
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ProjectTemplateWithOwner(ProjectTemplateResponse):
"""Project template response with owner details."""
owner_name: Optional[str] = None
class ProjectTemplateListResponse(BaseModel):
"""Response schema for listing project templates."""
templates: List[ProjectTemplateWithOwner]
total: int
class CreateProjectFromTemplateRequest(BaseModel):
"""Request schema for creating a project from a template."""
template_id: str
title: str = Field(..., min_length=1, max_length=500)
description: Optional[str] = Field(None, max_length=10000)
space_id: str
department_id: Optional[str] = None
class CreateProjectFromTemplateResponse(BaseModel):
"""Response schema for project created from template."""
id: str
title: str
template_id: str
template_name: str
task_statuses_created: int
custom_fields_created: int

View File

@@ -1,11 +1,11 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
class SpaceBase(BaseModel):
name: str
description: Optional[str] = None
name: str = Field(..., min_length=1, max_length=200)
description: Optional[str] = Field(None, max_length=2000)
class SpaceCreate(SpaceBase):
@@ -13,8 +13,8 @@ class SpaceCreate(SpaceBase):
class SpaceUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
name: Optional[str] = Field(None, min_length=1, max_length=200)
description: Optional[str] = Field(None, max_length=2000)
class SpaceResponse(SpaceBase):

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, computed_field
from pydantic import BaseModel, computed_field, Field, field_validator
from typing import Optional, List, Any, Dict
from datetime import datetime
from decimal import Decimal
@@ -28,10 +28,10 @@ class CustomValueResponse(BaseModel):
class TaskBase(BaseModel):
title: str
description: Optional[str] = None
title: str = Field(..., min_length=1, max_length=500)
description: Optional[str] = Field(None, max_length=10000)
priority: Priority = Priority.MEDIUM
original_estimate: Optional[Decimal] = None
original_estimate: Optional[Decimal] = Field(None, ge=0, le=99999)
start_date: Optional[datetime] = None
due_date: Optional[datetime] = None
@@ -44,17 +44,18 @@ class TaskCreate(TaskBase):
class TaskUpdate(BaseModel):
title: Optional[str] = None
description: Optional[str] = None
title: Optional[str] = Field(None, min_length=1, max_length=500)
description: Optional[str] = Field(None, max_length=10000)
priority: Optional[Priority] = None
status_id: Optional[str] = None
assignee_id: Optional[str] = None
original_estimate: Optional[Decimal] = None
time_spent: Optional[Decimal] = None
original_estimate: Optional[Decimal] = Field(None, ge=0, le=99999)
time_spent: Optional[Decimal] = Field(None, ge=0, le=99999)
start_date: Optional[datetime] = None
due_date: Optional[datetime] = None
position: Optional[int] = None
position: Optional[int] = Field(None, ge=0)
custom_values: Optional[List[CustomValueInput]] = None
version: Optional[int] = Field(None, ge=1, description="Version for optimistic locking")
class TaskStatusUpdate(BaseModel):
@@ -77,6 +78,7 @@ class TaskResponse(TaskBase):
created_by: str
created_at: datetime
updated_at: datetime
version: int = 1 # Optimistic locking version
class Config:
from_attributes = True
@@ -100,3 +102,32 @@ class TaskWithDetails(TaskResponse):
class TaskListResponse(BaseModel):
tasks: List[TaskWithDetails]
total: int
class TaskRestoreRequest(BaseModel):
"""Request body for restoring a soft-deleted task."""
cascade: bool = Field(
default=True,
description="If True, also restore child tasks deleted at the same time. If False, restore only the parent task."
)
class TaskRestoreResponse(BaseModel):
"""Response for task restore operation."""
restored_task: TaskResponse
restored_children_count: int = 0
restored_children_ids: List[str] = []
class TaskDeleteWarningResponse(BaseModel):
"""Response when task has unresolved blockers and force_delete is False."""
warning: str
blocker_count: int
message: str = "Task has unresolved blockers. Use force_delete=true to delete anyway."
class TaskDeleteResponse(BaseModel):
"""Response for task delete operation."""
task: TaskResponse
blockers_resolved: int = 0
force_deleted: bool = False

View File

@@ -76,3 +76,46 @@ class DependencyValidationError(BaseModel):
error_type: str # 'circular', 'self_reference', 'duplicate', 'cross_project'
message: str
details: Optional[dict] = None
class BulkDependencyItem(BaseModel):
"""Single dependency item for bulk operations."""
predecessor_id: str
successor_id: str
dependency_type: DependencyType = DependencyType.FS
lag_days: int = 0
@field_validator('lag_days')
@classmethod
def validate_lag_days(cls, v):
if v < -365 or v > 365:
raise ValueError('lag_days must be between -365 and 365')
return v
class BulkDependencyCreate(BaseModel):
"""Schema for creating multiple dependencies at once."""
dependencies: List[BulkDependencyItem]
@field_validator('dependencies')
@classmethod
def validate_dependencies(cls, v):
if not v:
raise ValueError('At least one dependency is required')
if len(v) > 50:
raise ValueError('Cannot create more than 50 dependencies at once')
return v
class BulkDependencyValidationResult(BaseModel):
"""Result of bulk dependency validation."""
valid: bool
errors: List[dict] = []
class BulkDependencyCreateResponse(BaseModel):
"""Response for bulk dependency creation."""
created: List[TaskDependencyResponse]
failed: List[dict] = []
total_created: int
total_failed: int

View File

@@ -1,12 +1,12 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
class TaskStatusBase(BaseModel):
name: str
color: str = "#808080"
position: int = 0
name: str = Field(..., min_length=1, max_length=100)
color: str = Field("#808080", max_length=20)
position: int = Field(0, ge=0)
is_done: bool = False
@@ -15,9 +15,9 @@ class TaskStatusCreate(TaskStatusBase):
class TaskStatusUpdate(BaseModel):
name: Optional[str] = None
color: Optional[str] = None
position: Optional[int] = None
name: Optional[str] = Field(None, min_length=1, max_length=100)
color: Optional[str] = Field(None, max_length=20)
position: Optional[int] = Field(None, ge=0)
is_done: Optional[bool] = None

View File

@@ -1,16 +1,16 @@
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, Field, field_validator
from typing import Optional, List
from datetime import datetime
from decimal import Decimal
class UserBase(BaseModel):
email: str
name: str
email: str = Field(..., max_length=255)
name: str = Field(..., min_length=1, max_length=200)
department_id: Optional[str] = None
role_id: Optional[str] = None
skills: Optional[List[str]] = None
capacity: Optional[Decimal] = Decimal("40.00")
capacity: Optional[Decimal] = Field(Decimal("40.00"), ge=0, le=168)
class UserCreate(UserBase):
@@ -18,11 +18,11 @@ class UserCreate(UserBase):
class UserUpdate(BaseModel):
name: Optional[str] = None
name: Optional[str] = Field(None, min_length=1, max_length=200)
department_id: Optional[str] = None
role_id: Optional[str] = None
skills: Optional[List[str]] = None
capacity: Optional[Decimal] = None
capacity: Optional[Decimal] = Field(None, ge=0, le=168)
is_active: Optional[bool] = None

View File

@@ -6,6 +6,7 @@ Handles task dependency validation including:
- Date constraint validation based on dependency types
- Self-reference prevention
- Cross-project dependency prevention
- Bulk dependency operations with cycle detection
"""
from typing import List, Optional, Set, Tuple, Dict, Any
from collections import defaultdict
@@ -25,6 +26,27 @@ class DependencyValidationError(Exception):
super().__init__(message)
class CycleDetectionResult:
"""Result of cycle detection with detailed path information."""
def __init__(
self,
has_cycle: bool,
cycle_path: Optional[List[str]] = None,
cycle_task_titles: Optional[List[str]] = None
):
self.has_cycle = has_cycle
self.cycle_path = cycle_path or []
self.cycle_task_titles = cycle_task_titles or []
def get_cycle_description(self) -> str:
"""Get a human-readable description of the cycle."""
if not self.has_cycle or not self.cycle_task_titles:
return ""
# Format: Task A -> Task B -> Task C -> Task A
return " -> ".join(self.cycle_task_titles)
class DependencyService:
"""Service for managing task dependencies with validation."""
@@ -53,9 +75,36 @@ class DependencyService:
Returns:
List of task IDs forming the cycle if circular, None otherwise
"""
# If adding predecessor -> successor, check if successor can reach predecessor
# This would mean predecessor depends (transitively) on successor, creating a cycle
result = DependencyService.detect_circular_dependency_detailed(
db, predecessor_id, successor_id, project_id
)
return result.cycle_path if result.has_cycle else None
@staticmethod
def detect_circular_dependency_detailed(
db: Session,
predecessor_id: str,
successor_id: str,
project_id: str,
additional_edges: Optional[List[Tuple[str, str]]] = None
) -> CycleDetectionResult:
"""
Detect if adding a dependency would create a circular reference.
Uses DFS to traverse from the successor to check if we can reach
the predecessor through existing dependencies.
Args:
db: Database session
predecessor_id: The task that must complete first
successor_id: The task that depends on the predecessor
project_id: Project ID to scope the query
additional_edges: Optional list of additional (predecessor_id, successor_id)
edges to consider (for bulk operations)
Returns:
CycleDetectionResult with detailed cycle information
"""
# Build adjacency list for the project's dependencies
dependencies = db.query(TaskDependency).join(
Task, TaskDependency.successor_id == Task.id
@@ -71,6 +120,20 @@ class DependencyService:
# Simulate adding the new edge
graph[successor_id].append(predecessor_id)
# Add any additional edges for bulk operations
if additional_edges:
for pred_id, succ_id in additional_edges:
graph[succ_id].append(pred_id)
# Build task title map for readable error messages
task_ids_in_graph = set()
for succ_id, pred_ids in graph.items():
task_ids_in_graph.add(succ_id)
task_ids_in_graph.update(pred_ids)
tasks = db.query(Task).filter(Task.id.in_(task_ids_in_graph)).all()
task_title_map: Dict[str, str] = {t.id: t.title for t in tasks}
# DFS to find if there's a path from predecessor back to successor
# (which would complete a cycle)
visited: Set[str] = set()
@@ -101,7 +164,18 @@ class DependencyService:
return None
# Start DFS from the successor to check if we can reach back to it
return dfs(successor_id)
cycle_path = dfs(successor_id)
if cycle_path:
# Build task titles for the cycle
cycle_titles = [task_title_map.get(task_id, task_id) for task_id in cycle_path]
return CycleDetectionResult(
has_cycle=True,
cycle_path=cycle_path,
cycle_task_titles=cycle_titles
)
return CycleDetectionResult(has_cycle=False)
@staticmethod
def validate_dependency(
@@ -183,15 +257,19 @@ class DependencyService:
)
# Check circular dependency
cycle = DependencyService.detect_circular_dependency(
cycle_result = DependencyService.detect_circular_dependency_detailed(
db, predecessor_id, successor_id, predecessor.project_id
)
if cycle:
if cycle_result.has_cycle:
raise DependencyValidationError(
error_type="circular",
message="Adding this dependency would create a circular reference",
details={"cycle": cycle}
message=f"Adding this dependency would create a circular reference: {cycle_result.get_cycle_description()}",
details={
"cycle": cycle_result.cycle_path,
"cycle_description": cycle_result.get_cycle_description(),
"cycle_task_titles": cycle_result.cycle_task_titles
}
)
@staticmethod
@@ -422,3 +500,202 @@ class DependencyService:
queue.append(dep.successor_id)
return successors
@staticmethod
def validate_bulk_dependencies(
db: Session,
dependencies: List[Tuple[str, str]],
project_id: str
) -> List[Dict[str, Any]]:
"""
Validate a batch of dependencies for cycle detection.
This method validates multiple dependencies together to detect cycles
that would only appear when all dependencies are added together.
Args:
db: Database session
dependencies: List of (predecessor_id, successor_id) tuples
project_id: Project ID to scope the query
Returns:
List of validation errors (empty if all valid)
"""
errors: List[Dict[str, Any]] = []
if not dependencies:
return errors
# First, validate each dependency individually for basic checks
for predecessor_id, successor_id in dependencies:
# Check self-reference
if predecessor_id == successor_id:
errors.append({
"error_type": "self_reference",
"predecessor_id": predecessor_id,
"successor_id": successor_id,
"message": "A task cannot depend on itself"
})
continue
# Get tasks to validate project membership
predecessor = db.query(Task).filter(Task.id == predecessor_id).first()
successor = db.query(Task).filter(Task.id == successor_id).first()
if not predecessor:
errors.append({
"error_type": "not_found",
"predecessor_id": predecessor_id,
"successor_id": successor_id,
"message": f"Predecessor task not found: {predecessor_id}"
})
continue
if not successor:
errors.append({
"error_type": "not_found",
"predecessor_id": predecessor_id,
"successor_id": successor_id,
"message": f"Successor task not found: {successor_id}"
})
continue
if predecessor.project_id != project_id or successor.project_id != project_id:
errors.append({
"error_type": "cross_project",
"predecessor_id": predecessor_id,
"successor_id": successor_id,
"message": "All tasks must be in the same project"
})
continue
# Check for duplicates within the batch
existing = db.query(TaskDependency).filter(
TaskDependency.predecessor_id == predecessor_id,
TaskDependency.successor_id == successor_id
).first()
if existing:
errors.append({
"error_type": "duplicate",
"predecessor_id": predecessor_id,
"successor_id": successor_id,
"message": "This dependency already exists"
})
# If there are basic validation errors, return them first
if errors:
return errors
# Now check for cycles considering all dependencies together
# Build the graph incrementally and check for cycles
accumulated_edges: List[Tuple[str, str]] = []
for predecessor_id, successor_id in dependencies:
# Check if adding this edge (plus all previously accumulated edges)
# would create a cycle
cycle_result = DependencyService.detect_circular_dependency_detailed(
db,
predecessor_id,
successor_id,
project_id,
additional_edges=accumulated_edges
)
if cycle_result.has_cycle:
errors.append({
"error_type": "circular",
"predecessor_id": predecessor_id,
"successor_id": successor_id,
"message": f"Adding this dependency would create a circular reference: {cycle_result.get_cycle_description()}",
"cycle": cycle_result.cycle_path,
"cycle_description": cycle_result.get_cycle_description(),
"cycle_task_titles": cycle_result.cycle_task_titles
})
else:
# Add this edge to accumulated edges for subsequent checks
accumulated_edges.append((predecessor_id, successor_id))
return errors
@staticmethod
def detect_cycles_in_graph(
db: Session,
project_id: str
) -> List[CycleDetectionResult]:
"""
Detect all cycles in the existing dependency graph for a project.
This is useful for auditing or cleanup operations.
Args:
db: Database session
project_id: Project ID to check
Returns:
List of CycleDetectionResult for each cycle found
"""
cycles: List[CycleDetectionResult] = []
# Get all dependencies for the project
dependencies = db.query(TaskDependency).join(
Task, TaskDependency.successor_id == Task.id
).filter(Task.project_id == project_id).all()
if not dependencies:
return cycles
# Build the graph
graph: Dict[str, List[str]] = defaultdict(list)
for dep in dependencies:
graph[dep.successor_id].append(dep.predecessor_id)
# Get task titles
task_ids = set()
for succ_id, pred_ids in graph.items():
task_ids.add(succ_id)
task_ids.update(pred_ids)
tasks = db.query(Task).filter(Task.id.in_(task_ids)).all()
task_title_map: Dict[str, str] = {t.id: t.title for t in tasks}
# Find all cycles using DFS
visited: Set[str] = set()
found_cycles: Set[Tuple[str, ...]] = set()
def find_cycles_dfs(node: str, path: List[str], in_path: Set[str]):
"""DFS to find all cycles."""
if node in in_path:
# Found a cycle
cycle_start = path.index(node)
cycle = tuple(sorted(path[cycle_start:])) # Normalize for dedup
if cycle not in found_cycles:
found_cycles.add(cycle)
actual_cycle = path[cycle_start:] + [node]
cycle_titles = [task_title_map.get(tid, tid) for tid in actual_cycle]
cycles.append(CycleDetectionResult(
has_cycle=True,
cycle_path=actual_cycle,
cycle_task_titles=cycle_titles
))
return
if node in visited:
return
visited.add(node)
in_path.add(node)
path.append(node)
for neighbor in graph.get(node, []):
find_cycles_dfs(neighbor, path.copy(), in_path.copy())
path.pop()
in_path.remove(node)
# Start DFS from all nodes
for start_node in graph.keys():
if start_node not in visited:
find_cycles_dfs(start_node, [], set())
return cycles

View File

@@ -1,26 +1,271 @@
import os
import hashlib
import shutil
import logging
import tempfile
from pathlib import Path
from typing import BinaryIO, Optional, Tuple
from fastapi import UploadFile, HTTPException
from app.core.config import settings
logger = logging.getLogger(__name__)
class PathTraversalError(Exception):
"""Raised when a path traversal attempt is detected."""
pass
class StorageValidationError(Exception):
"""Raised when storage validation fails."""
pass
class FileStorageService:
"""Service for handling file storage operations."""
# Common NAS mount points to detect
NAS_MOUNT_INDICATORS = [
"/mnt/", "/mount/", "/nas/", "/nfs/", "/smb/", "/cifs/",
"/Volumes/", "/media/", "/srv/", "/storage/"
]
def __init__(self):
self.base_dir = Path(settings.UPLOAD_DIR)
self._ensure_base_dir()
self.base_dir = Path(settings.UPLOAD_DIR).resolve()
self._storage_status = {
"validated": False,
"path_exists": False,
"writable": False,
"is_nas": False,
"error": None,
}
self._validate_storage_on_init()
def _validate_storage_on_init(self):
"""Validate storage configuration on service initialization."""
try:
# Step 1: Ensure directory exists
self._ensure_base_dir()
self._storage_status["path_exists"] = True
# Step 2: Check write permissions
self._check_write_permissions()
self._storage_status["writable"] = True
# Step 3: Check if using NAS
is_nas = self._detect_nas_storage()
self._storage_status["is_nas"] = is_nas
if not is_nas:
logger.warning(
"Storage directory '%s' appears to be local storage, not NAS. "
"Consider configuring UPLOAD_DIR to a NAS mount point for production use.",
self.base_dir
)
self._storage_status["validated"] = True
logger.info(
"Storage validated successfully: path=%s, is_nas=%s",
self.base_dir, is_nas
)
except StorageValidationError as e:
self._storage_status["error"] = str(e)
logger.error("Storage validation failed: %s", e)
raise
except Exception as e:
self._storage_status["error"] = str(e)
logger.error("Unexpected error during storage validation: %s", e)
raise StorageValidationError(f"Storage validation failed: {e}")
def _check_write_permissions(self):
"""Check if the storage directory has write permissions."""
test_file = self.base_dir / f".write_test_{os.getpid()}"
try:
# Try to create and write to a test file
test_file.write_text("write_test")
# Verify we can read it back
content = test_file.read_text()
if content != "write_test":
raise StorageValidationError(
f"Write verification failed for directory: {self.base_dir}"
)
# Clean up
test_file.unlink()
logger.debug("Write permission check passed for %s", self.base_dir)
except PermissionError as e:
raise StorageValidationError(
f"No write permission for storage directory '{self.base_dir}': {e}"
)
except OSError as e:
raise StorageValidationError(
f"Failed to verify write permissions for '{self.base_dir}': {e}"
)
finally:
# Ensure test file is removed even on partial failure
if test_file.exists():
try:
test_file.unlink()
except Exception:
pass
def _detect_nas_storage(self) -> bool:
"""
Detect if the storage directory appears to be on a NAS mount.
This is a best-effort detection based on common mount point patterns.
"""
path_str = str(self.base_dir)
# Check common NAS mount point patterns
for indicator in self.NAS_MOUNT_INDICATORS:
if indicator in path_str:
logger.debug("NAS storage detected: path contains '%s'", indicator)
return True
# Check if it's a mount point (Unix-like systems)
try:
if self.base_dir.is_mount():
logger.debug("NAS storage detected: path is a mount point")
return True
except Exception:
pass
# Check mount info on Linux
try:
with open("/proc/mounts", "r") as f:
mounts = f.read()
if path_str in mounts:
# Check for network filesystem types
for line in mounts.splitlines():
if path_str in line:
fs_type = line.split()[2] if len(line.split()) > 2 else ""
if fs_type in ["nfs", "nfs4", "cifs", "smb", "smbfs"]:
logger.debug("NAS storage detected: mounted as %s", fs_type)
return True
except FileNotFoundError:
pass # Not on Linux
except Exception as e:
logger.debug("Could not check /proc/mounts: %s", e)
return False
def get_storage_status(self) -> dict:
"""Get current storage status for health checks."""
return {
**self._storage_status,
"base_dir": str(self.base_dir),
"exists": self.base_dir.exists(),
"is_directory": self.base_dir.is_dir() if self.base_dir.exists() else False,
}
def _ensure_base_dir(self):
"""Ensure the base upload directory exists."""
self.base_dir.mkdir(parents=True, exist_ok=True)
def _validate_path_component(self, component: str, component_name: str) -> None:
"""
Validate a path component to prevent path traversal attacks.
Args:
component: The path component to validate (e.g., project_id, task_id)
component_name: Name of the component for error messages
Raises:
PathTraversalError: If the component contains path traversal patterns
"""
if not component:
raise PathTraversalError(f"Empty {component_name} is not allowed")
# Check for path traversal patterns
dangerous_patterns = ['..', '/', '\\', '\x00']
for pattern in dangerous_patterns:
if pattern in component:
logger.warning(
"Path traversal attempt detected in %s: %r",
component_name,
component
)
raise PathTraversalError(
f"Invalid characters in {component_name}: path traversal not allowed"
)
# Additional check: component should not start with special characters
if component.startswith('.') or component.startswith('-'):
logger.warning(
"Suspicious path component in %s: %r",
component_name,
component
)
raise PathTraversalError(
f"Invalid {component_name}: cannot start with '.' or '-'"
)
def _validate_path_in_base_dir(self, path: Path, context: str = "") -> Path:
"""
Validate that a resolved path is within the base directory.
Args:
path: The path to validate
context: Additional context for logging
Returns:
The resolved path if valid
Raises:
PathTraversalError: If the path is outside the base directory
"""
resolved_path = path.resolve()
# Check if the resolved path is within the base directory
try:
resolved_path.relative_to(self.base_dir)
except ValueError:
logger.warning(
"Path traversal attempt detected: path %s is outside base directory %s. Context: %s",
resolved_path,
self.base_dir,
context
)
raise PathTraversalError(
"Access denied: path is outside the allowed directory"
)
return resolved_path
def _get_file_path(self, project_id: str, task_id: str, attachment_id: str, version: int) -> Path:
"""Generate the file path for an attachment version."""
return self.base_dir / project_id / task_id / attachment_id / str(version)
"""
Generate the file path for an attachment version.
Args:
project_id: The project ID
task_id: The task ID
attachment_id: The attachment ID
version: The version number
Returns:
Safe path within the base directory
Raises:
PathTraversalError: If any component contains path traversal patterns
"""
# Validate all path components
self._validate_path_component(project_id, "project_id")
self._validate_path_component(task_id, "task_id")
self._validate_path_component(attachment_id, "attachment_id")
if version < 0:
raise PathTraversalError("Version must be non-negative")
# Build the path
path = self.base_dir / project_id / task_id / attachment_id / str(version)
# Validate the final path is within base directory
return self._validate_path_in_base_dir(
path,
f"project={project_id}, task={task_id}, attachment={attachment_id}, version={version}"
)
@staticmethod
def calculate_checksum(file: BinaryIO) -> str:
@@ -89,6 +334,10 @@ class FileStorageService:
"""
Save uploaded file to storage.
Returns (file_path, file_size, checksum).
Raises:
PathTraversalError: If path traversal is detected
HTTPException: If file validation fails
"""
# Validate file
extension, _ = self.validate_file(file)
@@ -96,14 +345,22 @@ class FileStorageService:
# Calculate checksum first
checksum = self.calculate_checksum(file.file)
# Create directory structure
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
# Create directory structure (path validation is done in _get_file_path)
try:
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
except PathTraversalError as e:
logger.error("Path traversal attempt during file save: %s", e)
raise HTTPException(status_code=400, detail=str(e))
dir_path.mkdir(parents=True, exist_ok=True)
# Save file with original extension
filename = f"file.{extension}" if extension else "file"
file_path = dir_path / filename
# Final validation of the file path
self._validate_path_in_base_dir(file_path, f"saving file {filename}")
# Get file size
file.file.seek(0, 2)
file_size = file.file.tell()
@@ -125,8 +382,15 @@ class FileStorageService:
"""
Get the file path for an attachment version.
Returns None if file doesn't exist.
Raises:
PathTraversalError: If path traversal is detected
"""
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
try:
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
except PathTraversalError as e:
logger.error("Path traversal attempt during file retrieval: %s", e)
return None
if not dir_path.exists():
return None
@@ -139,21 +403,48 @@ class FileStorageService:
return files[0]
def get_file_by_path(self, file_path: str) -> Optional[Path]:
"""Get file by stored path. Handles both absolute and relative paths."""
"""
Get file by stored path. Handles both absolute and relative paths.
This method validates that the requested path is within the base directory
to prevent path traversal attacks.
Args:
file_path: The stored file path
Returns:
Path object if file exists and is within base directory, None otherwise
"""
if not file_path:
return None
path = Path(file_path)
# If path is absolute and exists, return it directly
if path.is_absolute() and path.exists():
return path
# For absolute paths, validate they are within base_dir
if path.is_absolute():
try:
validated_path = self._validate_path_in_base_dir(
path,
f"get_file_by_path absolute: {file_path}"
)
if validated_path.exists():
return validated_path
except PathTraversalError:
return None
return None
# If path is relative, try prepending base_dir
# For relative paths, resolve from base_dir
full_path = self.base_dir / path
if full_path.exists():
return full_path
# Fallback: check if original path exists (e.g., relative from current dir)
if path.exists():
return path
try:
validated_path = self._validate_path_in_base_dir(
full_path,
f"get_file_by_path relative: {file_path}"
)
if validated_path.exists():
return validated_path
except PathTraversalError:
return None
return None
@@ -168,13 +459,29 @@ class FileStorageService:
Delete file(s) from storage.
If version is None, deletes all versions.
Returns True if successful.
Raises:
PathTraversalError: If path traversal is detected
"""
if version is not None:
# Delete specific version
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
else:
# Delete all versions (attachment directory)
dir_path = self.base_dir / project_id / task_id / attachment_id
try:
if version is not None:
# Delete specific version
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
else:
# Delete all versions (attachment directory)
# Validate components first
self._validate_path_component(project_id, "project_id")
self._validate_path_component(task_id, "task_id")
self._validate_path_component(attachment_id, "attachment_id")
dir_path = self.base_dir / project_id / task_id / attachment_id
dir_path = self._validate_path_in_base_dir(
dir_path,
f"delete attachment: project={project_id}, task={task_id}, attachment={attachment_id}"
)
except PathTraversalError as e:
logger.error("Path traversal attempt during file deletion: %s", e)
return False
if dir_path.exists():
shutil.rmtree(dir_path)
@@ -182,8 +489,26 @@ class FileStorageService:
return False
def delete_task_files(self, project_id: str, task_id: str) -> bool:
"""Delete all files for a task."""
dir_path = self.base_dir / project_id / task_id
"""
Delete all files for a task.
Raises:
PathTraversalError: If path traversal is detected
"""
try:
# Validate components
self._validate_path_component(project_id, "project_id")
self._validate_path_component(task_id, "task_id")
dir_path = self.base_dir / project_id / task_id
dir_path = self._validate_path_in_base_dir(
dir_path,
f"delete task files: project={project_id}, task={task_id}"
)
except PathTraversalError as e:
logger.error("Path traversal attempt during task file deletion: %s", e)
return False
if dir_path.exists():
shutil.rmtree(dir_path)
return True

View File

@@ -29,7 +29,17 @@ class FormulaError(Exception):
class CircularReferenceError(FormulaError):
"""Exception raised when circular references are detected in formulas."""
pass
def __init__(self, message: str, cycle_path: Optional[List[str]] = None):
super().__init__(message)
self.message = message
self.cycle_path = cycle_path or []
def get_cycle_description(self) -> str:
"""Get a human-readable description of the cycle."""
if not self.cycle_path:
return ""
return " -> ".join(self.cycle_path)
class FormulaService:
@@ -140,24 +150,43 @@ class FormulaService:
field_id: str,
references: Set[str],
visited: Optional[Set[str]] = None,
path: Optional[List[str]] = None,
) -> None:
"""
Check for circular references in formula fields.
Raises CircularReferenceError if a cycle is detected.
Args:
db: Database session
project_id: Project ID to scope the query
field_id: The current field being validated
references: Set of field names referenced in the formula
visited: Set of visited field IDs (for cycle detection)
path: Current path of field names (for error reporting)
"""
if visited is None:
visited = set()
if path is None:
path = []
# Get the current field's name
current_field = db.query(CustomField).filter(
CustomField.id == field_id
).first()
current_field_name = current_field.name if current_field else "unknown"
# Add current field to path if not already there
if current_field_name not in path:
path = path + [current_field_name]
if current_field:
if current_field.name in references:
cycle_path = path + [current_field.name]
raise CircularReferenceError(
f"Circular reference detected: field cannot reference itself"
f"Circular reference detected: field '{current_field.name}' cannot reference itself",
cycle_path=cycle_path
)
# Get all referenced formula fields
@@ -173,22 +202,199 @@ class FormulaService:
for field in formula_fields:
if field.id in visited:
# Found a cycle
cycle_path = path + [field.name]
raise CircularReferenceError(
f"Circular reference detected involving field '{field.name}'"
f"Circular reference detected: {' -> '.join(cycle_path)}",
cycle_path=cycle_path
)
visited.add(field.id)
new_path = path + [field.name]
if field.formula:
nested_refs = FormulaService.extract_field_references(field.formula)
if current_field and current_field.name in nested_refs:
cycle_path = new_path + [current_field.name]
raise CircularReferenceError(
f"Circular reference detected: '{field.name}' references the current field"
f"Circular reference detected: {' -> '.join(cycle_path)}",
cycle_path=cycle_path
)
FormulaService._check_circular_references(
db, project_id, field_id, nested_refs, visited
db, project_id, field_id, nested_refs, visited, new_path
)
@staticmethod
def build_formula_dependency_graph(
db: Session,
project_id: str
) -> Dict[str, Set[str]]:
"""
Build a dependency graph for all formula fields in a project.
Returns a dict where keys are field names and values are sets of
field names that the key field depends on.
Args:
db: Database session
project_id: Project ID to scope the query
Returns:
Dict mapping field names to their dependencies
"""
graph: Dict[str, Set[str]] = {}
formula_fields = db.query(CustomField).filter(
CustomField.project_id == project_id,
CustomField.field_type == "formula",
).all()
for field in formula_fields:
if field.formula:
refs = FormulaService.extract_field_references(field.formula)
# Only include custom field references (not builtin fields)
custom_refs = refs - FormulaService.BUILTIN_FIELDS
graph[field.name] = custom_refs
else:
graph[field.name] = set()
return graph
@staticmethod
def detect_formula_cycles(
db: Session,
project_id: str
) -> List[List[str]]:
"""
Detect all cycles in the formula dependency graph for a project.
This is useful for auditing or cleanup operations.
Args:
db: Database session
project_id: Project ID to check
Returns:
List of cycles, where each cycle is a list of field names
"""
graph = FormulaService.build_formula_dependency_graph(db, project_id)
if not graph:
return []
cycles: List[List[str]] = []
visited: Set[str] = set()
found_cycles: Set[Tuple[str, ...]] = set()
def dfs(node: str, path: List[str], in_path: Set[str]):
"""DFS to find cycles."""
if node in in_path:
# Found a cycle
cycle_start = path.index(node)
cycle = path[cycle_start:] + [node]
# Normalize for deduplication
normalized = tuple(sorted(cycle[:-1]))
if normalized not in found_cycles:
found_cycles.add(normalized)
cycles.append(cycle)
return
if node in visited:
return
if node not in graph:
return
visited.add(node)
in_path.add(node)
path.append(node)
for neighbor in graph.get(node, set()):
dfs(neighbor, path.copy(), in_path.copy())
path.pop()
in_path.discard(node)
for start_node in graph.keys():
if start_node not in visited:
dfs(start_node, [], set())
return cycles
@staticmethod
def validate_formula_with_details(
formula: str,
project_id: str,
db: Session,
current_field_id: Optional[str] = None,
) -> Tuple[bool, Optional[str], Optional[List[str]]]:
"""
Validate a formula expression with detailed error information.
Similar to validate_formula but returns cycle path on circular reference errors.
Args:
formula: The formula expression to validate
project_id: Project ID to scope field lookups
db: Database session
current_field_id: Optional ID of the field being edited (for self-reference check)
Returns:
Tuple of (is_valid, error_message, cycle_path)
"""
if not formula or not formula.strip():
return False, "Formula cannot be empty", None
# Extract field references
references = FormulaService.extract_field_references(formula)
if not references:
return False, "Formula must reference at least one field", None
# Validate syntax by trying to parse
try:
# Replace field references with dummy numbers for syntax check
test_formula = formula
for ref in references:
test_formula = test_formula.replace(f"{{{ref}}}", "1")
# Try to parse and evaluate with safe operations
FormulaService._safe_eval(test_formula)
except Exception as e:
return False, f"Invalid formula syntax: {str(e)}", None
# Separate builtin and custom field references
custom_references = references - FormulaService.BUILTIN_FIELDS
# Validate custom field references exist and are numeric types
if custom_references:
fields = db.query(CustomField).filter(
CustomField.project_id == project_id,
CustomField.name.in_(custom_references),
).all()
found_names = {f.name for f in fields}
missing = custom_references - found_names
if missing:
return False, f"Unknown field references: {', '.join(missing)}", None
# Check field types (must be number or formula)
for field in fields:
if field.field_type not in ("number", "formula"):
return False, f"Field '{field.name}' is not a numeric type", None
# Check for circular references
if current_field_id:
try:
FormulaService._check_circular_references(
db, project_id, current_field_id, references
)
except CircularReferenceError as e:
return False, str(e), e.cycle_path
return True, None, None
@staticmethod
def _safe_eval(expression: str) -> Decimal:
"""

View File

@@ -4,8 +4,10 @@ import re
import asyncio
import logging
import threading
import os
from datetime import datetime, timezone
from typing import List, Optional, Dict, Set
from collections import deque
from sqlalchemy.orm import Session
from sqlalchemy import event
@@ -22,9 +24,152 @@ _pending_publish: Dict[int, List[dict]] = {}
# Track which sessions have handlers registered
_registered_sessions: Set[int] = set()
# Redis fallback queue configuration
REDIS_FALLBACK_MAX_QUEUE_SIZE = int(os.getenv("REDIS_FALLBACK_MAX_QUEUE_SIZE", "1000"))
REDIS_FALLBACK_RETRY_INTERVAL = int(os.getenv("REDIS_FALLBACK_RETRY_INTERVAL", "5")) # seconds
REDIS_FALLBACK_MAX_RETRIES = int(os.getenv("REDIS_FALLBACK_MAX_RETRIES", "10"))
# Redis fallback queue for failed publishes
_redis_fallback_lock = threading.Lock()
_redis_fallback_queue: deque = deque(maxlen=REDIS_FALLBACK_MAX_QUEUE_SIZE)
_redis_retry_timer: Optional[threading.Timer] = None
_redis_available = True
_redis_consecutive_failures = 0
def _add_to_fallback_queue(user_id: str, data: dict, retry_count: int = 0) -> bool:
"""
Add a failed notification to the fallback queue.
Returns True if added successfully, False if queue is full.
"""
global _redis_consecutive_failures
with _redis_fallback_lock:
if len(_redis_fallback_queue) >= REDIS_FALLBACK_MAX_QUEUE_SIZE:
logger.warning(
"Redis fallback queue is full (%d items), dropping notification for user %s",
REDIS_FALLBACK_MAX_QUEUE_SIZE, user_id
)
return False
_redis_fallback_queue.append({
"user_id": user_id,
"data": data,
"retry_count": retry_count,
"queued_at": datetime.now(timezone.utc).isoformat(),
})
_redis_consecutive_failures += 1
queue_size = len(_redis_fallback_queue)
logger.debug("Added notification to fallback queue (size: %d)", queue_size)
# Start retry mechanism if not already running
_ensure_retry_timer_running()
return True
def _ensure_retry_timer_running():
"""Ensure the retry timer is running if there are items in the queue."""
global _redis_retry_timer
if _redis_retry_timer is None or not _redis_retry_timer.is_alive():
_redis_retry_timer = threading.Timer(REDIS_FALLBACK_RETRY_INTERVAL, _process_fallback_queue)
_redis_retry_timer.daemon = True
_redis_retry_timer.start()
def _process_fallback_queue():
"""Process the fallback queue and retry sending notifications to Redis."""
global _redis_available, _redis_consecutive_failures, _redis_retry_timer
items_to_retry = []
with _redis_fallback_lock:
# Get all items from queue
while _redis_fallback_queue:
items_to_retry.append(_redis_fallback_queue.popleft())
if not items_to_retry:
_redis_retry_timer = None
return
logger.info("Processing %d items from Redis fallback queue", len(items_to_retry))
failed_items = []
success_count = 0
for item in items_to_retry:
user_id = item["user_id"]
data = item["data"]
retry_count = item["retry_count"]
if retry_count >= REDIS_FALLBACK_MAX_RETRIES:
logger.warning(
"Notification for user %s exceeded max retries (%d), dropping",
user_id, REDIS_FALLBACK_MAX_RETRIES
)
continue
try:
redis_client = get_redis_sync()
channel = get_channel_name(user_id)
message = json.dumps(data, default=str)
redis_client.publish(channel, message)
success_count += 1
except Exception as e:
logger.debug("Retry failed for user %s: %s", user_id, e)
failed_items.append({
**item,
"retry_count": retry_count + 1,
})
# Re-queue failed items
if failed_items:
with _redis_fallback_lock:
for item in failed_items:
if len(_redis_fallback_queue) < REDIS_FALLBACK_MAX_QUEUE_SIZE:
_redis_fallback_queue.append(item)
# Log recovery if we had successes
if success_count > 0:
with _redis_fallback_lock:
_redis_consecutive_failures = 0
if not _redis_fallback_queue:
_redis_available = True
logger.info(
"Redis connection recovered. Successfully processed %d notifications from fallback queue",
success_count
)
# Schedule next retry if queue is not empty
with _redis_fallback_lock:
if _redis_fallback_queue:
_redis_retry_timer = threading.Timer(REDIS_FALLBACK_RETRY_INTERVAL, _process_fallback_queue)
_redis_retry_timer.daemon = True
_redis_retry_timer.start()
else:
_redis_retry_timer = None
def get_redis_fallback_status() -> dict:
"""Get current Redis fallback queue status for health checks."""
with _redis_fallback_lock:
return {
"queue_size": len(_redis_fallback_queue),
"max_queue_size": REDIS_FALLBACK_MAX_QUEUE_SIZE,
"redis_available": _redis_available,
"consecutive_failures": _redis_consecutive_failures,
"retry_interval_seconds": REDIS_FALLBACK_RETRY_INTERVAL,
"max_retries": REDIS_FALLBACK_MAX_RETRIES,
}
def _sync_publish(user_id: str, data: dict):
"""Sync fallback to publish notification via Redis when no event loop available."""
global _redis_available
try:
redis_client = get_redis_sync()
channel = get_channel_name(user_id)
@@ -33,6 +178,10 @@ def _sync_publish(user_id: str, data: dict):
logger.debug(f"Sync published notification to channel {channel}")
except Exception as e:
logger.error(f"Failed to sync publish notification to Redis: {e}")
# Add to fallback queue for retry
with _redis_fallback_lock:
_redis_available = False
_add_to_fallback_queue(user_id, data)
def _cleanup_session(session_id: int, remove_registration: bool = True):
@@ -86,10 +235,16 @@ def _register_session_handlers(db: Session, session_id: int):
async def _async_publish(user_id: str, data: dict):
"""Async helper to publish notification to Redis."""
global _redis_available
try:
await redis_publish(user_id, data)
except Exception as e:
logger.error(f"Failed to publish notification to Redis: {e}")
# Add to fallback queue for retry
with _redis_fallback_lock:
_redis_available = False
_add_to_fallback_queue(user_id, data)
class NotificationService:

View File

@@ -7,6 +7,7 @@ scheduled triggers based on their cron schedule, including deadline reminders.
import uuid
import logging
import time
from datetime import datetime, timezone, timedelta
from typing import Optional, List, Dict, Any, Tuple, Set
@@ -22,6 +23,10 @@ logger = logging.getLogger(__name__)
# Key prefix for tracking deadline reminders already sent
DEADLINE_REMINDER_LOG_TYPE = "deadline_reminder"
# Retry configuration
MAX_RETRIES = 3
BASE_DELAY_SECONDS = 1 # 1s, 2s, 4s exponential backoff
class TriggerSchedulerService:
"""Service for scheduling and executing cron-based triggers."""
@@ -220,50 +225,170 @@ class TriggerSchedulerService:
@staticmethod
def _execute_trigger(db: Session, trigger: Trigger) -> TriggerLog:
"""
Execute a scheduled trigger's actions.
Execute a scheduled trigger's actions with retry mechanism.
Implements exponential backoff retry (1s, 2s, 4s) for transient failures.
After max retries are exhausted, marks as permanently failed and sends alert.
Args:
db: Database session
trigger: The trigger to execute
Returns:
TriggerLog entry for this execution
"""
return TriggerSchedulerService._execute_trigger_with_retry(
db=db,
trigger=trigger,
task_id=None,
log_type="schedule",
)
@staticmethod
def _execute_trigger_with_retry(
db: Session,
trigger: Trigger,
task_id: Optional[str] = None,
log_type: str = "schedule",
) -> TriggerLog:
"""
Execute trigger actions with exponential backoff retry.
Args:
db: Database session
trigger: The trigger to execute
task_id: Optional task ID for context (deadline reminders)
log_type: Type of trigger execution for logging
Returns:
TriggerLog entry for this execution
"""
actions = trigger.actions if isinstance(trigger.actions, list) else [trigger.actions]
executed_actions = []
error_message = None
last_error = None
attempt = 0
try:
for action in actions:
action_type = action.get("type")
while attempt < MAX_RETRIES:
attempt += 1
executed_actions = []
last_error = None
if action_type == "notify":
TriggerSchedulerService._execute_notify_action(db, action, trigger)
executed_actions.append({"type": action_type, "status": "success"})
try:
logger.info(
f"Executing trigger {trigger.id} (attempt {attempt}/{MAX_RETRIES})"
)
# Add more action types here as needed
for action in actions:
action_type = action.get("type")
status = "success"
if action_type == "notify":
TriggerSchedulerService._execute_notify_action(db, action, trigger)
executed_actions.append({"type": action_type, "status": "success"})
except Exception as e:
status = "failed"
error_message = str(e)
executed_actions.append({"type": "error", "message": str(e)})
logger.error(f"Error executing trigger {trigger.id} actions: {e}")
# Add more action types here as needed
# Success - return log
logger.info(f"Trigger {trigger.id} executed successfully on attempt {attempt}")
return TriggerSchedulerService._log_execution(
db=db,
trigger=trigger,
status="success",
details={
"trigger_name": trigger.name,
"trigger_type": log_type,
"cron_expression": trigger.conditions.get("cron_expression") if trigger.conditions else None,
"actions_executed": executed_actions,
"attempts": attempt,
},
error_message=None,
task_id=task_id,
)
except Exception as e:
last_error = e
executed_actions.append({"type": "error", "message": str(e)})
logger.warning(
f"Trigger {trigger.id} failed on attempt {attempt}/{MAX_RETRIES}: {e}"
)
# Calculate exponential backoff delay
if attempt < MAX_RETRIES:
delay = BASE_DELAY_SECONDS * (2 ** (attempt - 1))
logger.info(f"Retrying trigger {trigger.id} in {delay}s...")
time.sleep(delay)
# All retries exhausted - permanent failure
logger.error(
f"Trigger {trigger.id} permanently failed after {MAX_RETRIES} attempts: {last_error}"
)
# Send alert notification for permanent failure
TriggerSchedulerService._send_failure_alert(db, trigger, str(last_error), MAX_RETRIES)
return TriggerSchedulerService._log_execution(
db=db,
trigger=trigger,
status=status,
status="permanently_failed",
details={
"trigger_name": trigger.name,
"trigger_type": "schedule",
"cron_expression": trigger.conditions.get("cron_expression"),
"trigger_type": log_type,
"cron_expression": trigger.conditions.get("cron_expression") if trigger.conditions else None,
"actions_executed": executed_actions,
"attempts": MAX_RETRIES,
"permanent_failure": True,
},
error_message=error_message,
error_message=f"Failed after {MAX_RETRIES} retries: {last_error}",
task_id=task_id,
)
@staticmethod
def _send_failure_alert(
db: Session,
trigger: Trigger,
error_message: str,
attempts: int,
) -> None:
"""
Send alert notification when trigger exhausts all retries.
Args:
db: Database session
trigger: The failed trigger
error_message: The last error message
attempts: Number of attempts made
"""
try:
# Notify the project owner about the failure
project = trigger.project
if not project:
logger.warning(f"Cannot send failure alert: trigger {trigger.id} has no project")
return
target_user_id = project.owner_id
if not target_user_id:
logger.warning(f"Cannot send failure alert: project {project.id} has no owner")
return
message = (
f"Trigger '{trigger.name}' has permanently failed after {attempts} attempts. "
f"Last error: {error_message}"
)
NotificationService.create_notification(
db=db,
user_id=target_user_id,
notification_type="trigger_failure",
reference_type="trigger",
reference_id=trigger.id,
title=f"Trigger Failed: {trigger.name}",
message=message,
)
logger.info(f"Sent failure alert for trigger {trigger.id} to user {target_user_id}")
except Exception as e:
logger.error(f"Failed to send failure alert for trigger {trigger.id}: {e}")
@staticmethod
def _execute_notify_action(db: Session, action: Dict[str, Any], trigger: Trigger) -> None:
"""