feat: complete issue fixes and implement remaining features

## Critical Issues (CRIT-001~003) - All Fixed
- JWT secret key validation with pydantic field_validator
- Login audit logging for success/failure attempts
- Frontend API path prefix removal

## High Priority Issues (HIGH-001~008) - All Fixed
- Project soft delete using is_active flag
- Redis session token bytes handling
- Rate limiting with slowapi (5 req/min for login)
- Attachment API permission checks
- Kanban view with drag-and-drop
- Workload heatmap UI (WorkloadPage, WorkloadHeatmap)
- TaskDetailModal integrating Comments/Attachments
- UserSelect component for task assignment

## Medium Priority Issues (MED-001~012) - All Fixed
- MED-001~005: DB commits, N+1 queries, datetime, error format, blocker flag
- MED-006: Project health dashboard (HealthService, ProjectHealthPage)
- MED-007: Capacity update API (PUT /api/users/{id}/capacity)
- MED-008: Schedule triggers (cron parsing, deadline reminders)
- MED-009: Watermark feature (image/PDF watermarking)
- MED-010~012: useEffect deps, DOM operations, PDF export

## New Files
- backend/app/api/health/ - Project health API
- backend/app/services/health_service.py
- backend/app/services/trigger_scheduler.py
- backend/app/services/watermark_service.py
- backend/app/core/rate_limiter.py
- frontend/src/pages/ProjectHealthPage.tsx
- frontend/src/components/ProjectHealthCard.tsx
- frontend/src/components/KanbanBoard.tsx
- frontend/src/components/WorkloadHeatmap.tsx

## Tests
- 113 new tests passing (health: 32, users: 14, triggers: 35, watermark: 32)

## OpenSpec Archives
- add-project-health-dashboard
- add-capacity-update-api
- add-schedule-triggers
- add-watermark-feature
- add-rate-limiting
- enhance-frontend-ux
- add-resource-management-ui

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-04 21:49:52 +08:00
parent 64874d5425
commit 9b220523ff
90 changed files with 9426 additions and 194 deletions

View File

@@ -1,38 +1,74 @@
import uuid
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Request
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, Response
from sqlalchemy.orm import Session
from typing import Optional
from app.core.database import get_db
from app.middleware.auth import get_current_user
from app.models import User, Task, Attachment, AttachmentVersion, AuditAction
from app.middleware.auth import get_current_user, check_task_access, check_task_edit_access
from app.models import User, Task, Project, Attachment, AttachmentVersion, AuditAction
from app.schemas.attachment import (
AttachmentResponse, AttachmentListResponse, AttachmentDetailResponse,
AttachmentVersionResponse, VersionHistoryResponse
)
from app.services.file_storage_service import file_storage_service
from app.services.audit_service import AuditService
from app.services.watermark_service import watermark_service
router = APIRouter(prefix="/api", tags=["attachments"])
def get_task_or_404(db: Session, task_id: str) -> Task:
"""Get task or raise 404."""
def get_task_with_access_check(db: Session, task_id: str, current_user: User, require_edit: bool = False) -> Task:
"""Get task and verify access permissions."""
task = db.query(Task).filter(Task.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
# Get project for access check
project = db.query(Project).filter(Project.id == task.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check access permission
if not check_task_access(current_user, task, project):
raise HTTPException(status_code=403, detail="Access denied to this task")
# Check edit permission if required
if require_edit and not check_task_edit_access(current_user, task, project):
raise HTTPException(status_code=403, detail="Edit access denied to this task")
return task
def get_attachment_or_404(db: Session, attachment_id: str) -> Attachment:
"""Get attachment or raise 404."""
def get_attachment_with_access_check(
db: Session, attachment_id: str, current_user: User, require_edit: bool = False
) -> Attachment:
"""Get attachment and verify access permissions."""
attachment = db.query(Attachment).filter(
Attachment.id == attachment_id,
Attachment.is_deleted == False
).first()
if not attachment:
raise HTTPException(status_code=404, detail="Attachment not found")
# Get task and project for access check
task = db.query(Task).filter(Task.id == attachment.task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
project = db.query(Project).filter(Project.id == task.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check access permission
if not check_task_access(current_user, task, project):
raise HTTPException(status_code=403, detail="Access denied to this attachment")
# Check edit permission if required
if require_edit and not check_task_edit_access(current_user, task, project):
raise HTTPException(status_code=403, detail="Edit access denied to this attachment")
return attachment
@@ -76,7 +112,7 @@ async def upload_attachment(
current_user: User = Depends(get_current_user)
):
"""Upload a file attachment to a task."""
task = get_task_or_404(db, task_id)
task = get_task_with_access_check(db, task_id, current_user, require_edit=True)
# Check if attachment with same filename exists (for versioning in Phase 2)
existing = db.query(Attachment).filter(
@@ -115,9 +151,6 @@ async def upload_attachment(
existing.file_size = file_size
existing.updated_at = version.created_at
db.commit()
db.refresh(existing)
# Audit log
AuditService.log_event(
db=db,
@@ -129,7 +162,9 @@ async def upload_attachment(
changes=[{"field": "version", "old_value": new_version - 1, "new_value": new_version}],
request_metadata=getattr(request.state, "audit_metadata", None)
)
db.commit()
db.refresh(existing)
return attachment_to_response(existing)
@@ -175,9 +210,6 @@ async def upload_attachment(
)
db.add(version)
db.commit()
db.refresh(attachment)
# Audit log
AuditService.log_event(
db=db,
@@ -189,7 +221,9 @@ async def upload_attachment(
changes=[{"field": "filename", "old_value": None, "new_value": attachment.filename}],
request_metadata=getattr(request.state, "audit_metadata", None)
)
db.commit()
db.refresh(attachment)
return attachment_to_response(attachment)
@@ -201,7 +235,7 @@ async def list_task_attachments(
current_user: User = Depends(get_current_user)
):
"""List all attachments for a task."""
task = get_task_or_404(db, task_id)
task = get_task_with_access_check(db, task_id, current_user, require_edit=False)
attachments = db.query(Attachment).filter(
Attachment.task_id == task_id,
@@ -221,7 +255,7 @@ async def get_attachment(
current_user: User = Depends(get_current_user)
):
"""Get attachment details with version history."""
attachment = get_attachment_or_404(db, attachment_id)
attachment = get_attachment_with_access_check(db, attachment_id, current_user, require_edit=False)
versions = db.query(AttachmentVersion).filter(
AttachmentVersion.attachment_id == attachment_id
@@ -252,8 +286,8 @@ async def download_attachment(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Download an attachment file."""
attachment = get_attachment_or_404(db, attachment_id)
"""Download an attachment file with dynamic watermark."""
attachment = get_attachment_with_access_check(db, attachment_id, current_user, require_edit=False)
# Get version to download
target_version = version or attachment.current_version
@@ -272,6 +306,7 @@ async def download_attachment(
raise HTTPException(status_code=404, detail="File not found on disk")
# Audit log
download_time = datetime.now()
AuditService.log_event(
db=db,
event_type="attachment.download",
@@ -284,6 +319,63 @@ async def download_attachment(
)
db.commit()
# Check if watermark should be applied
mime_type = attachment.mime_type or ""
if watermark_service.supports_watermark(mime_type):
try:
# Read the original file
with open(file_path, "rb") as f:
file_bytes = f.read()
# Apply watermark based on file type
if watermark_service.is_supported_image(mime_type):
watermarked_bytes, output_format = watermark_service.add_image_watermark(
image_bytes=file_bytes,
user_name=current_user.name,
employee_id=current_user.employee_id,
download_time=download_time
)
# Update mime type based on output format
output_mime_type = f"image/{output_format}"
# Update filename extension if format changed
original_filename = attachment.original_filename
if output_format == "png" and not original_filename.lower().endswith(".png"):
original_filename = original_filename.rsplit(".", 1)[0] + ".png"
return Response(
content=watermarked_bytes,
media_type=output_mime_type,
headers={
"Content-Disposition": f'attachment; filename="{original_filename}"'
}
)
elif watermark_service.is_supported_pdf(mime_type):
watermarked_bytes = watermark_service.add_pdf_watermark(
pdf_bytes=file_bytes,
user_name=current_user.name,
employee_id=current_user.employee_id,
download_time=download_time
)
return Response(
content=watermarked_bytes,
media_type="application/pdf",
headers={
"Content-Disposition": f'attachment; filename="{attachment.original_filename}"'
}
)
except Exception as e:
# If watermarking fails, log the error but still return the original file
# This ensures users can still download files even if watermarking has issues
import logging
logging.getLogger(__name__).warning(
f"Watermarking failed for attachment {attachment_id}: {str(e)}. "
"Returning original file."
)
# Return original file without watermark for unsupported types or on error
return FileResponse(
path=str(file_path),
filename=attachment.original_filename,
@@ -299,11 +391,10 @@ async def delete_attachment(
current_user: User = Depends(get_current_user)
):
"""Soft delete an attachment."""
attachment = get_attachment_or_404(db, attachment_id)
attachment = get_attachment_with_access_check(db, attachment_id, current_user, require_edit=True)
# Soft delete
attachment.is_deleted = True
db.commit()
# Audit log
AuditService.log_event(
@@ -316,9 +407,10 @@ async def delete_attachment(
changes=[{"field": "is_deleted", "old_value": False, "new_value": True}],
request_metadata=getattr(request.state, "audit_metadata", None)
)
db.commit()
return {"message": "Attachment deleted", "id": attachment_id}
return {"detail": "Attachment deleted", "id": attachment_id}
@router.get("/attachments/{attachment_id}/versions", response_model=VersionHistoryResponse)
@@ -328,7 +420,7 @@ async def get_version_history(
current_user: User = Depends(get_current_user)
):
"""Get version history for an attachment."""
attachment = get_attachment_or_404(db, attachment_id)
attachment = get_attachment_with_access_check(db, attachment_id, current_user, require_edit=False)
versions = db.query(AttachmentVersion).filter(
AttachmentVersion.attachment_id == attachment_id
@@ -351,7 +443,7 @@ async def restore_version(
current_user: User = Depends(get_current_user)
):
"""Restore an attachment to a specific version."""
attachment = get_attachment_or_404(db, attachment_id)
attachment = get_attachment_with_access_check(db, attachment_id, current_user, require_edit=True)
version_record = db.query(AttachmentVersion).filter(
AttachmentVersion.attachment_id == attachment_id,
@@ -364,7 +456,6 @@ async def restore_version(
old_version = attachment.current_version
attachment.current_version = version
attachment.file_size = version_record.file_size
db.commit()
# Audit log
AuditService.log_event(
@@ -377,6 +468,7 @@ async def restore_version(
changes=[{"field": "current_version", "old_value": old_version, "new_value": version}],
request_metadata=getattr(request.state, "audit_metadata", None)
)
db.commit()
return {"message": f"Restored to version {version}", "current_version": version}
return {"detail": f"Restored to version {version}", "current_version": version}

View File

@@ -1,6 +1,6 @@
import csv
import io
from datetime import datetime
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import StreamingResponse
@@ -191,7 +191,7 @@ async def export_audit_logs(
output.seek(0)
filename = f"audit_logs_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.csv"
filename = f"audit_logs_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}.csv"
return StreamingResponse(
iter([output.getvalue()]),

View File

@@ -1,53 +1,86 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.database import get_db
from app.core.security import create_access_token, create_token_payload
from app.core.redis import get_redis
from app.core.rate_limiter import limiter
from app.models.user import User
from app.models.audit_log import AuditAction
from app.schemas.auth import LoginRequest, LoginResponse, UserInfo
from app.services.auth_client import (
verify_credentials,
AuthAPIError,
AuthAPIConnectionError,
)
from app.services.audit_service import AuditService
from app.middleware.auth import get_current_user
router = APIRouter()
@router.post("/login", response_model=LoginResponse)
@limiter.limit("5/minute")
async def login(
request: LoginRequest,
request: Request,
login_request: LoginRequest,
db: Session = Depends(get_db),
redis_client=Depends(get_redis),
):
"""
Authenticate user via external API and return JWT token.
"""
# Prepare metadata for audit logging
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
try:
# Verify credentials with external API
auth_result = await verify_credentials(request.email, request.password)
auth_result = await verify_credentials(login_request.email, login_request.password)
except AuthAPIConnectionError:
# Log failed login attempt due to service unavailable
AuditService.log_event(
db=db,
event_type="user.login_failed",
resource_type="user",
action=AuditAction.LOGIN,
user_id=None,
resource_id=None,
changes={"email": login_request.email, "reason": "auth_service_unavailable"},
request_metadata={"ip_address": client_ip, "user_agent": user_agent},
)
db.commit()
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Authentication service temporarily unavailable",
)
except AuthAPIError as e:
# Log failed login attempt due to invalid credentials
AuditService.log_event(
db=db,
event_type="user.login_failed",
resource_type="user",
action=AuditAction.LOGIN,
user_id=None,
resource_id=None,
changes={"email": login_request.email, "reason": "invalid_credentials"},
request_metadata={"ip_address": client_ip, "user_agent": user_agent},
)
db.commit()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
)
# Find or create user in local database
user = db.query(User).filter(User.email == request.email).first()
user = db.query(User).filter(User.email == login_request.email).first()
if not user:
# Create new user based on auth API response
user = User(
email=request.email,
name=auth_result.get("name", request.email.split("@")[0]),
email=login_request.email,
name=auth_result.get("name", login_request.email.split("@")[0]),
is_active=True,
)
db.add(user)
@@ -82,6 +115,19 @@ async def login(
access_token,
)
# Log successful login
AuditService.log_event(
db=db,
event_type="user.login",
resource_type="user",
action=AuditAction.LOGIN,
user_id=user.id,
resource_id=user.id,
changes=None,
request_metadata={"ip_address": client_ip, "user_agent": user_agent},
)
db.commit()
return LoginResponse(
access_token=access_token,
user=UserInfo(
@@ -106,7 +152,7 @@ async def logout(
# Remove session from Redis
redis_client.delete(f"session:{current_user.id}")
return {"message": "Successfully logged out"}
return {"detail": "Successfully logged out"}
@router.get("/me", response_model=UserInfo)

View File

@@ -1,5 +1,5 @@
import uuid
from datetime import datetime
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
@@ -138,7 +138,8 @@ async def resolve_blocker(
# Update blocker
blocker.resolved_by = current_user.id
blocker.resolution_note = resolve_data.resolution_note
blocker.resolved_at = datetime.utcnow()
# Use naive datetime for consistency with database storage
blocker.resolved_at = datetime.now(timezone.utc).replace(tzinfo=None)
# Check if there are other unresolved blockers
other_blockers = db.query(Blocker).filter(

View File

@@ -0,0 +1,3 @@
from app.api.health.router import router
__all__ = ["router"]

View File

@@ -0,0 +1,70 @@
"""Project health API endpoints.
Provides endpoints for retrieving project health metrics
and dashboard information.
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models import User
from app.schemas.project_health import (
ProjectHealthWithDetails,
ProjectHealthDashboardResponse,
)
from app.services.health_service import HealthService
from app.middleware.auth import get_current_user
router = APIRouter(prefix="/api/projects/health", tags=["Project Health"])
@router.get("/dashboard", response_model=ProjectHealthDashboardResponse)
async def get_health_dashboard(
status_filter: Optional[str] = "active",
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Get health dashboard for all projects.
Returns aggregated health metrics and summary statistics
for all projects matching the status filter.
- **status_filter**: Filter projects by status (default: "active")
Returns:
- **projects**: List of project health details
- **summary**: Aggregated summary statistics
"""
service = HealthService(db)
return service.get_dashboard(status_filter=status_filter)
@router.get("/{project_id}", response_model=ProjectHealthWithDetails)
async def get_project_health(
project_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Get health information for a specific project.
Returns detailed health metrics including risk level,
schedule status, resource status, and task statistics.
- **project_id**: UUID of the project
Raises:
- **404**: Project not found
"""
service = HealthService(db)
result = service.get_project_health(project_id)
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Project not found"
)
return result

View File

@@ -1,5 +1,5 @@
from typing import Optional
from datetime import datetime
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
@@ -91,7 +91,8 @@ async def mark_as_read(
if not notification.is_read:
notification.is_read = True
notification.read_at = datetime.utcnow()
# Use naive datetime for consistency with database storage
notification.read_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.commit()
db.refresh(notification)
@@ -104,7 +105,8 @@ async def mark_all_as_read(
current_user: User = Depends(get_current_user),
):
"""Mark all notifications as read."""
now = datetime.utcnow()
# Use naive datetime for consistency with database storage
now = datetime.now(timezone.utc).replace(tzinfo=None)
updated_count = db.query(Notification).filter(
Notification.user_id == current_user.id,

View File

@@ -273,9 +273,9 @@ async def delete_project(
current_user: User = Depends(get_current_user),
):
"""
Delete a project (hard delete, cascades to tasks).
Delete a project (soft delete - sets is_active to False).
"""
project = db.query(Project).filter(Project.id == project_id).first()
project = db.query(Project).filter(Project.id == project_id, Project.is_active == True).first()
if not project:
raise HTTPException(
@@ -289,7 +289,7 @@ async def delete_project(
detail="Only project owner can delete",
)
# Audit log before deletion (this is a high-sensitivity event that triggers alert)
# Audit log before soft deletion (this is a high-sensitivity event that triggers alert)
AuditService.log_event(
db=db,
event_type="project.delete",
@@ -297,11 +297,12 @@ async def delete_project(
action=AuditAction.DELETE,
user_id=current_user.id,
resource_id=project.id,
changes=[{"field": "title", "old_value": project.title, "new_value": None}],
changes=[{"field": "is_active", "old_value": True, "new_value": False}],
request_metadata=get_audit_metadata(request),
)
db.delete(project)
# Soft delete - set is_active to False
project.is_active = False
db.commit()
return None

View File

@@ -1,11 +1,11 @@
import uuid
from datetime import datetime
from datetime import datetime, timezone
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.models import User, Project, Task, TaskStatus, AuditAction
from app.models import User, Project, Task, TaskStatus, AuditAction, Blocker
from app.schemas.task import (
TaskCreate, TaskUpdate, TaskResponse, TaskWithDetails, TaskListResponse,
TaskStatusUpdate, TaskAssignUpdate
@@ -374,7 +374,8 @@ async def delete_task(
detail="Permission denied",
)
now = datetime.utcnow()
# Use naive datetime for consistency with database storage
now = datetime.now(timezone.utc).replace(tzinfo=None)
# Soft delete the task
task.is_deleted = True
@@ -504,11 +505,18 @@ async def update_task_status(
task.status_id = status_data.status_id
# Auto-set blocker_flag based on status name
# Auto-set blocker_flag based on status name and actual Blocker records
if new_status.name.lower() == "blocked":
task.blocker_flag = True
else:
task.blocker_flag = False
# Only set blocker_flag = False if there are no unresolved blockers
unresolved_blockers = db.query(Blocker).filter(
Blocker.task_id == task.id,
Blocker.resolved_at == None,
).count()
if unresolved_blockers == 0:
task.blocker_flag = False
# If there are unresolved blockers, keep blocker_flag as is
# Evaluate triggers for status changes
if old_status_id != status_data.status_id:

View File

@@ -10,6 +10,7 @@ from app.schemas.trigger import (
TriggerLogResponse, TriggerLogListResponse, TriggerUserInfo
)
from app.middleware.auth import get_current_user, check_project_access, check_project_edit_access
from app.services.trigger_scheduler import TriggerSchedulerService
router = APIRouter(tags=["triggers"])
@@ -65,18 +66,50 @@ async def create_trigger(
detail="Invalid trigger type. Must be 'field_change' or 'schedule'",
)
# Validate conditions
if trigger_data.conditions.field not in ["status_id", "assignee_id", "priority"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid condition field. Must be 'status_id', 'assignee_id', or 'priority'",
)
# Validate conditions based on trigger type
if trigger_data.trigger_type == "field_change":
# Validate field_change conditions
if not trigger_data.conditions.field:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Field is required for field_change triggers",
)
if trigger_data.conditions.field not in ["status_id", "assignee_id", "priority"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid condition field. Must be 'status_id', 'assignee_id', or 'priority'",
)
if not trigger_data.conditions.operator:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Operator is required for field_change triggers",
)
if trigger_data.conditions.operator not in ["equals", "not_equals", "changed_to", "changed_from"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid operator. Must be 'equals', 'not_equals', 'changed_to', or 'changed_from'",
)
elif trigger_data.trigger_type == "schedule":
# Validate schedule conditions
has_cron = trigger_data.conditions.cron_expression is not None
has_deadline = trigger_data.conditions.deadline_reminder_days is not None
if trigger_data.conditions.operator not in ["equals", "not_equals", "changed_to", "changed_from"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid operator. Must be 'equals', 'not_equals', 'changed_to', or 'changed_from'",
)
if not has_cron and not has_deadline:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Schedule triggers require either cron_expression or deadline_reminder_days",
)
# Validate cron expression if provided
if has_cron:
is_valid, error_msg = TriggerSchedulerService.parse_cron_expression(
trigger_data.conditions.cron_expression
)
if not is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_msg or "Invalid cron expression",
)
# Create trigger
trigger = Trigger(
@@ -186,13 +219,25 @@ async def update_trigger(
if trigger_data.description is not None:
trigger.description = trigger_data.description
if trigger_data.conditions is not None:
# Validate conditions
if trigger_data.conditions.field not in ["status_id", "assignee_id", "priority"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid condition field",
)
trigger.conditions = trigger_data.conditions.model_dump()
# Validate conditions based on trigger type
if trigger.trigger_type == "field_change":
if trigger_data.conditions.field and trigger_data.conditions.field not in ["status_id", "assignee_id", "priority"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid condition field",
)
elif trigger.trigger_type == "schedule":
# Validate cron expression if provided
if trigger_data.conditions.cron_expression is not None:
is_valid, error_msg = TriggerSchedulerService.parse_cron_expression(
trigger_data.conditions.cron_expression
)
if not is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_msg or "Invalid cron expression",
)
trigger.conditions = trigger_data.conditions.model_dump(exclude_none=True)
if trigger_data.actions is not None:
trigger.actions = [a.model_dump() for a in trigger_data.actions]
if trigger_data.is_active is not None:

View File

@@ -4,10 +4,11 @@ from sqlalchemy import or_
from typing import List
from app.core.database import get_db
from app.core.redis import get_redis
from app.models.user import User
from app.models.role import Role
from app.models import AuditAction
from app.schemas.user import UserResponse, UserUpdate
from app.schemas.user import UserResponse, UserUpdate, CapacityUpdate
from app.middleware.auth import (
get_current_user,
require_permission,
@@ -239,3 +240,86 @@ async def set_admin_status(
db.commit()
db.refresh(user)
return user
@router.put("/{user_id}/capacity", response_model=UserResponse)
async def update_user_capacity(
user_id: str,
capacity: CapacityUpdate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
redis_client=Depends(get_redis),
):
"""
Update user's weekly capacity hours.
Permission: admin, manager, or the user themselves can update capacity.
- Admin/Manager can update any user's capacity
- Regular users can only update their own capacity
Capacity changes are recorded in the audit trail and workload cache is invalidated.
"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
# Permission check: admin, manager, or the user themselves can update capacity
is_self = current_user.id == user_id
is_admin = current_user.is_system_admin
is_manager = False
# Check if current user has manager role
if current_user.role and current_user.role.name == "manager":
is_manager = True
if not is_self and not is_admin and not is_manager:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only admin, manager, or the user themselves can update capacity",
)
# Store old capacity for audit log
old_capacity = float(user.capacity) if user.capacity else None
# Update capacity (validation is handled by Pydantic schema)
user.capacity = capacity.capacity_hours
new_capacity = float(capacity.capacity_hours)
# Record capacity change in audit trail
if old_capacity != new_capacity:
AuditService.log_event(
db=db,
event_type="user.capacity_change",
resource_type="user",
action=AuditAction.UPDATE,
user_id=current_user.id,
resource_id=user.id,
changes=[{
"field": "capacity",
"old_value": old_capacity,
"new_value": new_capacity
}],
request_metadata=get_audit_metadata(request),
)
db.commit()
db.refresh(user)
# Invalidate workload cache for this user
# Cache keys follow pattern: workload:{user_id}:* or workload:heatmap:*
try:
# Delete user-specific workload cache
for key in redis_client.scan_iter(f"workload:{user_id}:*"):
redis_client.delete(key)
# Delete heatmap cache (contains all users' workload data)
for key in redis_client.scan_iter("workload:heatmap:*"):
redis_client.delete(key)
except Exception:
# Cache invalidation failure should not fail the request
pass
return user

View File

@@ -1,4 +1,5 @@
from pydantic_settings import BaseSettings
from pydantic import field_validator
from typing import List
import os
@@ -24,11 +25,33 @@ class Settings(BaseSettings):
def REDIS_URL(self) -> str:
return f"redis://{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
# JWT
JWT_SECRET_KEY: str = "your-secret-key-change-in-production"
# JWT - Must be set in environment, no default allowed
JWT_SECRET_KEY: str = ""
JWT_ALGORITHM: str = "HS256"
JWT_EXPIRE_MINUTES: int = 10080 # 7 days
@field_validator("JWT_SECRET_KEY")
@classmethod
def validate_jwt_secret_key(cls, v: str) -> str:
"""Validate that JWT_SECRET_KEY is set and not a placeholder."""
if not v or v.strip() == "":
raise ValueError(
"JWT_SECRET_KEY must be set in environment variables. "
"Please configure it in the .env file."
)
placeholder_values = [
"your-secret-key-change-in-production",
"change-me",
"secret",
"your-secret-key",
]
if v.lower() in placeholder_values:
raise ValueError(
"JWT_SECRET_KEY appears to be a placeholder value. "
"Please set a secure secret key in the .env file."
)
return v
# External Auth API
AUTH_API_URL: str = "https://pj-auth-api.vercel.app"

View File

@@ -0,0 +1,26 @@
"""
Rate limiting configuration using slowapi with Redis backend.
This module provides rate limiting functionality to protect against
brute force attacks and DoS attempts on sensitive endpoints.
"""
import os
from slowapi import Limiter
from slowapi.util import get_remote_address
from app.core.config import settings
# Use memory storage for testing, Redis for production
# This allows tests to run without a Redis connection
_testing = os.environ.get("TESTING", "").lower() in ("true", "1", "yes")
_storage_uri = "memory://" if _testing else settings.REDIS_URL
# Create limiter instance with appropriate storage
# Uses the client's remote address (IP) as the key for rate limiting
limiter = Limiter(
key_func=get_remote_address,
storage_uri=_storage_uri,
strategy="fixed-window", # Fixed window strategy for predictable rate limiting
)

View File

@@ -1,9 +1,11 @@
import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from app.core.database import SessionLocal
from app.services.report_service import ReportService
from app.services.trigger_scheduler import TriggerSchedulerService
logger = logging.getLogger(__name__)
@@ -24,6 +26,24 @@ async def weekly_report_job():
db.close()
async def schedule_trigger_job():
"""Job function to evaluate and execute schedule triggers.
This runs every minute and checks:
1. Cron-based schedule triggers
2. Deadline reminder triggers
"""
db = SessionLocal()
try:
logs = TriggerSchedulerService.evaluate_schedule_triggers(db)
if logs:
logger.info(f"Schedule trigger job executed {len(logs)} triggers")
except Exception as e:
logger.error(f"Error in schedule trigger job: {e}")
finally:
db.close()
def init_scheduler():
"""Initialize the scheduler with jobs."""
# Weekly report - Every Friday at 16:00
@@ -35,7 +55,16 @@ def init_scheduler():
replace_existing=True,
)
logger.info("Scheduler initialized with weekly report job (Friday 16:00)")
# Schedule trigger evaluation - Every minute
scheduler.add_job(
schedule_trigger_job,
IntervalTrigger(minutes=1),
id='schedule_triggers',
name='Evaluate Schedule Triggers',
replace_existing=True,
)
logger.info("Scheduler initialized with jobs: weekly_report (Friday 16:00), schedule_triggers (every minute)")
def start_scheduler():

View File

@@ -1,4 +1,4 @@
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional, Any
from jose import jwt, JWTError
from app.core.config import settings
@@ -16,13 +16,14 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
Encoded JWT token string
"""
to_encode = data.copy()
now = datetime.now(timezone.utc)
if expires_delta:
expire = datetime.utcnow() + expires_delta
expire = now + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
expire = now + timedelta(minutes=settings.JWT_EXPIRE_MINUTES)
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
to_encode.update({"exp": expire, "iat": now})
encoded_jwt = jwt.encode(
to_encode,

View File

@@ -1,9 +1,13 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from app.middleware.audit import AuditMiddleware
from app.core.scheduler import start_scheduler, shutdown_scheduler
from app.core.rate_limiter import limiter
@asynccontextmanager
@@ -29,6 +33,7 @@ from app.api.audit import router as audit_router
from app.api.attachments import router as attachments_router
from app.api.triggers import router as triggers_router
from app.api.reports import router as reports_router
from app.api.health import router as health_router
from app.core.config import settings
app = FastAPI(
@@ -38,6 +43,10 @@ app = FastAPI(
lifespan=lifespan,
)
# Initialize rate limiter
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# CORS middleware
app.add_middleware(
CORSMiddleware,
@@ -66,6 +75,7 @@ app.include_router(audit_router)
app.include_router(attachments_router)
app.include_router(triggers_router)
app.include_router(reports_router)
app.include_router(health_router)
@app.get("/health")

View File

@@ -42,7 +42,16 @@ async def get_current_user(
# Check session in Redis
stored_token = redis_client.get(f"session:{user_id}")
if stored_token is None or stored_token != token:
if stored_token is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Session expired or invalid",
headers={"WWW-Authenticate": "Bearer"},
)
# Handle Redis bytes type - decode if necessary
if isinstance(stored_token, bytes):
stored_token = stored_token.decode("utf-8")
if stored_token != token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Session expired or invalid",

View File

@@ -18,6 +18,7 @@ from app.models.trigger import Trigger, TriggerType
from app.models.trigger_log import TriggerLog, TriggerLogStatus
from app.models.scheduled_report import ScheduledReport, ReportType
from app.models.report_history import ReportHistory, ReportHistoryStatus
from app.models.project_health import ProjectHealth, RiskLevel, ScheduleStatus, ResourceStatus
__all__ = [
"User", "Role", "Department", "Space", "Project", "TaskStatus", "Task", "WorkloadSnapshot",
@@ -25,5 +26,6 @@ __all__ = [
"AuditLog", "AuditAlert", "AuditAction", "SensitivityLevel", "EVENT_SENSITIVITY", "ALERT_EVENTS",
"Attachment", "AttachmentVersion",
"Trigger", "TriggerType", "TriggerLog", "TriggerLogStatus",
"ScheduledReport", "ReportType", "ReportHistory", "ReportHistoryStatus"
"ScheduledReport", "ReportType", "ReportHistory", "ReportHistoryStatus",
"ProjectHealth", "RiskLevel", "ScheduleStatus", "ResourceStatus"
]

View File

@@ -13,6 +13,8 @@ class NotificationType(str, enum.Enum):
STATUS_CHANGE = "status_change"
COMMENT = "comment"
BLOCKER_RESOLVED = "blocker_resolved"
DEADLINE_REMINDER = "deadline_reminder"
SCHEDULED_TRIGGER = "scheduled_trigger"
class Notification(Base):
@@ -22,6 +24,7 @@ class Notification(Base):
user_id = Column(String(36), ForeignKey("pjctrl_users.id", ondelete="CASCADE"), nullable=False)
type = Column(
Enum("mention", "assignment", "blocker", "status_change", "comment", "blocker_resolved",
"deadline_reminder", "scheduled_trigger",
name="notification_type_enum"),
nullable=False
)

View File

@@ -39,3 +39,4 @@ class Project(Base):
task_statuses = relationship("TaskStatus", back_populates="project", cascade="all, delete-orphan")
tasks = relationship("Task", back_populates="project", cascade="all, delete-orphan")
triggers = relationship("Trigger", back_populates="project", cascade="all, delete-orphan")
health = relationship("ProjectHealth", back_populates="project", uselist=False, cascade="all, delete-orphan")

View File

@@ -0,0 +1,51 @@
from sqlalchemy import Column, String, Integer, DateTime, Enum, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from app.core.database import Base
import enum
class RiskLevel(str, enum.Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class ScheduleStatus(str, enum.Enum):
ON_TRACK = "on_track"
AT_RISK = "at_risk"
DELAYED = "delayed"
class ResourceStatus(str, enum.Enum):
ADEQUATE = "adequate"
CONSTRAINED = "constrained"
OVERLOADED = "overloaded"
class ProjectHealth(Base):
__tablename__ = "pjctrl_project_health"
id = Column(String(36), primary_key=True)
project_id = Column(String(36), ForeignKey("pjctrl_projects.id", ondelete="CASCADE"), nullable=False, unique=True)
health_score = Column(Integer, default=100, nullable=False) # 0-100
risk_level = Column(
Enum("low", "medium", "high", "critical", name="risk_level_enum"),
default="low",
nullable=False
)
schedule_status = Column(
Enum("on_track", "at_risk", "delayed", name="schedule_status_enum"),
default="on_track",
nullable=False
)
resource_status = Column(
Enum("adequate", "constrained", "overloaded", name="resource_status_enum"),
default="adequate",
nullable=False
)
last_updated = Column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False)
# Relationships
project = relationship("Project", back_populates="health")

View File

@@ -10,6 +10,7 @@ class User(Base):
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
email = Column(String(200), unique=True, nullable=False, index=True)
employee_id = Column(String(50), unique=True, nullable=True, index=True)
name = Column(String(200), nullable=False)
department_id = Column(String(36), ForeignKey("pjctrl_departments.id"), nullable=True)
role_id = Column(String(36), ForeignKey("pjctrl_roles.id"), nullable=True)

View File

@@ -0,0 +1,68 @@
from pydantic import BaseModel
from typing import Optional, List
from datetime import datetime
from enum import Enum
class RiskLevel(str, Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class ScheduleStatus(str, Enum):
ON_TRACK = "on_track"
AT_RISK = "at_risk"
DELAYED = "delayed"
class ResourceStatus(str, Enum):
ADEQUATE = "adequate"
CONSTRAINED = "constrained"
OVERLOADED = "overloaded"
class ProjectHealthBase(BaseModel):
health_score: int
risk_level: RiskLevel
schedule_status: ScheduleStatus
resource_status: ResourceStatus
class ProjectHealthResponse(ProjectHealthBase):
id: str
project_id: str
last_updated: datetime
class Config:
from_attributes = True
class ProjectHealthWithDetails(ProjectHealthResponse):
"""Extended health response with project and computed metrics."""
project_title: str
project_status: str
owner_name: Optional[str] = None
space_name: Optional[str] = None
task_count: int = 0
completed_task_count: int = 0
blocker_count: int = 0
overdue_task_count: int = 0
class ProjectHealthSummary(BaseModel):
"""Aggregated health metrics across all projects."""
total_projects: int
healthy_count: int # health_score >= 80
at_risk_count: int # health_score 50-79
critical_count: int # health_score < 50
average_health_score: float
projects_with_blockers: int
projects_delayed: int
class ProjectHealthDashboardResponse(BaseModel):
"""Full dashboard response with project list and summary."""
projects: List[ProjectHealthWithDetails]
summary: ProjectHealthSummary

View File

@@ -1,14 +1,32 @@
from datetime import datetime
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field
class TriggerCondition(BaseModel):
class FieldChangeCondition(BaseModel):
"""Condition for field_change triggers."""
field: str = Field(..., description="Field to check: status_id, assignee_id, priority")
operator: str = Field(..., description="Operator: equals, not_equals, changed_to, changed_from")
value: str = Field(..., description="Value to compare against")
class ScheduleCondition(BaseModel):
"""Condition for schedule triggers."""
cron_expression: Optional[str] = Field(None, description="Cron expression (e.g., '0 9 * * 1' for Monday 9am)")
deadline_reminder_days: Optional[int] = Field(None, ge=1, le=365, description="Days before due date to send reminder")
class TriggerCondition(BaseModel):
"""Union condition that supports both field_change and schedule triggers."""
# Field change conditions
field: Optional[str] = Field(None, description="Field to check: status_id, assignee_id, priority")
operator: Optional[str] = Field(None, description="Operator: equals, not_equals, changed_to, changed_from")
value: Optional[str] = Field(None, description="Value to compare against")
# Schedule conditions
cron_expression: Optional[str] = Field(None, description="Cron expression for schedule triggers")
deadline_reminder_days: Optional[int] = Field(None, ge=1, le=365, description="Days before due date to send reminder")
class TriggerAction(BaseModel):
type: str = Field(default="notify", description="Action type: notify")
target: str = Field(default="assignee", description="Target: assignee, creator, project_owner, user:<id>")

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from typing import Optional, List
from datetime import datetime
from decimal import Decimal
@@ -39,3 +39,25 @@ class UserResponse(UserBase):
class UserInDB(UserResponse):
pass
class CapacityUpdate(BaseModel):
"""Schema for updating user's weekly capacity hours."""
capacity_hours: Decimal
@field_validator("capacity_hours")
@classmethod
def validate_capacity_hours(cls, v: Decimal) -> Decimal:
"""Validate capacity hours is within valid range (0-168)."""
if v < 0:
raise ValueError("Capacity hours must be non-negative")
if v > 168:
raise ValueError("Capacity hours cannot exceed 168 (hours in a week)")
return v
class Config:
json_schema_extra = {
"example": {
"capacity_hours": 40.00
}
}

View File

@@ -1,7 +1,7 @@
import uuid
import hashlib
import json
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
@@ -85,7 +85,8 @@ class AuditService:
request_metadata: Optional[Dict] = None,
) -> AuditLog:
"""Log an audit event."""
now = datetime.utcnow()
# Use naive datetime for consistency with database storage (SQLite strips tzinfo)
now = datetime.now(timezone.utc).replace(tzinfo=None)
sensitivity = AuditService.get_sensitivity_level(event_type)
checksum = AuditService.calculate_checksum(
@@ -204,7 +205,8 @@ class AuditService:
alert.is_acknowledged = True
alert.acknowledged_by = user_id
alert.acknowledged_at = datetime.utcnow()
# Use naive datetime for consistency with database storage
alert.acknowledged_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.flush()
return alert

View File

@@ -139,9 +139,23 @@ class FileStorageService:
return files[0]
def get_file_by_path(self, file_path: str) -> Optional[Path]:
"""Get file by stored path."""
"""Get file by stored path. Handles both absolute and relative paths."""
path = Path(file_path)
return path if path.exists() else None
# If path is absolute and exists, return it directly
if path.is_absolute() and path.exists():
return path
# If path is relative, try prepending base_dir
full_path = self.base_dir / path
if full_path.exists():
return full_path
# Fallback: check if original path exists (e.g., relative from current dir)
if path.exists():
return path
return None
def delete_file(
self,

View File

@@ -0,0 +1,378 @@
"""Project health calculation service.
Provides functionality to calculate and retrieve project health metrics
including risk scores, schedule status, and resource status.
"""
import uuid
from datetime import datetime
from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
from app.models import Project, Task, TaskStatus, Blocker, ProjectHealth
from app.schemas.project_health import (
RiskLevel,
ScheduleStatus,
ResourceStatus,
ProjectHealthResponse,
ProjectHealthWithDetails,
ProjectHealthSummary,
ProjectHealthDashboardResponse,
)
# Constants for health score calculation
BLOCKER_PENALTY_PER_ITEM = 10
BLOCKER_PENALTY_MAX = 30
OVERDUE_PENALTY_PER_ITEM = 5
OVERDUE_PENALTY_MAX = 30
COMPLETION_PENALTY_THRESHOLD = 50
COMPLETION_PENALTY_FACTOR = 0.4
COMPLETION_PENALTY_MAX = 20
# Risk level thresholds
RISK_LOW_THRESHOLD = 80
RISK_MEDIUM_THRESHOLD = 60
RISK_HIGH_THRESHOLD = 40
# Schedule status thresholds
SCHEDULE_AT_RISK_THRESHOLD = 2
# Resource status thresholds
RESOURCE_CONSTRAINED_THRESHOLD = 2
def calculate_health_metrics(db: Session, project: Project) -> Dict[str, Any]:
"""
Calculate health metrics for a project.
Args:
db: Database session
project: Project object to calculate metrics for
Returns:
Dictionary containing:
- health_score: 0-100 integer
- risk_level: low/medium/high/critical
- schedule_status: on_track/at_risk/delayed
- resource_status: adequate/constrained/overloaded
- task_count: Total number of active tasks
- completed_task_count: Number of completed tasks
- blocker_count: Number of unresolved blockers
- overdue_task_count: Number of overdue incomplete tasks
"""
# Fetch active tasks for this project
tasks = db.query(Task).filter(
Task.project_id == project.id,
Task.is_deleted == False
).all()
task_count = len(tasks)
# Count completed tasks
completed_task_count = sum(
1 for task in tasks
if task.status and task.status.is_done
)
# Count overdue tasks (incomplete with past due date)
now = datetime.utcnow()
overdue_task_count = sum(
1 for task in tasks
if task.due_date and task.due_date < now
and not (task.status and task.status.is_done)
)
# Count unresolved blockers
task_ids = [t.id for t in tasks]
blocker_count = 0
if task_ids:
blocker_count = db.query(Blocker).filter(
Blocker.task_id.in_(task_ids),
Blocker.resolved_at.is_(None)
).count()
# Calculate completion rate
completion_rate = 0.0
if task_count > 0:
completion_rate = (completed_task_count / task_count) * 100
# Calculate health score (start at 100, subtract penalties)
health_score = 100
# Apply blocker penalty
blocker_penalty = min(blocker_count * BLOCKER_PENALTY_PER_ITEM, BLOCKER_PENALTY_MAX)
health_score -= blocker_penalty
# Apply overdue penalty
overdue_penalty = min(overdue_task_count * OVERDUE_PENALTY_PER_ITEM, OVERDUE_PENALTY_MAX)
health_score -= overdue_penalty
# Apply completion penalty (if below threshold)
if task_count > 0 and completion_rate < COMPLETION_PENALTY_THRESHOLD:
completion_penalty = int(
(COMPLETION_PENALTY_THRESHOLD - completion_rate) * COMPLETION_PENALTY_FACTOR
)
health_score -= min(completion_penalty, COMPLETION_PENALTY_MAX)
# Ensure health score stays within bounds
health_score = max(0, min(100, health_score))
# Determine risk level based on health score
risk_level = _determine_risk_level(health_score)
# Determine schedule status based on overdue count
schedule_status = _determine_schedule_status(overdue_task_count)
# Determine resource status based on blocker count
resource_status = _determine_resource_status(blocker_count)
return {
"health_score": health_score,
"risk_level": risk_level,
"schedule_status": schedule_status,
"resource_status": resource_status,
"task_count": task_count,
"completed_task_count": completed_task_count,
"blocker_count": blocker_count,
"overdue_task_count": overdue_task_count,
}
def _determine_risk_level(health_score: int) -> str:
"""Determine risk level based on health score."""
if health_score >= RISK_LOW_THRESHOLD:
return "low"
elif health_score >= RISK_MEDIUM_THRESHOLD:
return "medium"
elif health_score >= RISK_HIGH_THRESHOLD:
return "high"
else:
return "critical"
def _determine_schedule_status(overdue_task_count: int) -> str:
"""Determine schedule status based on overdue task count."""
if overdue_task_count == 0:
return "on_track"
elif overdue_task_count <= SCHEDULE_AT_RISK_THRESHOLD:
return "at_risk"
else:
return "delayed"
def _determine_resource_status(blocker_count: int) -> str:
"""Determine resource status based on blocker count."""
if blocker_count == 0:
return "adequate"
elif blocker_count <= RESOURCE_CONSTRAINED_THRESHOLD:
return "constrained"
else:
return "overloaded"
def get_or_create_project_health(db: Session, project: Project) -> ProjectHealth:
"""
Get existing project health record or create a new one.
Args:
db: Database session
project: Project object
Returns:
ProjectHealth record
"""
health = db.query(ProjectHealth).filter(
ProjectHealth.project_id == project.id
).first()
if not health:
health = ProjectHealth(
id=str(uuid.uuid4()),
project_id=project.id
)
db.add(health)
return health
def update_project_health(
db: Session,
project: Project,
metrics: Dict[str, Any]
) -> ProjectHealth:
"""
Update project health record with calculated metrics.
Args:
db: Database session
project: Project object
metrics: Calculated health metrics
Returns:
Updated ProjectHealth record
"""
health = get_or_create_project_health(db, project)
health.health_score = metrics["health_score"]
health.risk_level = metrics["risk_level"]
health.schedule_status = metrics["schedule_status"]
health.resource_status = metrics["resource_status"]
return health
def get_project_health(
db: Session,
project_id: str
) -> Optional[ProjectHealthWithDetails]:
"""
Get health information for a single project.
Args:
db: Database session
project_id: Project ID
Returns:
ProjectHealthWithDetails or None if project not found
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
return None
metrics = calculate_health_metrics(db, project)
health = update_project_health(db, project, metrics)
db.commit()
db.refresh(health)
return _build_health_with_details(project, health, metrics)
def get_all_projects_health(
db: Session,
status_filter: Optional[str] = "active"
) -> ProjectHealthDashboardResponse:
"""
Get health information for all projects.
Args:
db: Database session
status_filter: Filter projects by status (default: "active")
Returns:
ProjectHealthDashboardResponse with projects list and summary
"""
query = db.query(Project)
if status_filter:
query = query.filter(Project.status == status_filter)
projects = query.all()
projects_health: List[ProjectHealthWithDetails] = []
for project in projects:
metrics = calculate_health_metrics(db, project)
health = update_project_health(db, project, metrics)
project_health = _build_health_with_details(project, health, metrics)
projects_health.append(project_health)
db.commit()
# Calculate summary statistics
summary = _calculate_summary(projects_health)
return ProjectHealthDashboardResponse(
projects=projects_health,
summary=summary
)
def _build_health_with_details(
project: Project,
health: ProjectHealth,
metrics: Dict[str, Any]
) -> ProjectHealthWithDetails:
"""Build ProjectHealthWithDetails from project, health, and metrics."""
return ProjectHealthWithDetails(
id=health.id,
project_id=project.id,
health_score=metrics["health_score"],
risk_level=RiskLevel(metrics["risk_level"]),
schedule_status=ScheduleStatus(metrics["schedule_status"]),
resource_status=ResourceStatus(metrics["resource_status"]),
last_updated=health.last_updated or datetime.utcnow(),
project_title=project.title,
project_status=project.status,
owner_name=project.owner.name if project.owner else None,
space_name=project.space.name if project.space else None,
task_count=metrics["task_count"],
completed_task_count=metrics["completed_task_count"],
blocker_count=metrics["blocker_count"],
overdue_task_count=metrics["overdue_task_count"],
)
def _calculate_summary(
projects_health: List[ProjectHealthWithDetails]
) -> ProjectHealthSummary:
"""Calculate summary statistics for health dashboard."""
total_projects = len(projects_health)
healthy_count = sum(1 for p in projects_health if p.health_score >= 80)
at_risk_count = sum(1 for p in projects_health if 50 <= p.health_score < 80)
critical_count = sum(1 for p in projects_health if p.health_score < 50)
average_health_score = 0.0
if total_projects > 0:
average_health_score = sum(p.health_score for p in projects_health) / total_projects
projects_with_blockers = sum(1 for p in projects_health if p.blocker_count > 0)
projects_delayed = sum(
1 for p in projects_health
if p.schedule_status == ScheduleStatus.DELAYED
)
return ProjectHealthSummary(
total_projects=total_projects,
healthy_count=healthy_count,
at_risk_count=at_risk_count,
critical_count=critical_count,
average_health_score=round(average_health_score, 1),
projects_with_blockers=projects_with_blockers,
projects_delayed=projects_delayed,
)
class HealthService:
"""
Service class for project health operations.
Provides a class-based interface for health calculations,
following the service pattern used in the codebase.
"""
def __init__(self, db: Session):
"""Initialize HealthService with database session."""
self.db = db
def calculate_metrics(self, project: Project) -> Dict[str, Any]:
"""Calculate health metrics for a project."""
return calculate_health_metrics(self.db, project)
def get_project_health(self, project_id: str) -> Optional[ProjectHealthWithDetails]:
"""Get health information for a single project."""
return get_project_health(self.db, project_id)
def get_dashboard(
self,
status_filter: Optional[str] = "active"
) -> ProjectHealthDashboardResponse:
"""Get health dashboard for all projects."""
return get_all_projects_health(self.db, status_filter)
def refresh_project_health(self, project: Project) -> ProjectHealth:
"""Refresh and persist health data for a project."""
metrics = calculate_health_metrics(self.db, project)
health = update_project_health(self.db, project, metrics)
self.db.commit()
self.db.refresh(health)
return health

View File

@@ -4,7 +4,7 @@ import re
import asyncio
import logging
import threading
from datetime import datetime
from datetime import datetime, timezone
from typing import List, Optional, Dict, Set
from sqlalchemy.orm import Session
from sqlalchemy import event
@@ -102,7 +102,7 @@ class NotificationService:
"""Convert a Notification to a dict for publishing."""
created_at = notification.created_at
if created_at is None:
created_at = datetime.utcnow()
created_at = datetime.now(timezone.utc).replace(tzinfo=None)
return {
"id": notification.id,
"type": notification.type,

View File

@@ -1,5 +1,5 @@
import uuid
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Dict, Any, List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import func
@@ -15,9 +15,15 @@ class ReportService:
@staticmethod
def get_week_start(date: Optional[datetime] = None) -> datetime:
"""Get the start of the week (Monday) for a given date."""
"""Get the start of the week (Monday) for a given date.
Returns a naive datetime for compatibility with database values.
"""
if date is None:
date = datetime.utcnow()
date = datetime.now(timezone.utc).replace(tzinfo=None)
elif date.tzinfo is not None:
# Convert to naive datetime for consistency
date = date.replace(tzinfo=None)
# Get Monday of the current week
days_since_monday = date.weekday()
week_start = date - timedelta(days=days_since_monday)
@@ -37,7 +43,8 @@ class ReportService:
week_end = week_start + timedelta(days=7)
next_week_start = week_end
next_week_end = next_week_start + timedelta(days=7)
now = datetime.utcnow()
# Use naive datetime for comparison with database values
now = datetime.now(timezone.utc).replace(tzinfo=None)
# Get projects owned by the user
projects = db.query(Project).filter(Project.owner_id == user_id).all()
@@ -189,7 +196,7 @@ class ReportService:
return {
"week_start": week_start.isoformat(),
"week_end": week_end.isoformat(),
"generated_at": datetime.utcnow().isoformat(),
"generated_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(),
"projects": project_details,
"summary": {
"completed_count": len(completed_tasks),
@@ -235,7 +242,8 @@ class ReportService:
db.add(report_history)
# Update last_sent_at
scheduled_report.last_sent_at = datetime.utcnow()
# Use naive datetime for consistency with database storage
scheduled_report.last_sent_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.commit()
@@ -304,7 +312,8 @@ class ReportService:
db.add(history)
# Update last_sent_at
scheduled_report.last_sent_at = datetime.utcnow()
# Use naive datetime for consistency with database storage
scheduled_report.last_sent_at = datetime.now(timezone.utc).replace(tzinfo=None)
# Send notification
ReportService.send_report_notification(db, scheduled_report.recipient_id, content)

View File

@@ -0,0 +1,701 @@
"""
Scheduled Trigger Execution Service
This module provides functionality for parsing cron expressions and executing
scheduled triggers based on their cron schedule, including deadline reminders.
"""
import uuid
import logging
from datetime import datetime, timezone, timedelta
from typing import Optional, List, Dict, Any, Tuple, Set
from croniter import croniter
from sqlalchemy.orm import Session
from sqlalchemy import and_
from app.models import Trigger, TriggerLog, Task, Project
from app.services.notification_service import NotificationService
logger = logging.getLogger(__name__)
# Key prefix for tracking deadline reminders already sent
DEADLINE_REMINDER_LOG_TYPE = "deadline_reminder"
class TriggerSchedulerService:
"""Service for scheduling and executing cron-based triggers."""
@staticmethod
def parse_cron_expression(expression: str) -> Tuple[bool, Optional[str]]:
"""
Validate a cron expression.
Args:
expression: A cron expression string (e.g., "0 9 * * 1-5" for weekdays at 9am)
Returns:
Tuple of (is_valid, error_message)
- is_valid: True if the expression is valid
- error_message: None if valid, otherwise an error description
"""
try:
# croniter requires a base time for initialization
base_time = datetime.now(timezone.utc)
croniter(expression, base_time)
return True, None
except (ValueError, KeyError) as e:
return False, f"Invalid cron expression: {str(e)}"
@staticmethod
def get_next_run_time(expression: str, base_time: Optional[datetime] = None) -> Optional[datetime]:
"""
Get the next scheduled run time for a cron expression.
Args:
expression: A cron expression string
base_time: The base time to calculate from (defaults to now)
Returns:
The next datetime when the schedule matches, or None if invalid
"""
try:
if base_time is None:
base_time = datetime.now(timezone.utc)
cron = croniter(expression, base_time)
return cron.get_next(datetime)
except (ValueError, KeyError):
return None
@staticmethod
def get_previous_run_time(expression: str, base_time: Optional[datetime] = None) -> Optional[datetime]:
"""
Get the previous scheduled run time for a cron expression.
Args:
expression: A cron expression string
base_time: The base time to calculate from (defaults to now)
Returns:
The previous datetime when the schedule matched, or None if invalid
"""
try:
if base_time is None:
base_time = datetime.now(timezone.utc)
cron = croniter(expression, base_time)
return cron.get_prev(datetime)
except (ValueError, KeyError):
return None
@staticmethod
def should_trigger(
trigger: Trigger,
current_time: datetime,
last_execution_time: Optional[datetime] = None,
) -> bool:
"""
Check if a schedule trigger should fire based on its cron expression.
A trigger should fire if:
1. It's a schedule-type trigger and is active
2. Its conditions contain a valid cron expression
3. The cron schedule has matched since the last execution
Args:
trigger: The trigger to evaluate
current_time: The current time to check against
last_execution_time: The time of the last successful execution
Returns:
True if the trigger should fire, False otherwise
"""
# Only process schedule triggers
if trigger.trigger_type != "schedule":
return False
if not trigger.is_active:
return False
# Get cron expression from conditions
conditions = trigger.conditions or {}
cron_expression = conditions.get("cron_expression")
if not cron_expression:
logger.warning(f"Trigger {trigger.id} has no cron_expression in conditions")
return False
# Validate cron expression
is_valid, error = TriggerSchedulerService.parse_cron_expression(cron_expression)
if not is_valid:
logger.warning(f"Trigger {trigger.id} has invalid cron: {error}")
return False
# Get the previous scheduled time before current_time
prev_scheduled = TriggerSchedulerService.get_previous_run_time(cron_expression, current_time)
if prev_scheduled is None:
return False
# If no last execution, check if we're within the execution window (5 minutes)
if last_execution_time is None:
# Only trigger if the scheduled time was within the last 5 minutes
window_seconds = 300 # 5 minutes
time_since_scheduled = (current_time - prev_scheduled).total_seconds()
return 0 <= time_since_scheduled < window_seconds
# Trigger if the previous scheduled time is after the last execution
return prev_scheduled > last_execution_time
@staticmethod
def get_last_execution_time(db: Session, trigger_id: str) -> Optional[datetime]:
"""
Get the last successful execution time for a trigger.
Args:
db: Database session
trigger_id: The trigger ID
Returns:
The datetime of the last successful execution, or None
"""
last_log = db.query(TriggerLog).filter(
TriggerLog.trigger_id == trigger_id,
TriggerLog.status == "success",
).order_by(TriggerLog.executed_at.desc()).first()
return last_log.executed_at if last_log else None
@staticmethod
def execute_scheduled_triggers(db: Session) -> List[TriggerLog]:
"""
Main execution function that evaluates and executes all scheduled triggers.
This function should be called periodically (e.g., every minute) by a scheduler.
Args:
db: Database session
Returns:
List of TriggerLog entries for executed triggers
"""
logs: List[TriggerLog] = []
current_time = datetime.now(timezone.utc)
# Get all active schedule-type triggers
triggers = db.query(Trigger).filter(
Trigger.trigger_type == "schedule",
Trigger.is_active == True,
).all()
logger.info(f"Evaluating {len(triggers)} scheduled triggers at {current_time}")
for trigger in triggers:
try:
# Get last execution time
last_execution = TriggerSchedulerService.get_last_execution_time(db, trigger.id)
# Check if trigger should fire
if TriggerSchedulerService.should_trigger(trigger, current_time, last_execution):
logger.info(f"Executing scheduled trigger: {trigger.name} (ID: {trigger.id})")
log = TriggerSchedulerService._execute_trigger(db, trigger)
logs.append(log)
except Exception as e:
logger.error(f"Error evaluating trigger {trigger.id}: {e}")
# Log the error
error_log = TriggerSchedulerService._log_execution(
db=db,
trigger=trigger,
status="failed",
details={"error_type": type(e).__name__},
error_message=str(e),
)
logs.append(error_log)
if logs:
db.commit()
logger.info(f"Executed {len(logs)} scheduled triggers")
return logs
@staticmethod
def _execute_trigger(db: Session, trigger: Trigger) -> TriggerLog:
"""
Execute a scheduled trigger's actions.
Args:
db: Database session
trigger: The trigger to execute
Returns:
TriggerLog entry for this execution
"""
actions = trigger.actions if isinstance(trigger.actions, list) else [trigger.actions]
executed_actions = []
error_message = None
try:
for action in actions:
action_type = action.get("type")
if action_type == "notify":
TriggerSchedulerService._execute_notify_action(db, action, trigger)
executed_actions.append({"type": action_type, "status": "success"})
# Add more action types here as needed
status = "success"
except Exception as e:
status = "failed"
error_message = str(e)
executed_actions.append({"type": "error", "message": str(e)})
logger.error(f"Error executing trigger {trigger.id} actions: {e}")
return TriggerSchedulerService._log_execution(
db=db,
trigger=trigger,
status=status,
details={
"trigger_name": trigger.name,
"trigger_type": "schedule",
"cron_expression": trigger.conditions.get("cron_expression"),
"actions_executed": executed_actions,
},
error_message=error_message,
)
@staticmethod
def _execute_notify_action(db: Session, action: Dict[str, Any], trigger: Trigger) -> None:
"""
Execute a notify action for a scheduled trigger.
Args:
db: Database session
action: The action configuration
trigger: The parent trigger
"""
target = action.get("target", "project_owner")
template = action.get("template", "Scheduled trigger '{trigger_name}' has fired")
# For scheduled triggers, we typically notify project-level users
project = trigger.project
if not project:
logger.warning(f"Trigger {trigger.id} has no associated project")
return
target_user_id = TriggerSchedulerService._resolve_target(project, target)
if not target_user_id:
logger.debug(f"No target user resolved for trigger {trigger.id} with target '{target}'")
return
# Format message with variables
message = TriggerSchedulerService._format_template(template, trigger, project)
NotificationService.create_notification(
db=db,
user_id=target_user_id,
notification_type="scheduled_trigger",
reference_type="trigger",
reference_id=trigger.id,
title=f"Scheduled: {trigger.name}",
message=message,
)
@staticmethod
def _resolve_target(project: Project, target: str) -> Optional[str]:
"""
Resolve notification target to user ID.
Args:
project: The project context
target: Target specification (e.g., "project_owner", "user:<id>")
Returns:
User ID or None
"""
if target == "project_owner":
return project.owner_id
elif target.startswith("user:"):
return target.split(":", 1)[1]
return None
@staticmethod
def _format_template(template: str, trigger: Trigger, project: Project) -> str:
"""
Format message template with trigger/project variables.
Args:
template: Template string with {variable} placeholders
trigger: The trigger context
project: The project context
Returns:
Formatted message string
"""
replacements = {
"{trigger_name}": trigger.name,
"{trigger_id}": trigger.id,
"{project_name}": project.title if project else "Unknown",
"{project_id}": project.id if project else "Unknown",
}
result = template
for key, value in replacements.items():
result = result.replace(key, str(value))
return result
@staticmethod
def _log_execution(
db: Session,
trigger: Trigger,
status: str,
details: Optional[Dict[str, Any]] = None,
error_message: Optional[str] = None,
task_id: Optional[str] = None,
) -> TriggerLog:
"""
Create a trigger execution log entry.
Args:
db: Database session
trigger: The trigger that was executed
status: Execution status ("success" or "failed")
details: Optional execution details
error_message: Optional error message if failed
task_id: Optional task ID for deadline reminders
Returns:
The created TriggerLog entry
"""
log = TriggerLog(
id=str(uuid.uuid4()),
trigger_id=trigger.id,
task_id=task_id,
status=status,
details=details,
error_message=error_message,
)
db.add(log)
return log
# =========================================================================
# Deadline Reminder Methods
# =========================================================================
@staticmethod
def execute_deadline_reminders(db: Session) -> List[TriggerLog]:
"""
Check all deadline reminder triggers and send notifications for tasks
that are within N days of their due date.
Each task only receives one reminder per trigger configuration.
Args:
db: Database session
Returns:
List of TriggerLog entries for sent reminders
"""
logs: List[TriggerLog] = []
current_time = datetime.now(timezone.utc)
today = current_time.date()
# Get all active schedule triggers with deadline_reminder_days
triggers = db.query(Trigger).filter(
Trigger.trigger_type == "schedule",
Trigger.is_active == True,
).all()
# Filter triggers that have deadline_reminder_days configured
deadline_triggers = [
t for t in triggers
if t.conditions and t.conditions.get("deadline_reminder_days") is not None
]
if not deadline_triggers:
return logs
logger.info(f"Evaluating {len(deadline_triggers)} deadline reminder triggers")
for trigger in deadline_triggers:
try:
reminder_days = trigger.conditions.get("deadline_reminder_days")
if not isinstance(reminder_days, int) or reminder_days < 1:
continue
# Calculate the target date range
# We want to find tasks whose due_date is exactly N days from today
target_date = today + timedelta(days=reminder_days)
# Get tasks in this project that:
# 1. Have a due_date matching the target date
# 2. Are not deleted
# 3. Have not already received a reminder for this trigger
tasks = TriggerSchedulerService._get_tasks_for_deadline_reminder(
db, trigger, target_date
)
for task in tasks:
try:
log = TriggerSchedulerService._send_deadline_reminder(
db, trigger, task, reminder_days
)
logs.append(log)
except Exception as e:
logger.error(
f"Error sending deadline reminder for task {task.id}: {e}"
)
error_log = TriggerSchedulerService._log_execution(
db=db,
trigger=trigger,
status="failed",
details={
"trigger_type": DEADLINE_REMINDER_LOG_TYPE,
"task_id": task.id,
"reminder_days": reminder_days,
},
error_message=str(e),
task_id=task.id,
)
logs.append(error_log)
except Exception as e:
logger.error(f"Error processing deadline trigger {trigger.id}: {e}")
if logs:
db.commit()
logger.info(f"Processed {len(logs)} deadline reminders")
return logs
@staticmethod
def _get_tasks_for_deadline_reminder(
db: Session,
trigger: Trigger,
target_date,
) -> List[Task]:
"""
Get tasks that need deadline reminders for a specific trigger.
Args:
db: Database session
trigger: The deadline reminder trigger
target_date: The date that matches (today + N days)
Returns:
List of tasks that need reminders
"""
# Get IDs of tasks that already received reminders for this trigger
already_notified = db.query(TriggerLog.task_id).filter(
TriggerLog.trigger_id == trigger.id,
TriggerLog.status == "success",
TriggerLog.task_id.isnot(None),
).all()
notified_task_ids: Set[str] = {t[0] for t in already_notified if t[0]}
# Use date range comparison for cross-database compatibility
# target_date is a date object, we need to find tasks due on that date
target_start = datetime.combine(target_date, datetime.min.time()).replace(tzinfo=timezone.utc)
target_end = datetime.combine(target_date, datetime.max.time()).replace(tzinfo=timezone.utc)
# Query tasks matching criteria
tasks = db.query(Task).filter(
Task.project_id == trigger.project_id,
Task.is_deleted == False,
Task.due_date.isnot(None),
Task.due_date >= target_start,
Task.due_date <= target_end,
).all()
# Filter out tasks that already received reminders
return [t for t in tasks if t.id not in notified_task_ids]
@staticmethod
def _send_deadline_reminder(
db: Session,
trigger: Trigger,
task: Task,
reminder_days: int,
) -> TriggerLog:
"""
Send a deadline reminder notification for a task.
Args:
db: Database session
trigger: The trigger configuration
task: The task approaching its deadline
reminder_days: Number of days before deadline
Returns:
TriggerLog entry for this reminder
"""
actions = trigger.actions if isinstance(trigger.actions, list) else [trigger.actions]
executed_actions = []
error_message = None
try:
for action in actions:
action_type = action.get("type")
if action_type == "notify":
TriggerSchedulerService._execute_deadline_notify_action(
db, action, trigger, task, reminder_days
)
executed_actions.append({"type": action_type, "status": "success"})
status = "success"
except Exception as e:
status = "failed"
error_message = str(e)
executed_actions.append({"type": "error", "message": str(e)})
logger.error(f"Error executing deadline reminder for task {task.id}: {e}")
return TriggerSchedulerService._log_execution(
db=db,
trigger=trigger,
status=status,
details={
"trigger_name": trigger.name,
"trigger_type": DEADLINE_REMINDER_LOG_TYPE,
"reminder_days": reminder_days,
"task_title": task.title,
"due_date": str(task.due_date),
"actions_executed": executed_actions,
},
error_message=error_message,
task_id=task.id,
)
@staticmethod
def _execute_deadline_notify_action(
db: Session,
action: Dict[str, Any],
trigger: Trigger,
task: Task,
reminder_days: int,
) -> None:
"""
Execute a notify action for a deadline reminder.
Args:
db: Database session
action: The action configuration
trigger: The parent trigger
task: The task with approaching deadline
reminder_days: Days until deadline
"""
target = action.get("target", "assignee")
template = action.get(
"template",
"Task '{task_title}' is due in {reminder_days} days"
)
# Resolve target user
target_user_id = TriggerSchedulerService._resolve_deadline_target(task, target)
if not target_user_id:
logger.debug(
f"No target user resolved for deadline reminder, task {task.id}, target '{target}'"
)
return
# Format message with variables
message = TriggerSchedulerService._format_deadline_template(
template, trigger, task, reminder_days
)
NotificationService.create_notification(
db=db,
user_id=target_user_id,
notification_type="deadline_reminder",
reference_type="task",
reference_id=task.id,
title=f"Deadline Reminder: {task.title}",
message=message,
)
@staticmethod
def _resolve_deadline_target(task: Task, target: str) -> Optional[str]:
"""
Resolve notification target for deadline reminders.
Args:
task: The task context
target: Target specification
Returns:
User ID or None
"""
if target == "assignee":
return task.assignee_id
elif target == "creator":
return task.created_by
elif target == "project_owner":
return task.project.owner_id if task.project else None
elif target.startswith("user:"):
return target.split(":", 1)[1]
return None
@staticmethod
def _format_deadline_template(
template: str,
trigger: Trigger,
task: Task,
reminder_days: int,
) -> str:
"""
Format message template for deadline reminders.
Args:
template: Template string with {variable} placeholders
trigger: The trigger context
task: The task context
reminder_days: Days until deadline
Returns:
Formatted message string
"""
project = trigger.project
replacements = {
"{trigger_name}": trigger.name,
"{trigger_id}": trigger.id,
"{task_title}": task.title,
"{task_id}": task.id,
"{due_date}": str(task.due_date.date()) if task.due_date else "N/A",
"{reminder_days}": str(reminder_days),
"{project_name}": project.title if project else "Unknown",
"{project_id}": project.id if project else "Unknown",
}
result = template
for key, value in replacements.items():
result = result.replace(key, str(value))
return result
@staticmethod
def evaluate_schedule_triggers(db: Session) -> List[TriggerLog]:
"""
Main entry point for evaluating all schedule triggers.
This method runs both cron-based triggers and deadline reminders.
Should be called every minute by the scheduler.
Args:
db: Database session
Returns:
Combined list of TriggerLog entries from all evaluations
"""
all_logs: List[TriggerLog] = []
# Execute cron-based schedule triggers
cron_logs = TriggerSchedulerService.execute_scheduled_triggers(db)
all_logs.extend(cron_logs)
# Execute deadline reminder triggers
deadline_logs = TriggerSchedulerService.execute_deadline_reminders(db)
all_logs.extend(deadline_logs)
return all_logs

View File

@@ -0,0 +1,327 @@
"""
Watermark Service for MED-009: Dynamic Watermark for Downloads
This service provides functions to add watermarks to image and PDF files
containing user information for audit and tracking purposes.
Watermark content includes:
- User name
- Employee ID (or email as fallback)
- Download timestamp
"""
import io
import logging
import math
from datetime import datetime
from typing import Optional, Tuple
import fitz # PyMuPDF
from PIL import Image, ImageDraw, ImageFont
logger = logging.getLogger(__name__)
class WatermarkService:
"""Service for adding watermarks to downloaded files."""
# Watermark configuration
WATERMARK_OPACITY = 0.3 # 30% opacity for semi-transparency
WATERMARK_ANGLE = -45 # Diagonal angle in degrees
WATERMARK_FONT_SIZE = 24
WATERMARK_COLOR = (128, 128, 128) # Gray color for watermark
WATERMARK_SPACING = 200 # Spacing between repeated watermarks
@staticmethod
def _format_watermark_text(
user_name: str,
employee_id: Optional[str] = None,
download_time: Optional[datetime] = None
) -> str:
"""
Format the watermark text with user information.
Args:
user_name: Name of the user
employee_id: Employee ID (工號) - uses 'N/A' if not provided
download_time: Time of download (defaults to now)
Returns:
Formatted watermark text
"""
if download_time is None:
download_time = datetime.now()
time_str = download_time.strftime("%Y-%m-%d %H:%M:%S")
emp_id = employee_id if employee_id else "N/A"
return f"{user_name} ({emp_id}) - {time_str}"
@staticmethod
def _get_font(size: int = 24) -> ImageFont.FreeTypeFont:
"""Get a font for the watermark. Falls back to default if custom font not available."""
try:
# Try to use a common system font (macOS)
return ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", size)
except (OSError, IOError):
try:
# Try Linux font
return ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", size)
except (OSError, IOError):
try:
# Try Windows font
return ImageFont.truetype("C:/Windows/Fonts/arial.ttf", size)
except (OSError, IOError):
# Fall back to default bitmap font
return ImageFont.load_default()
def add_image_watermark(
self,
image_bytes: bytes,
user_name: str,
employee_id: Optional[str] = None,
download_time: Optional[datetime] = None
) -> Tuple[bytes, str]:
"""
Add a semi-transparent diagonal watermark to an image.
Args:
image_bytes: The original image as bytes
user_name: Name of the user downloading the file
employee_id: Employee ID of the user (工號)
download_time: Time of download (defaults to now)
Returns:
Tuple of (watermarked image bytes, output format)
Raises:
Exception: If watermarking fails
"""
# Open the image
original = Image.open(io.BytesIO(image_bytes))
# Convert to RGBA if necessary for transparency support
if original.mode != 'RGBA':
image = original.convert('RGBA')
else:
image = original.copy()
# Create a transparent overlay for the watermark
watermark_layer = Image.new('RGBA', image.size, (255, 255, 255, 0))
draw = ImageDraw.Draw(watermark_layer)
# Get watermark text
watermark_text = self._format_watermark_text(user_name, employee_id, download_time)
# Get font
font = self._get_font(self.WATERMARK_FONT_SIZE)
# Calculate text size
bbox = draw.textbbox((0, 0), watermark_text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
# Create a larger canvas for the rotated text pattern
diagonal = int(math.sqrt(image.size[0]**2 + image.size[1]**2))
pattern_size = diagonal * 2
# Create pattern layer
pattern = Image.new('RGBA', (pattern_size, pattern_size), (255, 255, 255, 0))
pattern_draw = ImageDraw.Draw(pattern)
# Draw repeated watermark text across the pattern
opacity = int(255 * self.WATERMARK_OPACITY)
watermark_color = (*self.WATERMARK_COLOR, opacity)
y = 0
row = 0
while y < pattern_size:
x = -text_width if row % 2 else 0 # Offset alternate rows
while x < pattern_size:
pattern_draw.text((x, y), watermark_text, font=font, fill=watermark_color)
x += text_width + self.WATERMARK_SPACING
y += text_height + self.WATERMARK_SPACING
row += 1
# Rotate the pattern
rotated_pattern = pattern.rotate(
self.WATERMARK_ANGLE,
expand=False,
center=(pattern_size // 2, pattern_size // 2)
)
# Crop to original image size (centered)
crop_x = (pattern_size - image.size[0]) // 2
crop_y = (pattern_size - image.size[1]) // 2
cropped_pattern = rotated_pattern.crop((
crop_x, crop_y,
crop_x + image.size[0],
crop_y + image.size[1]
))
# Composite the watermark onto the image
watermarked = Image.alpha_composite(image, cropped_pattern)
# Determine output format
original_format = original.format or 'PNG'
if original_format.upper() == 'JPEG':
# Convert back to RGB for JPEG (no alpha channel)
watermarked = watermarked.convert('RGB')
output_format = 'JPEG'
else:
output_format = 'PNG'
# Save to bytes
output = io.BytesIO()
watermarked.save(output, format=output_format, quality=95)
output.seek(0)
logger.info(
f"Image watermark applied successfully for user {user_name} "
f"(employee_id: {employee_id})"
)
return output.getvalue(), output_format.lower()
def add_pdf_watermark(
self,
pdf_bytes: bytes,
user_name: str,
employee_id: Optional[str] = None,
download_time: Optional[datetime] = None
) -> bytes:
"""
Add a semi-transparent diagonal watermark to a PDF using PyMuPDF.
Args:
pdf_bytes: The original PDF as bytes
user_name: Name of the user downloading the file
employee_id: Employee ID of the user (工號)
download_time: Time of download (defaults to now)
Returns:
Watermarked PDF as bytes
Raises:
Exception: If watermarking fails
"""
# Get watermark text
watermark_text = self._format_watermark_text(user_name, employee_id, download_time)
# Open the PDF with PyMuPDF
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
page_count = len(doc)
# Process each page
for page_num in range(page_count):
page = doc[page_num]
page_rect = page.rect
page_width = page_rect.width
page_height = page_rect.height
# Calculate text width for spacing estimation
text_length = fitz.get_text_length(
watermark_text,
fontname="helv",
fontsize=self.WATERMARK_FONT_SIZE
)
# Calculate diagonal for watermark coverage
diagonal = math.sqrt(page_width**2 + page_height**2)
# Set watermark color with opacity (gray with 30% opacity)
color = (0.5, 0.5, 0.5) # Gray
# Calculate rotation angle in radians
angle_rad = math.radians(self.WATERMARK_ANGLE)
# Draw watermark pattern using shape with proper rotation
# We use insert_textbox with a morph transform for rotation
spacing_x = text_length + self.WATERMARK_SPACING
spacing_y = self.WATERMARK_FONT_SIZE + self.WATERMARK_SPACING
# Create watermark by drawing rotated text lines
# We'll use a simpler approach: draw text and apply rotation via morph
shape = page.new_shape()
# Calculate grid positions to cover the page when rotated
center = fitz.Point(page_width / 2, page_height / 2)
# Calculate start and end points for coverage
start = -diagonal
end = diagonal * 2
y = start
row = 0
while y < end:
x = start + (spacing_x / 2 if row % 2 else 0)
while x < end:
# Create text position
text_point = fitz.Point(x, y)
# Apply rotation around center
cos_a = math.cos(angle_rad)
sin_a = math.sin(angle_rad)
# Translate to origin, rotate, translate back
rx = text_point.x - center.x
ry = text_point.y - center.y
new_x = rx * cos_a - ry * sin_a + center.x
new_y = rx * sin_a + ry * cos_a + center.y
# Check if the rotated point is within page bounds (with margin)
margin = 50
if (-margin <= new_x <= page_width + margin and
-margin <= new_y <= page_height + margin):
# Insert text using shape with rotation via morph
text_rect = fitz.Rect(new_x, new_y, new_x + text_length + 10, new_y + 30)
# Use insert_textbox with morph for rotation
pivot = fitz.Point(new_x, new_y)
morph = (pivot, fitz.Matrix(1, 0, 0, 1, 0, 0).prerotate(self.WATERMARK_ANGLE))
shape.insert_textbox(
text_rect,
watermark_text,
fontname="helv",
fontsize=self.WATERMARK_FONT_SIZE,
color=color,
fill_opacity=self.WATERMARK_OPACITY,
morph=morph
)
x += spacing_x
y += spacing_y
row += 1
# Commit the shape drawings
shape.commit(overlay=True)
# Save to bytes
output = io.BytesIO()
doc.save(output)
doc.close()
output.seek(0)
logger.info(
f"PDF watermark applied successfully for user {user_name} "
f"(employee_id: {employee_id}), pages: {page_count}"
)
return output.getvalue()
def is_supported_image(self, mime_type: str) -> bool:
"""Check if the mime type is a supported image format."""
supported_types = {'image/png', 'image/jpeg', 'image/jpg'}
return mime_type.lower() in supported_types
def is_supported_pdf(self, mime_type: str) -> bool:
"""Check if the mime type is a PDF."""
return mime_type.lower() == 'application/pdf'
def supports_watermark(self, mime_type: str) -> bool:
"""Check if the file type supports watermarking."""
return self.is_supported_image(mime_type) or self.is_supported_pdf(mime_type)
# Singleton instance
watermark_service = WatermarkService()

View File

@@ -184,12 +184,17 @@ def get_workload_heatmap(
Returns:
List of UserWorkloadSummary objects
"""
from datetime import datetime
from collections import defaultdict
if week_start is None:
week_start = get_current_week_start()
else:
# Normalize to week start (Monday)
week_start = get_week_bounds(week_start)[0]
week_start, week_end = get_week_bounds(week_start)
# Build user query
query = db.query(User).filter(User.is_active == True)
@@ -201,10 +206,58 @@ def get_workload_heatmap(
users = query.options(joinedload(User.department)).all()
# Calculate workload for each user
if not users:
return []
# Batch query: fetch all tasks for all users in one query
user_id_list = [user.id for user in users]
week_start_dt = datetime.combine(week_start, datetime.min.time())
week_end_dt = datetime.combine(week_end, datetime.max.time())
all_tasks = (
db.query(Task)
.join(Task.status, isouter=True)
.filter(
Task.assignee_id.in_(user_id_list),
Task.due_date >= week_start_dt,
Task.due_date <= week_end_dt,
# Exclude completed tasks
(TaskStatus.is_done == False) | (Task.status_id == None)
)
.all()
)
# Group tasks by assignee_id in memory
tasks_by_user: dict = defaultdict(list)
for task in all_tasks:
tasks_by_user[task.assignee_id].append(task)
# Calculate workload for each user using pre-fetched tasks
results = []
for user in users:
summary = calculate_user_workload(db, user, week_start)
user_tasks = tasks_by_user.get(user.id, [])
# Calculate allocated hours from original_estimate
allocated_hours = Decimal("0")
for task in user_tasks:
if task.original_estimate:
allocated_hours += task.original_estimate
capacity_hours = Decimal(str(user.capacity)) if user.capacity else Decimal("40")
load_percentage = calculate_load_percentage(allocated_hours, capacity_hours)
load_level = determine_load_level(load_percentage)
summary = UserWorkloadSummary(
user_id=user.id,
user_name=user.name,
department_id=user.department_id,
department_name=user.department.name if user.department else None,
capacity_hours=capacity_hours,
allocated_hours=allocated_hours,
load_percentage=load_percentage,
load_level=load_level,
task_count=len(user_tasks),
)
results.append(summary)
return results

View File

@@ -0,0 +1,42 @@
"""Create project health table
Revision ID: 009
Revises: 008
Create Date: 2025-01-04
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '009'
down_revision = '008'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create project_health table
op.create_table(
'pjctrl_project_health',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('project_id', sa.String(36), sa.ForeignKey('pjctrl_projects.id', ondelete='CASCADE'), nullable=False, unique=True),
sa.Column('health_score', sa.Integer, server_default='100', nullable=False),
sa.Column('risk_level', sa.Enum('low', 'medium', 'high', 'critical', name='risk_level_enum'), server_default='low', nullable=False),
sa.Column('schedule_status', sa.Enum('on_track', 'at_risk', 'delayed', name='schedule_status_enum'), server_default='on_track', nullable=False),
sa.Column('resource_status', sa.Enum('adequate', 'constrained', 'overloaded', name='resource_status_enum'), server_default='adequate', nullable=False),
sa.Column('last_updated', sa.DateTime, server_default=sa.func.now(), nullable=False),
)
# Create indexes
op.create_index('idx_project_health_project', 'pjctrl_project_health', ['project_id'])
op.create_index('idx_project_health_risk', 'pjctrl_project_health', ['risk_level'])
def downgrade() -> None:
op.drop_index('idx_project_health_risk', table_name='pjctrl_project_health')
op.drop_index('idx_project_health_project', table_name='pjctrl_project_health')
op.drop_table('pjctrl_project_health')
op.execute("DROP TYPE IF EXISTS risk_level_enum")
op.execute("DROP TYPE IF EXISTS schedule_status_enum")
op.execute("DROP TYPE IF EXISTS resource_status_enum")

View File

@@ -0,0 +1,32 @@
"""Add employee_id to users table for watermark feature
Revision ID: 010
Revises: 009
Create Date: 2026-01-04
MED-009: Add employee_id field to support dynamic watermark with user identification.
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '010'
down_revision: Union[str, None] = '009'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add employee_id column to pjctrl_users table
op.add_column(
'pjctrl_users',
sa.Column('employee_id', sa.String(50), nullable=True, unique=True)
)
# Create index for employee_id lookups
op.create_index('ix_pjctrl_users_employee_id', 'pjctrl_users', ['employee_id'])
def downgrade() -> None:
op.drop_index('ix_pjctrl_users_employee_id', table_name='pjctrl_users')
op.drop_column('pjctrl_users', 'employee_id')

View File

@@ -14,3 +14,10 @@ pydantic-settings==2.1.0
pytest==7.4.4
pytest-asyncio==0.23.3
pytest-cov==4.1.0
slowapi==0.1.9
croniter==2.0.1
APScheduler==3.10.4
Pillow==10.2.0
PyPDF2==3.0.1
reportlab==4.1.0
PyMuPDF==1.26.7

View File

@@ -1,3 +1,8 @@
import os
# Set testing environment before importing app modules
os.environ["TESTING"] = "true"
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
@@ -103,6 +108,18 @@ def mock_redis():
@pytest.fixture(scope="function")
def client(db, mock_redis):
"""Create test client with overridden dependencies."""
# Reset rate limiter storage before each test
from app.core.rate_limiter import limiter
if hasattr(limiter, '_storage') and limiter._storage:
try:
limiter._storage.reset()
except Exception:
pass # Memory storage might not have reset method
# For memory storage, clear internal state
if hasattr(limiter, '_limiter') and hasattr(limiter._limiter, '_storage'):
storage = limiter._limiter._storage
if hasattr(storage, 'storage'):
storage.storage.clear()
def override_get_db():
try:

View File

@@ -0,0 +1,672 @@
"""Tests for project health API and service."""
import pytest
from datetime import datetime, timedelta
from decimal import Decimal
from app.models import User, Department, Space, Project, Task, Blocker
from app.models.task_status import TaskStatus
from app.models.project_health import ProjectHealth
from app.services.health_service import (
calculate_health_metrics,
get_or_create_project_health,
update_project_health,
get_project_health,
get_all_projects_health,
HealthService,
_determine_risk_level,
_determine_schedule_status,
_determine_resource_status,
BLOCKER_PENALTY_PER_ITEM,
BLOCKER_PENALTY_MAX,
OVERDUE_PENALTY_PER_ITEM,
OVERDUE_PENALTY_MAX,
)
from app.schemas.project_health import RiskLevel, ScheduleStatus, ResourceStatus
class TestRiskLevelDetermination:
"""Tests for risk level determination logic."""
def test_low_risk(self):
"""Health score >= 80 should be low risk."""
assert _determine_risk_level(100) == "low"
assert _determine_risk_level(80) == "low"
def test_medium_risk(self):
"""Health score 60-79 should be medium risk."""
assert _determine_risk_level(79) == "medium"
assert _determine_risk_level(60) == "medium"
def test_high_risk(self):
"""Health score 40-59 should be high risk."""
assert _determine_risk_level(59) == "high"
assert _determine_risk_level(40) == "high"
def test_critical_risk(self):
"""Health score < 40 should be critical risk."""
assert _determine_risk_level(39) == "critical"
assert _determine_risk_level(0) == "critical"
class TestScheduleStatusDetermination:
"""Tests for schedule status determination logic."""
def test_on_track(self):
"""No overdue tasks means on track."""
assert _determine_schedule_status(0) == "on_track"
def test_at_risk(self):
"""1-2 overdue tasks means at risk."""
assert _determine_schedule_status(1) == "at_risk"
assert _determine_schedule_status(2) == "at_risk"
def test_delayed(self):
"""More than 2 overdue tasks means delayed."""
assert _determine_schedule_status(3) == "delayed"
assert _determine_schedule_status(10) == "delayed"
class TestResourceStatusDetermination:
"""Tests for resource status determination logic."""
def test_adequate(self):
"""No blockers means adequate resources."""
assert _determine_resource_status(0) == "adequate"
def test_constrained(self):
"""1-2 blockers means constrained resources."""
assert _determine_resource_status(1) == "constrained"
assert _determine_resource_status(2) == "constrained"
def test_overloaded(self):
"""More than 2 blockers means overloaded."""
assert _determine_resource_status(3) == "overloaded"
assert _determine_resource_status(10) == "overloaded"
class TestHealthMetricsCalculation:
"""Tests for health metrics calculation with database."""
def setup_test_data(self, db):
"""Set up test data for health tests."""
# Create department
dept = Department(
id="dept-health-001",
name="Health Test Department",
)
db.add(dept)
# Create space
space = Space(
id="space-health-001",
name="Health Test Space",
owner_id="00000000-0000-0000-0000-000000000001",
is_active=True,
)
db.add(space)
# Create project
project = Project(
id="project-health-001",
space_id="space-health-001",
title="Health Test Project",
owner_id="00000000-0000-0000-0000-000000000001",
department_id="dept-health-001",
security_level="department",
status="active",
)
db.add(project)
# Create task statuses
status_todo = TaskStatus(
id="status-health-todo",
project_id="project-health-001",
name="To Do",
is_done=False,
)
db.add(status_todo)
status_done = TaskStatus(
id="status-health-done",
project_id="project-health-001",
name="Done",
is_done=True,
)
db.add(status_done)
db.commit()
return {
"department": dept,
"space": space,
"project": project,
"status_todo": status_todo,
"status_done": status_done,
}
def create_task(self, db, data, task_id, done=False, overdue=False, has_blocker=False):
"""Helper to create a task with optional characteristics."""
due_date = datetime.utcnow()
if overdue:
due_date = datetime.utcnow() - timedelta(days=3)
else:
due_date = datetime.utcnow() + timedelta(days=3)
task = Task(
id=task_id,
project_id=data["project"].id,
title=f"Task {task_id}",
status_id=data["status_done"].id if done else data["status_todo"].id,
due_date=due_date,
created_by="00000000-0000-0000-0000-000000000001",
is_deleted=False,
)
db.add(task)
db.commit()
if has_blocker:
blocker = Blocker(
id=f"blocker-{task_id}",
task_id=task_id,
reported_by="00000000-0000-0000-0000-000000000001",
reason="Test blocker",
resolved_at=None,
)
db.add(blocker)
db.commit()
return task
def test_calculate_metrics_no_tasks(self, db):
"""Project with no tasks should have 100 health score."""
data = self.setup_test_data(db)
metrics = calculate_health_metrics(db, data["project"])
assert metrics["health_score"] == 100
assert metrics["risk_level"] == "low"
assert metrics["schedule_status"] == "on_track"
assert metrics["resource_status"] == "adequate"
assert metrics["task_count"] == 0
assert metrics["completed_task_count"] == 0
assert metrics["blocker_count"] == 0
assert metrics["overdue_task_count"] == 0
def test_calculate_metrics_all_completed(self, db):
"""Project with all completed tasks should have high health score."""
data = self.setup_test_data(db)
self.create_task(db, data, "task-c1", done=True)
self.create_task(db, data, "task-c2", done=True)
self.create_task(db, data, "task-c3", done=True)
metrics = calculate_health_metrics(db, data["project"])
assert metrics["health_score"] == 100
assert metrics["task_count"] == 3
assert metrics["completed_task_count"] == 3
assert metrics["overdue_task_count"] == 0
def test_calculate_metrics_with_blockers(self, db):
"""Blockers should reduce health score."""
data = self.setup_test_data(db)
# Create 3 tasks with blockers
self.create_task(db, data, "task-b1", has_blocker=True)
self.create_task(db, data, "task-b2", has_blocker=True)
self.create_task(db, data, "task-b3", has_blocker=True)
metrics = calculate_health_metrics(db, data["project"])
# 3 blockers * 10 points = 30 penalty, also low completion penalty
expected_blocker_penalty = min(3 * BLOCKER_PENALTY_PER_ITEM, BLOCKER_PENALTY_MAX)
assert metrics["blocker_count"] == 3
assert metrics["resource_status"] == "overloaded"
assert metrics["health_score"] < 100
def test_calculate_metrics_with_overdue_tasks(self, db):
"""Overdue tasks should reduce health score."""
data = self.setup_test_data(db)
# Create 3 overdue tasks
self.create_task(db, data, "task-o1", overdue=True)
self.create_task(db, data, "task-o2", overdue=True)
self.create_task(db, data, "task-o3", overdue=True)
metrics = calculate_health_metrics(db, data["project"])
assert metrics["overdue_task_count"] == 3
assert metrics["schedule_status"] == "delayed"
assert metrics["health_score"] < 100
def test_calculate_metrics_overdue_completed_not_counted(self, db):
"""Completed overdue tasks should not count as overdue."""
data = self.setup_test_data(db)
# Create task that is overdue but completed
task = Task(
id="task-oc1",
project_id=data["project"].id,
title="Overdue Completed Task",
status_id=data["status_done"].id,
due_date=datetime.utcnow() - timedelta(days=5),
created_by="00000000-0000-0000-0000-000000000001",
is_deleted=False,
)
db.add(task)
db.commit()
metrics = calculate_health_metrics(db, data["project"])
assert metrics["overdue_task_count"] == 0
assert metrics["completed_task_count"] == 1
def test_calculate_metrics_deleted_tasks_excluded(self, db):
"""Soft-deleted tasks should be excluded from calculations."""
data = self.setup_test_data(db)
# Create a normal task
self.create_task(db, data, "task-normal")
# Create a deleted task
deleted_task = Task(
id="task-deleted",
project_id=data["project"].id,
title="Deleted Task",
status_id=data["status_todo"].id,
due_date=datetime.utcnow() - timedelta(days=5), # Overdue
created_by="00000000-0000-0000-0000-000000000001",
is_deleted=True,
deleted_at=datetime.utcnow(),
)
db.add(deleted_task)
db.commit()
metrics = calculate_health_metrics(db, data["project"])
assert metrics["task_count"] == 1 # Only non-deleted task
assert metrics["overdue_task_count"] == 0 # Deleted task not counted
def test_calculate_metrics_combined_penalties(self, db):
"""Multiple issues should stack penalties correctly."""
data = self.setup_test_data(db)
# Create mixed tasks: 2 overdue with blockers
self.create_task(db, data, "task-mix1", overdue=True, has_blocker=True)
self.create_task(db, data, "task-mix2", overdue=True, has_blocker=True)
metrics = calculate_health_metrics(db, data["project"])
assert metrics["blocker_count"] == 2
assert metrics["overdue_task_count"] == 2
# Should have penalties from both
# 2 blockers = 20 penalty, 2 overdue = 10 penalty, plus completion penalty
assert metrics["health_score"] < 80
class TestHealthServiceClass:
"""Tests for HealthService class."""
def setup_test_data(self, db):
"""Set up test data for health service tests."""
# Create department
dept = Department(
id="dept-svc-001",
name="Service Test Department",
)
db.add(dept)
# Create space
space = Space(
id="space-svc-001",
name="Service Test Space",
owner_id="00000000-0000-0000-0000-000000000001",
is_active=True,
)
db.add(space)
# Create project
project = Project(
id="project-svc-001",
space_id="space-svc-001",
title="Service Test Project",
owner_id="00000000-0000-0000-0000-000000000001",
department_id="dept-svc-001",
security_level="department",
status="active",
)
db.add(project)
# Create inactive project
inactive_project = Project(
id="project-svc-inactive",
space_id="space-svc-001",
title="Inactive Project",
owner_id="00000000-0000-0000-0000-000000000001",
department_id="dept-svc-001",
security_level="department",
status="archived",
)
db.add(inactive_project)
db.commit()
return {
"department": dept,
"space": space,
"project": project,
"inactive_project": inactive_project,
}
def test_get_or_create_health_creates_new(self, db):
"""Should create new ProjectHealth if none exists."""
data = self.setup_test_data(db)
health = get_or_create_project_health(db, data["project"])
db.commit()
assert health is not None
assert health.project_id == data["project"].id
assert health.health_score == 100 # Default
def test_get_or_create_health_returns_existing(self, db):
"""Should return existing ProjectHealth if one exists."""
data = self.setup_test_data(db)
# Create initial health record
health1 = get_or_create_project_health(db, data["project"])
health1.health_score = 75
db.commit()
# Should return same record
health2 = get_or_create_project_health(db, data["project"])
assert health2.id == health1.id
assert health2.health_score == 75
def test_get_project_health(self, db):
"""Should return health details for a project."""
data = self.setup_test_data(db)
result = get_project_health(db, data["project"].id)
assert result is not None
assert result.project_id == data["project"].id
assert result.project_title == "Service Test Project"
assert result.health_score == 100
def test_get_project_health_not_found(self, db):
"""Should return None for non-existent project."""
data = self.setup_test_data(db)
result = get_project_health(db, "non-existent-id")
assert result is None
def test_get_all_projects_health_active_only(self, db):
"""Dashboard should only include active projects by default."""
data = self.setup_test_data(db)
result = get_all_projects_health(db, status_filter="active")
project_ids = [p.project_id for p in result.projects]
assert data["project"].id in project_ids
assert data["inactive_project"].id not in project_ids
def test_get_all_projects_health_summary(self, db):
"""Dashboard should include correct summary statistics."""
data = self.setup_test_data(db)
result = get_all_projects_health(db, status_filter="active")
assert result.summary.total_projects >= 1
assert result.summary.average_health_score <= 100
def test_health_service_class_interface(self, db):
"""HealthService class should provide same functionality."""
data = self.setup_test_data(db)
service = HealthService(db)
# Test get_project_health
health = service.get_project_health(data["project"].id)
assert health is not None
assert health.project_id == data["project"].id
# Test get_dashboard
dashboard = service.get_dashboard()
assert dashboard.summary.total_projects >= 1
# Test calculate_metrics
metrics = service.calculate_metrics(data["project"])
assert "health_score" in metrics
assert "risk_level" in metrics
class TestHealthAPI:
"""Tests for health API endpoints."""
def setup_test_data(self, db):
"""Set up test data for API tests."""
# Create department
dept = Department(
id="dept-api-001",
name="API Test Department",
)
db.add(dept)
# Create space
space = Space(
id="space-api-001",
name="API Test Space",
owner_id="00000000-0000-0000-0000-000000000001",
is_active=True,
)
db.add(space)
# Create projects
project1 = Project(
id="project-api-001",
space_id="space-api-001",
title="API Test Project 1",
owner_id="00000000-0000-0000-0000-000000000001",
department_id="dept-api-001",
security_level="department",
status="active",
)
db.add(project1)
project2 = Project(
id="project-api-002",
space_id="space-api-001",
title="API Test Project 2",
owner_id="00000000-0000-0000-0000-000000000001",
department_id="dept-api-001",
security_level="department",
status="active",
)
db.add(project2)
# Create task statuses
status_todo = TaskStatus(
id="status-api-todo",
project_id="project-api-001",
name="To Do",
is_done=False,
)
db.add(status_todo)
# Create a task with blocker for project1
task = Task(
id="task-api-001",
project_id="project-api-001",
title="API Test Task",
status_id="status-api-todo",
due_date=datetime.utcnow() - timedelta(days=2), # Overdue
created_by="00000000-0000-0000-0000-000000000001",
is_deleted=False,
)
db.add(task)
blocker = Blocker(
id="blocker-api-001",
task_id="task-api-001",
reported_by="00000000-0000-0000-0000-000000000001",
reason="Test blocker",
resolved_at=None,
)
db.add(blocker)
db.commit()
return {
"department": dept,
"space": space,
"project1": project1,
"project2": project2,
"task": task,
"blocker": blocker,
}
def test_get_dashboard(self, client, db, admin_token):
"""Admin should be able to get health dashboard."""
data = self.setup_test_data(db)
response = client.get(
"/api/projects/health/dashboard",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
result = response.json()
assert "projects" in result
assert "summary" in result
assert result["summary"]["total_projects"] >= 2
def test_get_dashboard_summary_fields(self, client, db, admin_token):
"""Dashboard summary should include all expected fields."""
data = self.setup_test_data(db)
response = client.get(
"/api/projects/health/dashboard",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
summary = response.json()["summary"]
assert "total_projects" in summary
assert "healthy_count" in summary
assert "at_risk_count" in summary
assert "critical_count" in summary
assert "average_health_score" in summary
assert "projects_with_blockers" in summary
assert "projects_delayed" in summary
def test_get_project_health(self, client, db, admin_token):
"""Admin should be able to get single project health."""
data = self.setup_test_data(db)
response = client.get(
f"/api/projects/health/{data['project1'].id}",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
result = response.json()
assert result["project_id"] == data["project1"].id
assert result["project_title"] == "API Test Project 1"
assert "health_score" in result
assert "risk_level" in result
assert "schedule_status" in result
assert "resource_status" in result
def test_get_project_health_not_found(self, client, db, admin_token):
"""Should return 404 for non-existent project."""
self.setup_test_data(db)
response = client.get(
"/api/projects/health/non-existent-id",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 404
assert response.json()["detail"] == "Project not found"
def test_get_project_health_with_issues(self, client, db, admin_token):
"""Project with issues should have correct metrics."""
data = self.setup_test_data(db)
response = client.get(
f"/api/projects/health/{data['project1'].id}",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
result = response.json()
# Project1 has 1 overdue task with 1 blocker
assert result["blocker_count"] == 1
assert result["overdue_task_count"] == 1
assert result["health_score"] < 100 # Should be penalized
def test_unauthorized_access(self, client, db):
"""Unauthenticated requests should fail."""
response = client.get("/api/projects/health/dashboard")
assert response.status_code == 403
def test_dashboard_with_status_filter(self, client, db, admin_token):
"""Dashboard should respect status filter."""
data = self.setup_test_data(db)
# Create an archived project
archived = Project(
id="project-archived",
space_id="space-api-001",
title="Archived Project",
owner_id="00000000-0000-0000-0000-000000000001",
department_id="dept-api-001",
security_level="department",
status="archived",
)
db.add(archived)
db.commit()
# Default filter should exclude archived
response = client.get(
"/api/projects/health/dashboard",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
project_ids = [p["project_id"] for p in response.json()["projects"]]
assert "project-archived" not in project_ids
def test_project_health_response_structure(self, client, db, admin_token):
"""Response should match ProjectHealthWithDetails schema."""
data = self.setup_test_data(db)
response = client.get(
f"/api/projects/health/{data['project1'].id}",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
result = response.json()
# Required fields from schema
required_fields = [
"id", "project_id", "health_score", "risk_level",
"schedule_status", "resource_status", "last_updated",
"project_title", "project_status", "task_count",
"completed_task_count", "blocker_count", "overdue_task_count"
]
for field in required_fields:
assert field in result, f"Missing field: {field}"
# Check enum values
assert result["risk_level"] in ["low", "medium", "high", "critical"]
assert result["schedule_status"] in ["on_track", "at_risk", "delayed"]
assert result["resource_status"] in ["adequate", "constrained", "overloaded"]

View File

@@ -0,0 +1,124 @@
"""
Test suite for rate limiting functionality.
Tests the rate limiting feature on the login endpoint to ensure
protection against brute force attacks.
"""
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
from app.services.auth_client import AuthAPIError
class TestRateLimiting:
"""Test rate limiting on the login endpoint."""
def test_login_rate_limit_exceeded(self, client):
"""
Test that the login endpoint returns 429 after exceeding rate limit.
GIVEN a client IP has made 5 login attempts within 1 minute
WHEN the client attempts another login
THEN the system returns HTTP 429 Too Many Requests
AND the response includes a Retry-After header
"""
# Mock the external auth service to return auth error
with patch("app.api.auth.router.verify_credentials", new_callable=AsyncMock) as mock_verify:
mock_verify.side_effect = AuthAPIError("Invalid credentials")
login_data = {"email": "test@example.com", "password": "wrongpassword"}
# Make 5 requests (the limit)
for i in range(5):
response = client.post("/api/auth/login", json=login_data)
# These should fail due to invalid credentials (401), but not rate limit
assert response.status_code == 401, f"Request {i+1} expected 401, got {response.status_code}"
# The 6th request should be rate limited
response = client.post("/api/auth/login", json=login_data)
assert response.status_code == 429, f"Expected 429 Too Many Requests, got {response.status_code}"
# Response should contain error details
data = response.json()
assert "error" in data or "detail" in data, "Response should contain error details"
def test_login_within_rate_limit(self, client):
"""
Test that requests within the rate limit are allowed.
GIVEN a client IP has not exceeded the rate limit
WHEN the client makes login requests
THEN the requests are processed normally (not rate limited)
"""
with patch("app.api.auth.router.verify_credentials", new_callable=AsyncMock) as mock_verify:
mock_verify.side_effect = AuthAPIError("Invalid credentials")
login_data = {"email": "test@example.com", "password": "wrongpassword"}
# Make requests within the limit
for i in range(3):
response = client.post("/api/auth/login", json=login_data)
# These should fail due to invalid credentials (401), but not be rate limited
assert response.status_code == 401, f"Request {i+1} expected 401, got {response.status_code}"
def test_rate_limit_response_format(self, client):
"""
Test that the 429 response format matches API standards.
GIVEN the rate limit has been exceeded
WHEN the client receives a 429 response
THEN the response body contains appropriate error information
"""
with patch("app.api.auth.router.verify_credentials", new_callable=AsyncMock) as mock_verify:
mock_verify.side_effect = AuthAPIError("Invalid credentials")
login_data = {"email": "test@example.com", "password": "wrongpassword"}
# Exhaust the rate limit
for _ in range(5):
client.post("/api/auth/login", json=login_data)
# The next request should be rate limited
response = client.post("/api/auth/login", json=login_data)
assert response.status_code == 429
# Check response body contains error information
data = response.json()
assert "error" in data or "detail" in data, "Response should contain error details"
class TestRateLimiterConfiguration:
"""Test rate limiter configuration."""
def test_limiter_uses_redis_storage(self):
"""
Test that the limiter is configured with Redis storage.
GIVEN the rate limiter configuration
WHEN we inspect the storage URI
THEN it should be configured to use Redis
"""
from app.core.rate_limiter import limiter
from app.core.config import settings
# The limiter should be configured
assert limiter is not None
# Verify Redis URL is properly configured
assert settings.REDIS_URL.startswith("redis://")
def test_limiter_uses_remote_address_key(self):
"""
Test that the limiter uses client IP as the key.
GIVEN the rate limiter configuration
WHEN we check the key function
THEN it should use get_remote_address
"""
from app.core.rate_limiter import limiter
from slowapi.util import get_remote_address
# The key function should be get_remote_address
assert limiter._key_func == get_remote_address

View File

@@ -0,0 +1,664 @@
"""
Tests for Schedule Triggers functionality.
This module tests:
- Cron expression parsing and validation
- Deadline reminder logic
- Schedule trigger execution
"""
import pytest
import uuid
from datetime import datetime, timezone, timedelta
from app.models import User, Space, Project, Task, TaskStatus, Trigger, TriggerLog, Notification
from app.services.trigger_scheduler import TriggerSchedulerService
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def test_user(db):
"""Create a test user."""
user = User(
id=str(uuid.uuid4()),
email="scheduleuser@example.com",
name="Schedule Test User",
role_id="00000000-0000-0000-0000-000000000003",
is_active=True,
is_system_admin=False,
)
db.add(user)
db.commit()
return user
@pytest.fixture
def test_user_token(client, mock_redis, test_user):
"""Get a token for test user."""
from app.core.security import create_access_token, create_token_payload
token_data = create_token_payload(
user_id=test_user.id,
email=test_user.email,
role="engineer",
department_id=None,
is_system_admin=False,
)
token = create_access_token(token_data)
mock_redis.setex(f"session:{test_user.id}", 900, token)
return token
@pytest.fixture
def test_space(db, test_user):
"""Create a test space."""
space = Space(
id=str(uuid.uuid4()),
name="Schedule Test Space",
description="Test space for schedule triggers",
owner_id=test_user.id,
)
db.add(space)
db.commit()
return space
@pytest.fixture
def test_project(db, test_space, test_user):
"""Create a test project."""
project = Project(
id=str(uuid.uuid4()),
space_id=test_space.id,
title="Schedule Test Project",
description="Test project for schedule triggers",
owner_id=test_user.id,
)
db.add(project)
db.commit()
return project
@pytest.fixture
def test_status(db, test_project):
"""Create test task statuses."""
status = TaskStatus(
id=str(uuid.uuid4()),
project_id=test_project.id,
name="To Do",
color="#808080",
position=0,
)
db.add(status)
db.commit()
return status
@pytest.fixture
def cron_trigger(db, test_project, test_user):
"""Create a cron-based schedule trigger."""
trigger = Trigger(
id=str(uuid.uuid4()),
project_id=test_project.id,
name="Daily Reminder",
description="Daily reminder at 9am",
trigger_type="schedule",
conditions={
"cron_expression": "0 9 * * *", # Every day at 9am
},
actions=[{
"type": "notify",
"target": "project_owner",
"template": "Daily scheduled trigger fired for {project_name}",
}],
is_active=True,
created_by=test_user.id,
)
db.add(trigger)
db.commit()
return trigger
@pytest.fixture
def deadline_trigger(db, test_project, test_user):
"""Create a deadline reminder trigger."""
trigger = Trigger(
id=str(uuid.uuid4()),
project_id=test_project.id,
name="Deadline Reminder",
description="Remind 3 days before deadline",
trigger_type="schedule",
conditions={
"deadline_reminder_days": 3,
},
actions=[{
"type": "notify",
"target": "assignee",
"template": "Task '{task_title}' is due in {reminder_days} days",
}],
is_active=True,
created_by=test_user.id,
)
db.add(trigger)
db.commit()
return trigger
@pytest.fixture
def task_with_deadline(db, test_project, test_user, test_status):
"""Create a task with a deadline 3 days from now."""
due_date = datetime.now(timezone.utc) + timedelta(days=3)
task = Task(
id=str(uuid.uuid4()),
project_id=test_project.id,
title="Task with Deadline",
description="This task has a deadline",
status_id=test_status.id,
created_by=test_user.id,
assignee_id=test_user.id,
due_date=due_date,
)
db.add(task)
db.commit()
return task
# ============================================================================
# Tests: Cron Expression Parsing
# ============================================================================
class TestCronExpressionParsing:
"""Tests for cron expression parsing and validation."""
def test_parse_valid_cron_expression(self):
"""Test parsing a valid cron expression."""
is_valid, error = TriggerSchedulerService.parse_cron_expression("0 9 * * 1")
assert is_valid is True
assert error is None
def test_parse_valid_cron_every_minute(self):
"""Test parsing every minute cron expression."""
is_valid, error = TriggerSchedulerService.parse_cron_expression("* * * * *")
assert is_valid is True
def test_parse_valid_cron_weekdays(self):
"""Test parsing weekdays-only cron expression."""
is_valid, error = TriggerSchedulerService.parse_cron_expression("0 9 * * 1-5")
assert is_valid is True
def test_parse_valid_cron_monthly(self):
"""Test parsing monthly cron expression."""
is_valid, error = TriggerSchedulerService.parse_cron_expression("0 0 1 * *")
assert is_valid is True
def test_parse_invalid_cron_expression(self):
"""Test parsing an invalid cron expression."""
is_valid, error = TriggerSchedulerService.parse_cron_expression("invalid")
assert is_valid is False
assert error is not None
assert "Invalid cron expression" in error
def test_parse_invalid_cron_too_many_fields(self):
"""Test parsing cron with too many fields."""
is_valid, error = TriggerSchedulerService.parse_cron_expression("0 0 0 0 0 0 0")
assert is_valid is False
def test_parse_invalid_cron_bad_range(self):
"""Test parsing cron with invalid range."""
is_valid, error = TriggerSchedulerService.parse_cron_expression("0 25 * * *")
assert is_valid is False
def test_get_next_run_time(self):
"""Test getting next run time from cron expression."""
base_time = datetime(2025, 1, 1, 8, 0, 0, tzinfo=timezone.utc)
next_time = TriggerSchedulerService.get_next_run_time("0 9 * * *", base_time)
assert next_time is not None
assert next_time.hour == 9
assert next_time.minute == 0
def test_get_previous_run_time(self):
"""Test getting previous run time from cron expression."""
base_time = datetime(2025, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
prev_time = TriggerSchedulerService.get_previous_run_time("0 9 * * *", base_time)
assert prev_time is not None
assert prev_time.hour == 9
assert prev_time.minute == 0
def test_get_next_run_time_invalid_cron(self):
"""Test getting next run time with invalid cron returns None."""
result = TriggerSchedulerService.get_next_run_time("invalid")
assert result is None
# ============================================================================
# Tests: Schedule Trigger Should Fire Logic
# ============================================================================
class TestScheduleTriggerShouldFire:
"""Tests for schedule trigger firing logic."""
def test_should_trigger_within_window(self, db, cron_trigger):
"""Test trigger should fire when within execution window."""
# Set current time to just after scheduled time
scheduled_time = datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
current_time = scheduled_time + timedelta(minutes=2)
result = TriggerSchedulerService.should_trigger(
cron_trigger, current_time, last_execution_time=None
)
assert result is True
def test_should_not_trigger_outside_window(self, db, cron_trigger):
"""Test trigger should not fire when outside execution window."""
# Set current time to well after scheduled time (more than 5 minutes)
scheduled_time = datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
current_time = scheduled_time + timedelta(minutes=10)
result = TriggerSchedulerService.should_trigger(
cron_trigger, current_time, last_execution_time=None
)
assert result is False
def test_should_not_trigger_if_already_executed(self, db, cron_trigger):
"""Test trigger should not fire if already executed after last schedule."""
scheduled_time = datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
current_time = scheduled_time + timedelta(minutes=2)
last_execution = scheduled_time + timedelta(minutes=1)
result = TriggerSchedulerService.should_trigger(
cron_trigger, current_time, last_execution_time=last_execution
)
assert result is False
def test_should_trigger_if_new_schedule_since_last_execution(self, db, cron_trigger):
"""Test trigger should fire if a new schedule time has passed since last execution."""
# Last execution was yesterday at 9:01
last_execution = datetime(2025, 1, 1, 9, 1, 0, tzinfo=timezone.utc)
# Current time is today at 9:02 (new schedule at 9:00 passed)
current_time = datetime(2025, 1, 2, 9, 2, 0, tzinfo=timezone.utc)
result = TriggerSchedulerService.should_trigger(
cron_trigger, current_time, last_execution_time=last_execution
)
assert result is True
def test_should_not_trigger_inactive(self, db, cron_trigger):
"""Test inactive trigger should not fire."""
cron_trigger.is_active = False
db.commit()
current_time = datetime(2025, 1, 1, 9, 1, 0, tzinfo=timezone.utc)
result = TriggerSchedulerService.should_trigger(
cron_trigger, current_time, last_execution_time=None
)
assert result is False
def test_should_not_trigger_field_change_type(self, db, test_project, test_user):
"""Test field_change trigger type should not be evaluated as schedule trigger."""
trigger = Trigger(
id=str(uuid.uuid4()),
project_id=test_project.id,
name="Field Change Trigger",
trigger_type="field_change",
conditions={
"field": "status_id",
"operator": "equals",
"value": "some-id",
},
actions=[{"type": "notify"}],
is_active=True,
created_by=test_user.id,
)
db.add(trigger)
db.commit()
result = TriggerSchedulerService.should_trigger(
trigger, datetime.now(timezone.utc), last_execution_time=None
)
assert result is False
# ============================================================================
# Tests: Deadline Reminder Logic
# ============================================================================
class TestDeadlineReminderLogic:
"""Tests for deadline reminder functionality."""
def test_deadline_reminder_finds_matching_tasks(
self, db, deadline_trigger, task_with_deadline, test_user
):
"""Test that deadline reminder finds tasks due in N days."""
# Execute deadline reminders
logs = TriggerSchedulerService.execute_deadline_reminders(db)
db.commit()
assert len(logs) == 1
assert logs[0].status == "success"
assert logs[0].task_id == task_with_deadline.id
assert logs[0].details["trigger_type"] == "deadline_reminder"
assert logs[0].details["reminder_days"] == 3
def test_deadline_reminder_creates_notification(
self, db, deadline_trigger, task_with_deadline, test_user
):
"""Test that deadline reminder creates a notification."""
logs = TriggerSchedulerService.execute_deadline_reminders(db)
db.commit()
# Check notification was created
notifications = db.query(Notification).filter(
Notification.user_id == test_user.id,
Notification.type == "deadline_reminder",
).all()
assert len(notifications) == 1
assert task_with_deadline.title in notifications[0].message
def test_deadline_reminder_only_sends_once(
self, db, deadline_trigger, task_with_deadline
):
"""Test that deadline reminder only sends once per task per trigger."""
# First execution
logs1 = TriggerSchedulerService.execute_deadline_reminders(db)
db.commit()
assert len(logs1) == 1
# Second execution should not send again
logs2 = TriggerSchedulerService.execute_deadline_reminders(db)
db.commit()
assert len(logs2) == 0
def test_deadline_reminder_ignores_deleted_tasks(
self, db, deadline_trigger, task_with_deadline
):
"""Test that deadline reminder ignores soft-deleted tasks."""
task_with_deadline.is_deleted = True
db.commit()
logs = TriggerSchedulerService.execute_deadline_reminders(db)
assert len(logs) == 0
def test_deadline_reminder_ignores_tasks_without_due_date(
self, db, deadline_trigger, test_project, test_user, test_status
):
"""Test that deadline reminder ignores tasks without due dates."""
task = Task(
id=str(uuid.uuid4()),
project_id=test_project.id,
title="No Deadline Task",
status_id=test_status.id,
created_by=test_user.id,
due_date=None,
)
db.add(task)
db.commit()
logs = TriggerSchedulerService.execute_deadline_reminders(db)
assert len(logs) == 0
def test_deadline_reminder_different_reminder_days(
self, db, test_project, test_user, test_status
):
"""Test deadline reminder with different reminder days configuration."""
# Create a trigger for 7 days reminder
trigger = Trigger(
id=str(uuid.uuid4()),
project_id=test_project.id,
name="7 Day Reminder",
trigger_type="schedule",
conditions={"deadline_reminder_days": 7},
actions=[{"type": "notify", "target": "assignee"}],
is_active=True,
created_by=test_user.id,
)
db.add(trigger)
# Create a task due in 7 days
task = Task(
id=str(uuid.uuid4()),
project_id=test_project.id,
title="Task Due in 7 Days",
status_id=test_status.id,
created_by=test_user.id,
assignee_id=test_user.id,
due_date=datetime.now(timezone.utc) + timedelta(days=7),
)
db.add(task)
db.commit()
logs = TriggerSchedulerService.execute_deadline_reminders(db)
db.commit()
assert len(logs) == 1
assert logs[0].details["reminder_days"] == 7
# ============================================================================
# Tests: Schedule Trigger API
# ============================================================================
class TestScheduleTriggerAPI:
"""Tests for Schedule Trigger API endpoints."""
def test_create_cron_trigger(self, client, test_user_token, test_project):
"""Test creating a schedule trigger with cron expression."""
response = client.post(
f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"},
json={
"name": "Weekly Monday Reminder",
"description": "Remind every Monday at 9am",
"trigger_type": "schedule",
"conditions": {
"cron_expression": "0 9 * * 1",
},
"actions": [{
"type": "notify",
"target": "project_owner",
"template": "Weekly reminder for {project_name}",
}],
},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "Weekly Monday Reminder"
assert data["trigger_type"] == "schedule"
assert data["conditions"]["cron_expression"] == "0 9 * * 1"
def test_create_deadline_trigger(self, client, test_user_token, test_project):
"""Test creating a schedule trigger with deadline reminder."""
response = client.post(
f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"},
json={
"name": "Deadline Reminder",
"description": "Remind 5 days before deadline",
"trigger_type": "schedule",
"conditions": {
"deadline_reminder_days": 5,
},
"actions": [{
"type": "notify",
"target": "assignee",
}],
},
)
assert response.status_code == 201
data = response.json()
assert data["conditions"]["deadline_reminder_days"] == 5
def test_create_schedule_trigger_invalid_cron(self, client, test_user_token, test_project):
"""Test creating a schedule trigger with invalid cron expression."""
response = client.post(
f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"},
json={
"name": "Invalid Cron Trigger",
"trigger_type": "schedule",
"conditions": {
"cron_expression": "invalid cron",
},
"actions": [{"type": "notify"}],
},
)
assert response.status_code == 400
assert "Invalid cron expression" in response.json()["detail"]
def test_create_schedule_trigger_missing_condition(self, client, test_user_token, test_project):
"""Test creating a schedule trigger without cron or deadline condition."""
response = client.post(
f"/api/projects/{test_project.id}/triggers",
headers={"Authorization": f"Bearer {test_user_token}"},
json={
"name": "Empty Schedule Trigger",
"trigger_type": "schedule",
"conditions": {},
"actions": [{"type": "notify"}],
},
)
assert response.status_code == 400
assert "require either cron_expression or deadline_reminder_days" in response.json()["detail"]
def test_update_schedule_trigger_cron(self, client, test_user_token, cron_trigger):
"""Test updating a schedule trigger's cron expression."""
response = client.put(
f"/api/triggers/{cron_trigger.id}",
headers={"Authorization": f"Bearer {test_user_token}"},
json={
"conditions": {
"cron_expression": "0 10 * * *", # Changed to 10am
},
},
)
assert response.status_code == 200
data = response.json()
assert data["conditions"]["cron_expression"] == "0 10 * * *"
def test_update_schedule_trigger_invalid_cron(self, client, test_user_token, cron_trigger):
"""Test updating a schedule trigger with invalid cron expression."""
response = client.put(
f"/api/triggers/{cron_trigger.id}",
headers={"Authorization": f"Bearer {test_user_token}"},
json={
"conditions": {
"cron_expression": "not valid",
},
},
)
assert response.status_code == 400
assert "Invalid cron expression" in response.json()["detail"]
# ============================================================================
# Tests: Integration - Schedule Trigger Execution
# ============================================================================
class TestScheduleTriggerExecution:
"""Integration tests for schedule trigger execution."""
def test_execute_scheduled_triggers(self, db, cron_trigger, test_user):
"""Test executing scheduled triggers creates logs."""
# Manually set conditions to trigger execution
# Create a log entry as if it was executed before
# The trigger should not fire again immediately
# First, verify no logs exist
logs_before = db.query(TriggerLog).filter(
TriggerLog.trigger_id == cron_trigger.id
).all()
assert len(logs_before) == 0
def test_evaluate_schedule_triggers_combined(
self, db, cron_trigger, deadline_trigger, task_with_deadline
):
"""Test that evaluate_schedule_triggers runs both cron and deadline triggers."""
# Note: This test verifies the combined execution method exists and works
# The actual execution depends on timing, so we mainly test structure
# Execute the combined evaluation
logs = TriggerSchedulerService.evaluate_schedule_triggers(db)
# Should have deadline reminder executed
deadline_logs = [l for l in logs if l.details and l.details.get("trigger_type") == "deadline_reminder"]
assert len(deadline_logs) == 1
def test_trigger_log_details(self, db, deadline_trigger, task_with_deadline):
"""Test that trigger logs contain proper details."""
logs = TriggerSchedulerService.execute_deadline_reminders(db)
db.commit()
assert len(logs) == 1
log = logs[0]
assert log.trigger_id == deadline_trigger.id
assert log.task_id == task_with_deadline.id
assert log.status == "success"
assert log.details is not None
assert log.details["trigger_name"] == deadline_trigger.name
assert log.details["task_title"] == task_with_deadline.title
assert "due_date" in log.details
def test_inactive_trigger_not_executed(self, db, deadline_trigger, task_with_deadline):
"""Test that inactive triggers are not executed."""
deadline_trigger.is_active = False
db.commit()
logs = TriggerSchedulerService.execute_deadline_reminders(db)
assert len(logs) == 0
# ============================================================================
# Tests: Template Formatting
# ============================================================================
class TestTemplateFormatting:
"""Tests for message template formatting."""
def test_format_deadline_template_basic(
self, db, deadline_trigger, task_with_deadline
):
"""Test basic deadline template formatting."""
template = "Task '{task_title}' is due in {reminder_days} days"
result = TriggerSchedulerService._format_deadline_template(
template, deadline_trigger, task_with_deadline, 3
)
assert task_with_deadline.title in result
assert "3" in result
def test_format_deadline_template_all_variables(
self, db, deadline_trigger, task_with_deadline
):
"""Test template with all available variables."""
template = (
"Trigger: {trigger_name}, Task: {task_title}, "
"Due: {due_date}, Days: {reminder_days}, Project: {project_name}"
)
result = TriggerSchedulerService._format_deadline_template(
template, deadline_trigger, task_with_deadline, 3
)
assert deadline_trigger.name in result
assert task_with_deadline.title in result
assert "3" in result
def test_format_scheduled_trigger_template(self, db, cron_trigger):
"""Test scheduled trigger template formatting."""
template = "Trigger '{trigger_name}' fired for project '{project_name}'"
result = TriggerSchedulerService._format_template(
template, cron_trigger, cron_trigger.project
)
assert cron_trigger.name in result
assert cron_trigger.project.title in result

View File

@@ -93,6 +93,263 @@ class TestUserEndpoints:
assert response.status_code == 403
class TestCapacityUpdate:
"""Test user capacity update API endpoint."""
def test_update_own_capacity(self, client, db, mock_redis):
"""Test that a user can update their own capacity."""
from app.core.security import create_access_token, create_token_payload
# Create a test user
test_user = User(
id="capacity-user-001",
email="capacityuser@example.com",
name="Capacity User",
is_active=True,
capacity=40.00,
)
db.add(test_user)
db.commit()
# Create token for the user
token_data = create_token_payload(
user_id="capacity-user-001",
email="capacityuser@example.com",
role="engineer",
department_id=None,
is_system_admin=False,
)
token = create_access_token(token_data)
mock_redis.setex("session:capacity-user-001", 900, token)
# Update own capacity
response = client.put(
"/api/users/capacity-user-001/capacity",
headers={"Authorization": f"Bearer {token}"},
json={"capacity_hours": 35.5},
)
assert response.status_code == 200
data = response.json()
assert float(data["capacity"]) == 35.5
def test_admin_can_update_other_user_capacity(self, client, admin_token, db):
"""Test that admin can update another user's capacity."""
# Create a test user
test_user = User(
id="capacity-user-002",
email="capacityuser2@example.com",
name="Capacity User 2",
is_active=True,
capacity=40.00,
)
db.add(test_user)
db.commit()
# Admin updates another user's capacity
response = client.put(
"/api/users/capacity-user-002/capacity",
headers={"Authorization": f"Bearer {admin_token}"},
json={"capacity_hours": 20.0},
)
assert response.status_code == 200
data = response.json()
assert float(data["capacity"]) == 20.0
def test_non_admin_cannot_update_other_user_capacity(self, client, db, mock_redis):
"""Test that a non-admin user cannot update another user's capacity."""
from app.core.security import create_access_token, create_token_payload
# Create two test users
user1 = User(
id="capacity-user-003",
email="capacityuser3@example.com",
name="Capacity User 3",
is_active=True,
capacity=40.00,
)
user2 = User(
id="capacity-user-004",
email="capacityuser4@example.com",
name="Capacity User 4",
is_active=True,
capacity=40.00,
)
db.add_all([user1, user2])
db.commit()
# Create token for user1
token_data = create_token_payload(
user_id="capacity-user-003",
email="capacityuser3@example.com",
role="engineer",
department_id=None,
is_system_admin=False,
)
token = create_access_token(token_data)
mock_redis.setex("session:capacity-user-003", 900, token)
# User1 tries to update user2's capacity - should fail
response = client.put(
"/api/users/capacity-user-004/capacity",
headers={"Authorization": f"Bearer {token}"},
json={"capacity_hours": 30.0},
)
assert response.status_code == 403
assert "Only admin, manager, or the user themselves" in response.json()["detail"]
def test_update_capacity_invalid_value_negative(self, client, admin_token, db):
"""Test that negative capacity hours are rejected."""
# Create a test user
test_user = User(
id="capacity-user-005",
email="capacityuser5@example.com",
name="Capacity User 5",
is_active=True,
capacity=40.00,
)
db.add(test_user)
db.commit()
response = client.put(
"/api/users/capacity-user-005/capacity",
headers={"Authorization": f"Bearer {admin_token}"},
json={"capacity_hours": -5.0},
)
# Pydantic validation returns 422 Unprocessable Entity
assert response.status_code == 422
error_detail = response.json()["detail"]
# Check validation error message in Pydantic format
assert any("non-negative" in str(err).lower() for err in error_detail)
def test_update_capacity_invalid_value_too_high(self, client, admin_token, db):
"""Test that capacity hours exceeding 168 are rejected."""
# Create a test user
test_user = User(
id="capacity-user-006",
email="capacityuser6@example.com",
name="Capacity User 6",
is_active=True,
capacity=40.00,
)
db.add(test_user)
db.commit()
response = client.put(
"/api/users/capacity-user-006/capacity",
headers={"Authorization": f"Bearer {admin_token}"},
json={"capacity_hours": 200.0},
)
# Pydantic validation returns 422 Unprocessable Entity
assert response.status_code == 422
error_detail = response.json()["detail"]
# Check validation error message in Pydantic format
assert any("168" in str(err) for err in error_detail)
def test_update_capacity_nonexistent_user(self, client, admin_token):
"""Test updating capacity for a nonexistent user."""
response = client.put(
"/api/users/nonexistent-user-id/capacity",
headers={"Authorization": f"Bearer {admin_token}"},
json={"capacity_hours": 40.0},
)
assert response.status_code == 404
assert "User not found" in response.json()["detail"]
def test_manager_can_update_other_user_capacity(self, client, db, mock_redis):
"""Test that manager can update another user's capacity."""
from app.core.security import create_access_token, create_token_payload
from app.models.role import Role
# Create manager role if not exists
manager_role = db.query(Role).filter(Role.name == "manager").first()
if not manager_role:
manager_role = Role(
id="manager-role-cap",
name="manager",
permissions={"users.read": True, "users.write": True},
)
db.add(manager_role)
db.commit()
# Create a manager user
manager_user = User(
id="manager-cap-001",
email="managercap@example.com",
name="Manager Cap",
role_id=manager_role.id,
is_active=True,
is_system_admin=False,
)
# Create a regular user
regular_user = User(
id="regular-cap-001",
email="regularcap@example.com",
name="Regular Cap",
is_active=True,
capacity=40.00,
)
db.add_all([manager_user, regular_user])
db.commit()
# Create token for manager
token_data = create_token_payload(
user_id="manager-cap-001",
email="managercap@example.com",
role="manager",
department_id=None,
is_system_admin=False,
)
token = create_access_token(token_data)
mock_redis.setex("session:manager-cap-001", 900, token)
# Manager updates regular user's capacity
response = client.put(
"/api/users/regular-cap-001/capacity",
headers={"Authorization": f"Bearer {token}"},
json={"capacity_hours": 30.0},
)
assert response.status_code == 200
data = response.json()
assert float(data["capacity"]) == 30.0
def test_capacity_change_creates_audit_log(self, client, admin_token, db):
"""Test that capacity changes are recorded in audit trail."""
from app.models import AuditLog
# Create a test user
test_user = User(
id="capacity-audit-001",
email="capacityaudit@example.com",
name="Capacity Audit User",
is_active=True,
capacity=40.00,
)
db.add(test_user)
db.commit()
# Update capacity
response = client.put(
"/api/users/capacity-audit-001/capacity",
headers={"Authorization": f"Bearer {admin_token}"},
json={"capacity_hours": 35.0},
)
assert response.status_code == 200
# Check audit log was created
audit_log = db.query(AuditLog).filter(
AuditLog.resource_id == "capacity-audit-001",
AuditLog.event_type == "user.capacity_change"
).first()
assert audit_log is not None
assert audit_log.resource_type == "user"
assert audit_log.action == "update"
assert len(audit_log.changes) == 1
assert audit_log.changes[0]["field"] == "capacity"
assert audit_log.changes[0]["old_value"] == 40.0
assert audit_log.changes[0]["new_value"] == 35.0
class TestDepartmentIsolation:
"""Test department-based access control."""

View File

@@ -0,0 +1,755 @@
"""
Tests for MED-009: Dynamic Watermark for Downloads
This module contains unit tests for WatermarkService and
integration tests for the download endpoint with watermark functionality.
"""
import pytest
import uuid
import os
import io
import tempfile
import shutil
from datetime import datetime
from io import BytesIO
from PIL import Image
from app.models import User, Task, Project, Space, Attachment, AttachmentVersion
from app.services.watermark_service import WatermarkService, watermark_service
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def test_user(db):
"""Create a test user for watermark tests."""
user = User(
id=str(uuid.uuid4()),
email="watermark.test@example.com",
employee_id="EMP-WM001",
name="Watermark Tester",
role_id="00000000-0000-0000-0000-000000000003",
is_active=True,
is_system_admin=False,
)
db.add(user)
db.commit()
return user
@pytest.fixture
def test_user_token(client, mock_redis, test_user):
"""Get a token for test user."""
from app.core.security import create_access_token, create_token_payload
token_data = create_token_payload(
user_id=test_user.id,
email=test_user.email,
role="engineer",
department_id=None,
is_system_admin=False,
)
token = create_access_token(token_data)
mock_redis.setex(f"session:{test_user.id}", 900, token)
return token
@pytest.fixture
def test_space(db, test_user):
"""Create a test space."""
space = Space(
id=str(uuid.uuid4()),
name="Watermark Test Space",
description="Test space for watermark tests",
owner_id=test_user.id,
)
db.add(space)
db.commit()
return space
@pytest.fixture
def test_project(db, test_space, test_user):
"""Create a test project."""
project = Project(
id=str(uuid.uuid4()),
space_id=test_space.id,
title="Watermark Test Project",
description="Test project for watermark tests",
owner_id=test_user.id,
)
db.add(project)
db.commit()
return project
@pytest.fixture
def test_task(db, test_project, test_user):
"""Create a test task."""
task = Task(
id=str(uuid.uuid4()),
project_id=test_project.id,
title="Watermark Test Task",
description="Test task for watermark tests",
created_by=test_user.id,
)
db.add(task)
db.commit()
return task
@pytest.fixture
def temp_upload_dir():
"""Create a temporary upload directory."""
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
@pytest.fixture
def sample_png_bytes():
"""Create a sample PNG image as bytes."""
img = Image.new("RGB", (200, 200), color=(255, 255, 255))
output = io.BytesIO()
img.save(output, format="PNG")
output.seek(0)
return output.getvalue()
@pytest.fixture
def sample_jpeg_bytes():
"""Create a sample JPEG image as bytes."""
img = Image.new("RGB", (200, 200), color=(255, 255, 255))
output = io.BytesIO()
img.save(output, format="JPEG")
output.seek(0)
return output.getvalue()
@pytest.fixture
def sample_pdf_bytes():
"""Create a sample PDF as bytes."""
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
buffer = io.BytesIO()
c = canvas.Canvas(buffer, pagesize=letter)
c.drawString(100, 750, "Test PDF Document")
c.drawString(100, 700, "This is a test page for watermarking.")
c.showPage()
c.drawString(100, 750, "Page 2")
c.drawString(100, 700, "Second page content.")
c.showPage()
c.save()
buffer.seek(0)
return buffer.getvalue()
# =============================================================================
# Unit Tests for WatermarkService
# =============================================================================
class TestWatermarkServiceUnit:
"""Unit tests for WatermarkService class."""
def test_format_watermark_text(self):
"""Test watermark text formatting with employee_id."""
test_time = datetime(2024, 1, 15, 10, 30, 45)
text = WatermarkService._format_watermark_text(
user_name="John Doe",
employee_id="EMP001",
download_time=test_time
)
assert "John Doe" in text
assert "EMP001" in text
assert "2024-01-15 10:30:45" in text
assert text == "John Doe (EMP001) - 2024-01-15 10:30:45"
def test_format_watermark_text_without_employee_id(self):
"""Test that watermark text uses N/A when employee_id is not provided."""
test_time = datetime(2024, 1, 15, 10, 30, 45)
text = WatermarkService._format_watermark_text(
user_name="Jane Doe",
employee_id=None,
download_time=test_time
)
assert "Jane Doe" in text
assert "(N/A)" in text
assert text == "Jane Doe (N/A) - 2024-01-15 10:30:45"
def test_format_watermark_text_defaults_to_now(self):
"""Test that watermark text defaults to current time."""
text = WatermarkService._format_watermark_text(
user_name="Jane Doe",
employee_id="EMP002"
)
assert "Jane Doe" in text
assert "EMP002" in text
# Should contain a date-like string
assert "-" in text # Date separator
def test_is_supported_image_png(self):
"""Test PNG is recognized as supported image."""
service = WatermarkService()
assert service.is_supported_image("image/png") is True
assert service.is_supported_image("IMAGE/PNG") is True
def test_is_supported_image_jpeg(self):
"""Test JPEG is recognized as supported image."""
service = WatermarkService()
assert service.is_supported_image("image/jpeg") is True
assert service.is_supported_image("image/jpg") is True
def test_is_supported_image_unsupported(self):
"""Test unsupported image formats are rejected."""
service = WatermarkService()
assert service.is_supported_image("image/gif") is False
assert service.is_supported_image("image/bmp") is False
assert service.is_supported_image("image/webp") is False
def test_is_supported_pdf(self):
"""Test PDF is recognized."""
service = WatermarkService()
assert service.is_supported_pdf("application/pdf") is True
assert service.is_supported_pdf("APPLICATION/PDF") is True
def test_is_supported_pdf_negative(self):
"""Test non-PDF types are not recognized as PDF."""
service = WatermarkService()
assert service.is_supported_pdf("application/json") is False
assert service.is_supported_pdf("text/plain") is False
def test_supports_watermark_images(self):
"""Test supports_watermark for images."""
service = WatermarkService()
assert service.supports_watermark("image/png") is True
assert service.supports_watermark("image/jpeg") is True
def test_supports_watermark_pdf(self):
"""Test supports_watermark for PDF."""
service = WatermarkService()
assert service.supports_watermark("application/pdf") is True
def test_supports_watermark_unsupported(self):
"""Test supports_watermark for unsupported types."""
service = WatermarkService()
assert service.supports_watermark("text/plain") is False
assert service.supports_watermark("application/zip") is False
assert service.supports_watermark("application/octet-stream") is False
class TestImageWatermarking:
"""Unit tests for image watermarking functionality."""
def test_add_image_watermark_png(self, sample_png_bytes):
"""Test adding watermark to PNG image."""
test_time = datetime(2024, 1, 15, 10, 30, 45)
result_bytes, output_format = watermark_service.add_image_watermark(
image_bytes=sample_png_bytes,
user_name="Test User",
employee_id="EMP001",
download_time=test_time
)
# Verify output is valid image bytes
assert len(result_bytes) > 0
assert output_format.lower() == "png"
# Verify output is valid PNG image
result_image = Image.open(io.BytesIO(result_bytes))
assert result_image.format == "PNG"
assert result_image.size == (200, 200)
def test_add_image_watermark_jpeg(self, sample_jpeg_bytes):
"""Test adding watermark to JPEG image."""
test_time = datetime(2024, 1, 15, 10, 30, 45)
result_bytes, output_format = watermark_service.add_image_watermark(
image_bytes=sample_jpeg_bytes,
user_name="Test User",
employee_id="EMP001",
download_time=test_time
)
# Verify output is valid image bytes
assert len(result_bytes) > 0
assert output_format.lower() == "jpeg"
# Verify output is valid JPEG image
result_image = Image.open(io.BytesIO(result_bytes))
assert result_image.format == "JPEG"
assert result_image.size == (200, 200)
def test_add_image_watermark_preserves_dimensions(self, sample_png_bytes):
"""Test that watermarking preserves image dimensions."""
original = Image.open(io.BytesIO(sample_png_bytes))
original_size = original.size
result_bytes, _ = watermark_service.add_image_watermark(
image_bytes=sample_png_bytes,
user_name="Test User",
employee_id="EMP001"
)
result = Image.open(io.BytesIO(result_bytes))
assert result.size == original_size
def test_add_image_watermark_modifies_image(self, sample_png_bytes):
"""Test that watermark actually modifies the image."""
result_bytes, _ = watermark_service.add_image_watermark(
image_bytes=sample_png_bytes,
user_name="Test User",
employee_id="EMP001"
)
# The watermarked image should be different from original
# (Note: size might differ slightly due to compression)
# We verify the image data is actually different
original = Image.open(io.BytesIO(sample_png_bytes))
result = Image.open(io.BytesIO(result_bytes))
# Convert to same mode for comparison
original_rgb = original.convert("RGB")
result_rgb = result.convert("RGB")
# Compare pixel data - they should be different
original_data = list(original_rgb.getdata())
result_data = list(result_rgb.getdata())
# At least some pixels should be different (watermark added)
different_pixels = sum(1 for o, r in zip(original_data, result_data) if o != r)
assert different_pixels > 0, "Watermark should modify image pixels"
def test_add_image_watermark_large_image(self):
"""Test watermarking a larger image."""
# Create a larger image
large_img = Image.new("RGB", (1920, 1080), color=(100, 150, 200))
output = io.BytesIO()
large_img.save(output, format="PNG")
large_bytes = output.getvalue()
result_bytes, output_format = watermark_service.add_image_watermark(
image_bytes=large_bytes,
user_name="Large Image User",
employee_id="EMP-LARGE"
)
assert len(result_bytes) > 0
result_image = Image.open(io.BytesIO(result_bytes))
assert result_image.size == (1920, 1080)
class TestPdfWatermarking:
"""Unit tests for PDF watermarking functionality."""
def test_add_pdf_watermark_basic(self, sample_pdf_bytes):
"""Test adding watermark to PDF."""
import fitz # PyMuPDF
test_time = datetime(2024, 1, 15, 10, 30, 45)
result_bytes = watermark_service.add_pdf_watermark(
pdf_bytes=sample_pdf_bytes,
user_name="PDF Test User",
employee_id="EMP-PDF001",
download_time=test_time
)
# Verify output is valid PDF bytes
assert len(result_bytes) > 0
# Verify output is valid PDF using PyMuPDF
result_pdf = fitz.open(stream=result_bytes, filetype="pdf")
assert len(result_pdf) == 2
result_pdf.close()
def test_add_pdf_watermark_preserves_page_count(self, sample_pdf_bytes):
"""Test that watermarking preserves page count."""
import fitz # PyMuPDF
original_pdf = fitz.open(stream=sample_pdf_bytes, filetype="pdf")
original_page_count = len(original_pdf)
original_pdf.close()
result_bytes = watermark_service.add_pdf_watermark(
pdf_bytes=sample_pdf_bytes,
user_name="Test User",
employee_id="EMP001"
)
result_pdf = fitz.open(stream=result_bytes, filetype="pdf")
assert len(result_pdf) == original_page_count
result_pdf.close()
def test_add_pdf_watermark_modifies_content(self, sample_pdf_bytes):
"""Test that watermark actually modifies the PDF content."""
result_bytes = watermark_service.add_pdf_watermark(
pdf_bytes=sample_pdf_bytes,
user_name="Modified User",
employee_id="EMP-MOD"
)
# The watermarked PDF should be different from original
assert result_bytes != sample_pdf_bytes
def test_add_pdf_watermark_single_page(self):
"""Test watermarking a single-page PDF."""
import fitz # PyMuPDF
# Create single page PDF with PyMuPDF
doc = fitz.open()
page = doc.new_page(width=612, height=792) # Letter size
page.insert_text(point=(100, 750), text="Single Page Document", fontsize=12)
buffer = io.BytesIO()
doc.save(buffer)
doc.close()
single_page_bytes = buffer.getvalue()
result_bytes = watermark_service.add_pdf_watermark(
pdf_bytes=single_page_bytes,
user_name="Single Page User",
employee_id="EMP-SINGLE"
)
result_pdf = fitz.open(stream=result_bytes, filetype="pdf")
assert len(result_pdf) == 1
result_pdf.close()
def test_add_pdf_watermark_many_pages(self):
"""Test watermarking a multi-page PDF."""
import fitz # PyMuPDF
# Create multi-page PDF with PyMuPDF
doc = fitz.open()
for i in range(5):
page = doc.new_page(width=612, height=792)
page.insert_text(point=(100, 750), text=f"Page {i + 1}", fontsize=12)
buffer = io.BytesIO()
doc.save(buffer)
doc.close()
multi_page_bytes = buffer.getvalue()
result_bytes = watermark_service.add_pdf_watermark(
pdf_bytes=multi_page_bytes,
user_name="Multi Page User",
employee_id="EMP-MULTI"
)
result_pdf = fitz.open(stream=result_bytes, filetype="pdf")
assert len(result_pdf) == 5
result_pdf.close()
class TestWatermarkServiceConfiguration:
"""Tests for WatermarkService configuration constants."""
def test_default_opacity(self):
"""Test default watermark opacity."""
assert WatermarkService.WATERMARK_OPACITY == 0.3
def test_default_angle(self):
"""Test default watermark angle."""
assert WatermarkService.WATERMARK_ANGLE == -45
def test_default_font_size(self):
"""Test default watermark font size."""
assert WatermarkService.WATERMARK_FONT_SIZE == 24
def test_default_color(self):
"""Test default watermark color (gray)."""
assert WatermarkService.WATERMARK_COLOR == (128, 128, 128)
# =============================================================================
# Integration Tests for Download with Watermark
# =============================================================================
class TestDownloadWithWatermark:
"""Integration tests for download endpoint with watermark."""
def test_download_png_with_watermark(
self, client, test_user_token, test_task, db, monkeypatch, temp_upload_dir, sample_png_bytes
):
"""Test downloading PNG file applies watermark."""
from pathlib import Path
from app.services.file_storage_service import file_storage_service
monkeypatch.setattr("app.core.config.settings.UPLOAD_DIR", temp_upload_dir)
monkeypatch.setattr(file_storage_service, "base_dir", Path(temp_upload_dir))
# Create attachment and version
attachment_id = str(uuid.uuid4())
version_id = str(uuid.uuid4())
# Save the file to disk
file_dir = os.path.join(temp_upload_dir, test_task.project_id, test_task.id, attachment_id, "v1")
os.makedirs(file_dir, exist_ok=True)
file_path = os.path.join(file_dir, "test.png")
with open(file_path, "wb") as f:
f.write(sample_png_bytes)
relative_path = os.path.join(test_task.project_id, test_task.id, attachment_id, "v1", "test.png")
attachment = Attachment(
id=attachment_id,
task_id=test_task.id,
filename="test.png",
original_filename="test.png",
mime_type="image/png",
file_size=len(sample_png_bytes),
current_version=1,
uploaded_by=test_task.created_by,
)
db.add(attachment)
version = AttachmentVersion(
id=version_id,
attachment_id=attachment_id,
version=1,
file_path=relative_path,
file_size=len(sample_png_bytes),
checksum="0" * 64,
uploaded_by=test_task.created_by,
)
db.add(version)
db.commit()
# Download the file
response = client.get(
f"/api/attachments/{attachment_id}/download",
headers={"Authorization": f"Bearer {test_user_token}"},
)
assert response.status_code == 200
assert response.headers["content-type"] == "image/png"
# Verify watermark was applied (image should be different)
downloaded_image = Image.open(io.BytesIO(response.content))
original_image = Image.open(io.BytesIO(sample_png_bytes))
# Convert to comparable format
downloaded_rgb = downloaded_image.convert("RGB")
original_rgb = original_image.convert("RGB")
downloaded_data = list(downloaded_rgb.getdata())
original_data = list(original_rgb.getdata())
# At least some pixels should be different (watermark present)
different_pixels = sum(1 for o, d in zip(original_data, downloaded_data) if o != d)
assert different_pixels > 0, "Downloaded image should have watermark"
def test_download_pdf_with_watermark(
self, client, test_user_token, test_task, db, monkeypatch, temp_upload_dir, sample_pdf_bytes
):
"""Test downloading PDF file applies watermark."""
from pathlib import Path
from app.services.file_storage_service import file_storage_service
monkeypatch.setattr("app.core.config.settings.UPLOAD_DIR", temp_upload_dir)
monkeypatch.setattr(file_storage_service, "base_dir", Path(temp_upload_dir))
# Create attachment and version
attachment_id = str(uuid.uuid4())
version_id = str(uuid.uuid4())
# Save the file to disk
file_dir = os.path.join(temp_upload_dir, test_task.project_id, test_task.id, attachment_id, "v1")
os.makedirs(file_dir, exist_ok=True)
file_path = os.path.join(file_dir, "test.pdf")
with open(file_path, "wb") as f:
f.write(sample_pdf_bytes)
relative_path = os.path.join(test_task.project_id, test_task.id, attachment_id, "v1", "test.pdf")
attachment = Attachment(
id=attachment_id,
task_id=test_task.id,
filename="test.pdf",
original_filename="test.pdf",
mime_type="application/pdf",
file_size=len(sample_pdf_bytes),
current_version=1,
uploaded_by=test_task.created_by,
)
db.add(attachment)
version = AttachmentVersion(
id=version_id,
attachment_id=attachment_id,
version=1,
file_path=relative_path,
file_size=len(sample_pdf_bytes),
checksum="0" * 64,
uploaded_by=test_task.created_by,
)
db.add(version)
db.commit()
# Download the file
response = client.get(
f"/api/attachments/{attachment_id}/download",
headers={"Authorization": f"Bearer {test_user_token}"},
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/pdf"
# Verify watermark was applied (PDF content should be different)
assert response.content != sample_pdf_bytes, "Downloaded PDF should have watermark"
def test_download_unsupported_file_no_watermark(
self, client, test_user_token, test_task, db, monkeypatch, temp_upload_dir
):
"""Test downloading unsupported file type returns original without watermark."""
from pathlib import Path
from app.services.file_storage_service import file_storage_service
monkeypatch.setattr("app.core.config.settings.UPLOAD_DIR", temp_upload_dir)
monkeypatch.setattr(file_storage_service, "base_dir", Path(temp_upload_dir))
# Create a text file
text_content = b"This is a plain text file."
attachment_id = str(uuid.uuid4())
version_id = str(uuid.uuid4())
# Save the file to disk
file_dir = os.path.join(temp_upload_dir, test_task.project_id, test_task.id, attachment_id, "v1")
os.makedirs(file_dir, exist_ok=True)
file_path = os.path.join(file_dir, "test.txt")
with open(file_path, "wb") as f:
f.write(text_content)
relative_path = os.path.join(test_task.project_id, test_task.id, attachment_id, "v1", "test.txt")
attachment = Attachment(
id=attachment_id,
task_id=test_task.id,
filename="test.txt",
original_filename="test.txt",
mime_type="text/plain",
file_size=len(text_content),
current_version=1,
uploaded_by=test_task.created_by,
)
db.add(attachment)
version = AttachmentVersion(
id=version_id,
attachment_id=attachment_id,
version=1,
file_path=relative_path,
file_size=len(text_content),
checksum="0" * 64,
uploaded_by=test_task.created_by,
)
db.add(version)
db.commit()
# Download the file
response = client.get(
f"/api/attachments/{attachment_id}/download",
headers={"Authorization": f"Bearer {test_user_token}"},
)
assert response.status_code == 200
# Content should be unchanged for unsupported types
assert response.content == text_content
def test_download_jpeg_with_watermark(
self, client, test_user_token, test_task, db, monkeypatch, temp_upload_dir, sample_jpeg_bytes
):
"""Test downloading JPEG file applies watermark."""
from pathlib import Path
from app.services.file_storage_service import file_storage_service
monkeypatch.setattr("app.core.config.settings.UPLOAD_DIR", temp_upload_dir)
monkeypatch.setattr(file_storage_service, "base_dir", Path(temp_upload_dir))
attachment_id = str(uuid.uuid4())
version_id = str(uuid.uuid4())
# Save the file to disk
file_dir = os.path.join(temp_upload_dir, test_task.project_id, test_task.id, attachment_id, "v1")
os.makedirs(file_dir, exist_ok=True)
file_path = os.path.join(file_dir, "test.jpg")
with open(file_path, "wb") as f:
f.write(sample_jpeg_bytes)
relative_path = os.path.join(test_task.project_id, test_task.id, attachment_id, "v1", "test.jpg")
attachment = Attachment(
id=attachment_id,
task_id=test_task.id,
filename="test.jpg",
original_filename="test.jpg",
mime_type="image/jpeg",
file_size=len(sample_jpeg_bytes),
current_version=1,
uploaded_by=test_task.created_by,
)
db.add(attachment)
version = AttachmentVersion(
id=version_id,
attachment_id=attachment_id,
version=1,
file_path=relative_path,
file_size=len(sample_jpeg_bytes),
checksum="0" * 64,
uploaded_by=test_task.created_by,
)
db.add(version)
db.commit()
# Download the file
response = client.get(
f"/api/attachments/{attachment_id}/download",
headers={"Authorization": f"Bearer {test_user_token}"},
)
assert response.status_code == 200
assert response.headers["content-type"] == "image/jpeg"
# Verify the response is a valid JPEG
downloaded_image = Image.open(io.BytesIO(response.content))
assert downloaded_image.format == "JPEG"
class TestWatermarkErrorHandling:
"""Tests for watermark error handling and graceful degradation."""
def test_watermark_service_singleton_exists(self):
"""Test that watermark_service singleton is available."""
assert watermark_service is not None
assert isinstance(watermark_service, WatermarkService)
def test_invalid_image_bytes_graceful_handling(self):
"""Test handling of invalid image bytes."""
invalid_bytes = b"not an image"
with pytest.raises(Exception):
# Should raise an exception for invalid image data
watermark_service.add_image_watermark(
image_bytes=invalid_bytes,
user_name="Test",
employee_id="EMP001"
)
def test_invalid_pdf_bytes_graceful_handling(self):
"""Test handling of invalid PDF bytes."""
invalid_bytes = b"not a pdf"
with pytest.raises(Exception):
# Should raise an exception for invalid PDF data
watermark_service.add_pdf_watermark(
pdf_bytes=invalid_bytes,
user_name="Test",
employee_id="EMP001"
)