feat: complete issue fixes and implement remaining features
## Critical Issues (CRIT-001~003) - All Fixed
- JWT secret key validation with pydantic field_validator
- Login audit logging for success/failure attempts
- Frontend API path prefix removal
## High Priority Issues (HIGH-001~008) - All Fixed
- Project soft delete using is_active flag
- Redis session token bytes handling
- Rate limiting with slowapi (5 req/min for login)
- Attachment API permission checks
- Kanban view with drag-and-drop
- Workload heatmap UI (WorkloadPage, WorkloadHeatmap)
- TaskDetailModal integrating Comments/Attachments
- UserSelect component for task assignment
## Medium Priority Issues (MED-001~012) - All Fixed
- MED-001~005: DB commits, N+1 queries, datetime, error format, blocker flag
- MED-006: Project health dashboard (HealthService, ProjectHealthPage)
- MED-007: Capacity update API (PUT /api/users/{id}/capacity)
- MED-008: Schedule triggers (cron parsing, deadline reminders)
- MED-009: Watermark feature (image/PDF watermarking)
- MED-010~012: useEffect deps, DOM operations, PDF export
## New Files
- backend/app/api/health/ - Project health API
- backend/app/services/health_service.py
- backend/app/services/trigger_scheduler.py
- backend/app/services/watermark_service.py
- backend/app/core/rate_limiter.py
- frontend/src/pages/ProjectHealthPage.tsx
- frontend/src/components/ProjectHealthCard.tsx
- frontend/src/components/KanbanBoard.tsx
- frontend/src/components/WorkloadHeatmap.tsx
## Tests
- 113 new tests passing (health: 32, users: 14, triggers: 35, watermark: 32)
## OpenSpec Archives
- add-project-health-dashboard
- add-capacity-update-api
- add-schedule-triggers
- add-watermark-feature
- add-rate-limiting
- enhance-frontend-ux
- add-resource-management-ui
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,3 +1,8 @@
|
||||
import os
|
||||
|
||||
# Set testing environment before importing app modules
|
||||
os.environ["TESTING"] = "true"
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
@@ -103,6 +108,18 @@ def mock_redis():
|
||||
@pytest.fixture(scope="function")
|
||||
def client(db, mock_redis):
|
||||
"""Create test client with overridden dependencies."""
|
||||
# Reset rate limiter storage before each test
|
||||
from app.core.rate_limiter import limiter
|
||||
if hasattr(limiter, '_storage') and limiter._storage:
|
||||
try:
|
||||
limiter._storage.reset()
|
||||
except Exception:
|
||||
pass # Memory storage might not have reset method
|
||||
# For memory storage, clear internal state
|
||||
if hasattr(limiter, '_limiter') and hasattr(limiter._limiter, '_storage'):
|
||||
storage = limiter._limiter._storage
|
||||
if hasattr(storage, 'storage'):
|
||||
storage.storage.clear()
|
||||
|
||||
def override_get_db():
|
||||
try:
|
||||
|
||||
672
backend/tests/test_health.py
Normal file
672
backend/tests/test_health.py
Normal file
@@ -0,0 +1,672 @@
|
||||
"""Tests for project health API and service."""
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
|
||||
from app.models import User, Department, Space, Project, Task, Blocker
|
||||
from app.models.task_status import TaskStatus
|
||||
from app.models.project_health import ProjectHealth
|
||||
from app.services.health_service import (
|
||||
calculate_health_metrics,
|
||||
get_or_create_project_health,
|
||||
update_project_health,
|
||||
get_project_health,
|
||||
get_all_projects_health,
|
||||
HealthService,
|
||||
_determine_risk_level,
|
||||
_determine_schedule_status,
|
||||
_determine_resource_status,
|
||||
BLOCKER_PENALTY_PER_ITEM,
|
||||
BLOCKER_PENALTY_MAX,
|
||||
OVERDUE_PENALTY_PER_ITEM,
|
||||
OVERDUE_PENALTY_MAX,
|
||||
)
|
||||
from app.schemas.project_health import RiskLevel, ScheduleStatus, ResourceStatus
|
||||
|
||||
|
||||
class TestRiskLevelDetermination:
|
||||
"""Tests for risk level determination logic."""
|
||||
|
||||
def test_low_risk(self):
|
||||
"""Health score >= 80 should be low risk."""
|
||||
assert _determine_risk_level(100) == "low"
|
||||
assert _determine_risk_level(80) == "low"
|
||||
|
||||
def test_medium_risk(self):
|
||||
"""Health score 60-79 should be medium risk."""
|
||||
assert _determine_risk_level(79) == "medium"
|
||||
assert _determine_risk_level(60) == "medium"
|
||||
|
||||
def test_high_risk(self):
|
||||
"""Health score 40-59 should be high risk."""
|
||||
assert _determine_risk_level(59) == "high"
|
||||
assert _determine_risk_level(40) == "high"
|
||||
|
||||
def test_critical_risk(self):
|
||||
"""Health score < 40 should be critical risk."""
|
||||
assert _determine_risk_level(39) == "critical"
|
||||
assert _determine_risk_level(0) == "critical"
|
||||
|
||||
|
||||
class TestScheduleStatusDetermination:
|
||||
"""Tests for schedule status determination logic."""
|
||||
|
||||
def test_on_track(self):
|
||||
"""No overdue tasks means on track."""
|
||||
assert _determine_schedule_status(0) == "on_track"
|
||||
|
||||
def test_at_risk(self):
|
||||
"""1-2 overdue tasks means at risk."""
|
||||
assert _determine_schedule_status(1) == "at_risk"
|
||||
assert _determine_schedule_status(2) == "at_risk"
|
||||
|
||||
def test_delayed(self):
|
||||
"""More than 2 overdue tasks means delayed."""
|
||||
assert _determine_schedule_status(3) == "delayed"
|
||||
assert _determine_schedule_status(10) == "delayed"
|
||||
|
||||
|
||||
class TestResourceStatusDetermination:
|
||||
"""Tests for resource status determination logic."""
|
||||
|
||||
def test_adequate(self):
|
||||
"""No blockers means adequate resources."""
|
||||
assert _determine_resource_status(0) == "adequate"
|
||||
|
||||
def test_constrained(self):
|
||||
"""1-2 blockers means constrained resources."""
|
||||
assert _determine_resource_status(1) == "constrained"
|
||||
assert _determine_resource_status(2) == "constrained"
|
||||
|
||||
def test_overloaded(self):
|
||||
"""More than 2 blockers means overloaded."""
|
||||
assert _determine_resource_status(3) == "overloaded"
|
||||
assert _determine_resource_status(10) == "overloaded"
|
||||
|
||||
|
||||
class TestHealthMetricsCalculation:
|
||||
"""Tests for health metrics calculation with database."""
|
||||
|
||||
def setup_test_data(self, db):
|
||||
"""Set up test data for health tests."""
|
||||
# Create department
|
||||
dept = Department(
|
||||
id="dept-health-001",
|
||||
name="Health Test Department",
|
||||
)
|
||||
db.add(dept)
|
||||
|
||||
# Create space
|
||||
space = Space(
|
||||
id="space-health-001",
|
||||
name="Health Test Space",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
is_active=True,
|
||||
)
|
||||
db.add(space)
|
||||
|
||||
# Create project
|
||||
project = Project(
|
||||
id="project-health-001",
|
||||
space_id="space-health-001",
|
||||
title="Health Test Project",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
department_id="dept-health-001",
|
||||
security_level="department",
|
||||
status="active",
|
||||
)
|
||||
db.add(project)
|
||||
|
||||
# Create task statuses
|
||||
status_todo = TaskStatus(
|
||||
id="status-health-todo",
|
||||
project_id="project-health-001",
|
||||
name="To Do",
|
||||
is_done=False,
|
||||
)
|
||||
db.add(status_todo)
|
||||
|
||||
status_done = TaskStatus(
|
||||
id="status-health-done",
|
||||
project_id="project-health-001",
|
||||
name="Done",
|
||||
is_done=True,
|
||||
)
|
||||
db.add(status_done)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"department": dept,
|
||||
"space": space,
|
||||
"project": project,
|
||||
"status_todo": status_todo,
|
||||
"status_done": status_done,
|
||||
}
|
||||
|
||||
def create_task(self, db, data, task_id, done=False, overdue=False, has_blocker=False):
|
||||
"""Helper to create a task with optional characteristics."""
|
||||
due_date = datetime.utcnow()
|
||||
if overdue:
|
||||
due_date = datetime.utcnow() - timedelta(days=3)
|
||||
else:
|
||||
due_date = datetime.utcnow() + timedelta(days=3)
|
||||
|
||||
task = Task(
|
||||
id=task_id,
|
||||
project_id=data["project"].id,
|
||||
title=f"Task {task_id}",
|
||||
status_id=data["status_done"].id if done else data["status_todo"].id,
|
||||
due_date=due_date,
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
is_deleted=False,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
|
||||
if has_blocker:
|
||||
blocker = Blocker(
|
||||
id=f"blocker-{task_id}",
|
||||
task_id=task_id,
|
||||
reported_by="00000000-0000-0000-0000-000000000001",
|
||||
reason="Test blocker",
|
||||
resolved_at=None,
|
||||
)
|
||||
db.add(blocker)
|
||||
db.commit()
|
||||
|
||||
return task
|
||||
|
||||
def test_calculate_metrics_no_tasks(self, db):
|
||||
"""Project with no tasks should have 100 health score."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
metrics = calculate_health_metrics(db, data["project"])
|
||||
|
||||
assert metrics["health_score"] == 100
|
||||
assert metrics["risk_level"] == "low"
|
||||
assert metrics["schedule_status"] == "on_track"
|
||||
assert metrics["resource_status"] == "adequate"
|
||||
assert metrics["task_count"] == 0
|
||||
assert metrics["completed_task_count"] == 0
|
||||
assert metrics["blocker_count"] == 0
|
||||
assert metrics["overdue_task_count"] == 0
|
||||
|
||||
def test_calculate_metrics_all_completed(self, db):
|
||||
"""Project with all completed tasks should have high health score."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
self.create_task(db, data, "task-c1", done=True)
|
||||
self.create_task(db, data, "task-c2", done=True)
|
||||
self.create_task(db, data, "task-c3", done=True)
|
||||
|
||||
metrics = calculate_health_metrics(db, data["project"])
|
||||
|
||||
assert metrics["health_score"] == 100
|
||||
assert metrics["task_count"] == 3
|
||||
assert metrics["completed_task_count"] == 3
|
||||
assert metrics["overdue_task_count"] == 0
|
||||
|
||||
def test_calculate_metrics_with_blockers(self, db):
|
||||
"""Blockers should reduce health score."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
# Create 3 tasks with blockers
|
||||
self.create_task(db, data, "task-b1", has_blocker=True)
|
||||
self.create_task(db, data, "task-b2", has_blocker=True)
|
||||
self.create_task(db, data, "task-b3", has_blocker=True)
|
||||
|
||||
metrics = calculate_health_metrics(db, data["project"])
|
||||
|
||||
# 3 blockers * 10 points = 30 penalty, also low completion penalty
|
||||
expected_blocker_penalty = min(3 * BLOCKER_PENALTY_PER_ITEM, BLOCKER_PENALTY_MAX)
|
||||
assert metrics["blocker_count"] == 3
|
||||
assert metrics["resource_status"] == "overloaded"
|
||||
assert metrics["health_score"] < 100
|
||||
|
||||
def test_calculate_metrics_with_overdue_tasks(self, db):
|
||||
"""Overdue tasks should reduce health score."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
# Create 3 overdue tasks
|
||||
self.create_task(db, data, "task-o1", overdue=True)
|
||||
self.create_task(db, data, "task-o2", overdue=True)
|
||||
self.create_task(db, data, "task-o3", overdue=True)
|
||||
|
||||
metrics = calculate_health_metrics(db, data["project"])
|
||||
|
||||
assert metrics["overdue_task_count"] == 3
|
||||
assert metrics["schedule_status"] == "delayed"
|
||||
assert metrics["health_score"] < 100
|
||||
|
||||
def test_calculate_metrics_overdue_completed_not_counted(self, db):
|
||||
"""Completed overdue tasks should not count as overdue."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
# Create task that is overdue but completed
|
||||
task = Task(
|
||||
id="task-oc1",
|
||||
project_id=data["project"].id,
|
||||
title="Overdue Completed Task",
|
||||
status_id=data["status_done"].id,
|
||||
due_date=datetime.utcnow() - timedelta(days=5),
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
is_deleted=False,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
|
||||
metrics = calculate_health_metrics(db, data["project"])
|
||||
|
||||
assert metrics["overdue_task_count"] == 0
|
||||
assert metrics["completed_task_count"] == 1
|
||||
|
||||
def test_calculate_metrics_deleted_tasks_excluded(self, db):
|
||||
"""Soft-deleted tasks should be excluded from calculations."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
# Create a normal task
|
||||
self.create_task(db, data, "task-normal")
|
||||
|
||||
# Create a deleted task
|
||||
deleted_task = Task(
|
||||
id="task-deleted",
|
||||
project_id=data["project"].id,
|
||||
title="Deleted Task",
|
||||
status_id=data["status_todo"].id,
|
||||
due_date=datetime.utcnow() - timedelta(days=5), # Overdue
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
is_deleted=True,
|
||||
deleted_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(deleted_task)
|
||||
db.commit()
|
||||
|
||||
metrics = calculate_health_metrics(db, data["project"])
|
||||
|
||||
assert metrics["task_count"] == 1 # Only non-deleted task
|
||||
assert metrics["overdue_task_count"] == 0 # Deleted task not counted
|
||||
|
||||
def test_calculate_metrics_combined_penalties(self, db):
|
||||
"""Multiple issues should stack penalties correctly."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
# Create mixed tasks: 2 overdue with blockers
|
||||
self.create_task(db, data, "task-mix1", overdue=True, has_blocker=True)
|
||||
self.create_task(db, data, "task-mix2", overdue=True, has_blocker=True)
|
||||
|
||||
metrics = calculate_health_metrics(db, data["project"])
|
||||
|
||||
assert metrics["blocker_count"] == 2
|
||||
assert metrics["overdue_task_count"] == 2
|
||||
# Should have penalties from both
|
||||
# 2 blockers = 20 penalty, 2 overdue = 10 penalty, plus completion penalty
|
||||
assert metrics["health_score"] < 80
|
||||
|
||||
|
||||
class TestHealthServiceClass:
|
||||
"""Tests for HealthService class."""
|
||||
|
||||
def setup_test_data(self, db):
|
||||
"""Set up test data for health service tests."""
|
||||
# Create department
|
||||
dept = Department(
|
||||
id="dept-svc-001",
|
||||
name="Service Test Department",
|
||||
)
|
||||
db.add(dept)
|
||||
|
||||
# Create space
|
||||
space = Space(
|
||||
id="space-svc-001",
|
||||
name="Service Test Space",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
is_active=True,
|
||||
)
|
||||
db.add(space)
|
||||
|
||||
# Create project
|
||||
project = Project(
|
||||
id="project-svc-001",
|
||||
space_id="space-svc-001",
|
||||
title="Service Test Project",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
department_id="dept-svc-001",
|
||||
security_level="department",
|
||||
status="active",
|
||||
)
|
||||
db.add(project)
|
||||
|
||||
# Create inactive project
|
||||
inactive_project = Project(
|
||||
id="project-svc-inactive",
|
||||
space_id="space-svc-001",
|
||||
title="Inactive Project",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
department_id="dept-svc-001",
|
||||
security_level="department",
|
||||
status="archived",
|
||||
)
|
||||
db.add(inactive_project)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"department": dept,
|
||||
"space": space,
|
||||
"project": project,
|
||||
"inactive_project": inactive_project,
|
||||
}
|
||||
|
||||
def test_get_or_create_health_creates_new(self, db):
|
||||
"""Should create new ProjectHealth if none exists."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
health = get_or_create_project_health(db, data["project"])
|
||||
db.commit()
|
||||
|
||||
assert health is not None
|
||||
assert health.project_id == data["project"].id
|
||||
assert health.health_score == 100 # Default
|
||||
|
||||
def test_get_or_create_health_returns_existing(self, db):
|
||||
"""Should return existing ProjectHealth if one exists."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
# Create initial health record
|
||||
health1 = get_or_create_project_health(db, data["project"])
|
||||
health1.health_score = 75
|
||||
db.commit()
|
||||
|
||||
# Should return same record
|
||||
health2 = get_or_create_project_health(db, data["project"])
|
||||
|
||||
assert health2.id == health1.id
|
||||
assert health2.health_score == 75
|
||||
|
||||
def test_get_project_health(self, db):
|
||||
"""Should return health details for a project."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
result = get_project_health(db, data["project"].id)
|
||||
|
||||
assert result is not None
|
||||
assert result.project_id == data["project"].id
|
||||
assert result.project_title == "Service Test Project"
|
||||
assert result.health_score == 100
|
||||
|
||||
def test_get_project_health_not_found(self, db):
|
||||
"""Should return None for non-existent project."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
result = get_project_health(db, "non-existent-id")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_all_projects_health_active_only(self, db):
|
||||
"""Dashboard should only include active projects by default."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
result = get_all_projects_health(db, status_filter="active")
|
||||
|
||||
project_ids = [p.project_id for p in result.projects]
|
||||
assert data["project"].id in project_ids
|
||||
assert data["inactive_project"].id not in project_ids
|
||||
|
||||
def test_get_all_projects_health_summary(self, db):
|
||||
"""Dashboard should include correct summary statistics."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
result = get_all_projects_health(db, status_filter="active")
|
||||
|
||||
assert result.summary.total_projects >= 1
|
||||
assert result.summary.average_health_score <= 100
|
||||
|
||||
def test_health_service_class_interface(self, db):
|
||||
"""HealthService class should provide same functionality."""
|
||||
data = self.setup_test_data(db)
|
||||
service = HealthService(db)
|
||||
|
||||
# Test get_project_health
|
||||
health = service.get_project_health(data["project"].id)
|
||||
assert health is not None
|
||||
assert health.project_id == data["project"].id
|
||||
|
||||
# Test get_dashboard
|
||||
dashboard = service.get_dashboard()
|
||||
assert dashboard.summary.total_projects >= 1
|
||||
|
||||
# Test calculate_metrics
|
||||
metrics = service.calculate_metrics(data["project"])
|
||||
assert "health_score" in metrics
|
||||
assert "risk_level" in metrics
|
||||
|
||||
|
||||
class TestHealthAPI:
|
||||
"""Tests for health API endpoints."""
|
||||
|
||||
def setup_test_data(self, db):
|
||||
"""Set up test data for API tests."""
|
||||
# Create department
|
||||
dept = Department(
|
||||
id="dept-api-001",
|
||||
name="API Test Department",
|
||||
)
|
||||
db.add(dept)
|
||||
|
||||
# Create space
|
||||
space = Space(
|
||||
id="space-api-001",
|
||||
name="API Test Space",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
is_active=True,
|
||||
)
|
||||
db.add(space)
|
||||
|
||||
# Create projects
|
||||
project1 = Project(
|
||||
id="project-api-001",
|
||||
space_id="space-api-001",
|
||||
title="API Test Project 1",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
department_id="dept-api-001",
|
||||
security_level="department",
|
||||
status="active",
|
||||
)
|
||||
db.add(project1)
|
||||
|
||||
project2 = Project(
|
||||
id="project-api-002",
|
||||
space_id="space-api-001",
|
||||
title="API Test Project 2",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
department_id="dept-api-001",
|
||||
security_level="department",
|
||||
status="active",
|
||||
)
|
||||
db.add(project2)
|
||||
|
||||
# Create task statuses
|
||||
status_todo = TaskStatus(
|
||||
id="status-api-todo",
|
||||
project_id="project-api-001",
|
||||
name="To Do",
|
||||
is_done=False,
|
||||
)
|
||||
db.add(status_todo)
|
||||
|
||||
# Create a task with blocker for project1
|
||||
task = Task(
|
||||
id="task-api-001",
|
||||
project_id="project-api-001",
|
||||
title="API Test Task",
|
||||
status_id="status-api-todo",
|
||||
due_date=datetime.utcnow() - timedelta(days=2), # Overdue
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
is_deleted=False,
|
||||
)
|
||||
db.add(task)
|
||||
|
||||
blocker = Blocker(
|
||||
id="blocker-api-001",
|
||||
task_id="task-api-001",
|
||||
reported_by="00000000-0000-0000-0000-000000000001",
|
||||
reason="Test blocker",
|
||||
resolved_at=None,
|
||||
)
|
||||
db.add(blocker)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"department": dept,
|
||||
"space": space,
|
||||
"project1": project1,
|
||||
"project2": project2,
|
||||
"task": task,
|
||||
"blocker": blocker,
|
||||
}
|
||||
|
||||
def test_get_dashboard(self, client, db, admin_token):
|
||||
"""Admin should be able to get health dashboard."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
response = client.get(
|
||||
"/api/projects/health/dashboard",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
|
||||
assert "projects" in result
|
||||
assert "summary" in result
|
||||
assert result["summary"]["total_projects"] >= 2
|
||||
|
||||
def test_get_dashboard_summary_fields(self, client, db, admin_token):
|
||||
"""Dashboard summary should include all expected fields."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
response = client.get(
|
||||
"/api/projects/health/dashboard",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
summary = response.json()["summary"]
|
||||
|
||||
assert "total_projects" in summary
|
||||
assert "healthy_count" in summary
|
||||
assert "at_risk_count" in summary
|
||||
assert "critical_count" in summary
|
||||
assert "average_health_score" in summary
|
||||
assert "projects_with_blockers" in summary
|
||||
assert "projects_delayed" in summary
|
||||
|
||||
def test_get_project_health(self, client, db, admin_token):
|
||||
"""Admin should be able to get single project health."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
response = client.get(
|
||||
f"/api/projects/health/{data['project1'].id}",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
|
||||
assert result["project_id"] == data["project1"].id
|
||||
assert result["project_title"] == "API Test Project 1"
|
||||
assert "health_score" in result
|
||||
assert "risk_level" in result
|
||||
assert "schedule_status" in result
|
||||
assert "resource_status" in result
|
||||
|
||||
def test_get_project_health_not_found(self, client, db, admin_token):
|
||||
"""Should return 404 for non-existent project."""
|
||||
self.setup_test_data(db)
|
||||
|
||||
response = client.get(
|
||||
"/api/projects/health/non-existent-id",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Project not found"
|
||||
|
||||
def test_get_project_health_with_issues(self, client, db, admin_token):
|
||||
"""Project with issues should have correct metrics."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
response = client.get(
|
||||
f"/api/projects/health/{data['project1'].id}",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
|
||||
# Project1 has 1 overdue task with 1 blocker
|
||||
assert result["blocker_count"] == 1
|
||||
assert result["overdue_task_count"] == 1
|
||||
assert result["health_score"] < 100 # Should be penalized
|
||||
|
||||
def test_unauthorized_access(self, client, db):
|
||||
"""Unauthenticated requests should fail."""
|
||||
response = client.get("/api/projects/health/dashboard")
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_dashboard_with_status_filter(self, client, db, admin_token):
|
||||
"""Dashboard should respect status filter."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
# Create an archived project
|
||||
archived = Project(
|
||||
id="project-archived",
|
||||
space_id="space-api-001",
|
||||
title="Archived Project",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
department_id="dept-api-001",
|
||||
security_level="department",
|
||||
status="archived",
|
||||
)
|
||||
db.add(archived)
|
||||
db.commit()
|
||||
|
||||
# Default filter should exclude archived
|
||||
response = client.get(
|
||||
"/api/projects/health/dashboard",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
project_ids = [p["project_id"] for p in response.json()["projects"]]
|
||||
assert "project-archived" not in project_ids
|
||||
|
||||
def test_project_health_response_structure(self, client, db, admin_token):
|
||||
"""Response should match ProjectHealthWithDetails schema."""
|
||||
data = self.setup_test_data(db)
|
||||
|
||||
response = client.get(
|
||||
f"/api/projects/health/{data['project1'].id}",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
|
||||
# Required fields from schema
|
||||
required_fields = [
|
||||
"id", "project_id", "health_score", "risk_level",
|
||||
"schedule_status", "resource_status", "last_updated",
|
||||
"project_title", "project_status", "task_count",
|
||||
"completed_task_count", "blocker_count", "overdue_task_count"
|
||||
]
|
||||
|
||||
for field in required_fields:
|
||||
assert field in result, f"Missing field: {field}"
|
||||
|
||||
# Check enum values
|
||||
assert result["risk_level"] in ["low", "medium", "high", "critical"]
|
||||
assert result["schedule_status"] in ["on_track", "at_risk", "delayed"]
|
||||
assert result["resource_status"] in ["adequate", "constrained", "overloaded"]
|
||||
124
backend/tests/test_rate_limit.py
Normal file
124
backend/tests/test_rate_limit.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Test suite for rate limiting functionality.
|
||||
|
||||
Tests the rate limiting feature on the login endpoint to ensure
|
||||
protection against brute force attacks.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
from app.services.auth_client import AuthAPIError
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
"""Test rate limiting on the login endpoint."""
|
||||
|
||||
def test_login_rate_limit_exceeded(self, client):
|
||||
"""
|
||||
Test that the login endpoint returns 429 after exceeding rate limit.
|
||||
|
||||
GIVEN a client IP has made 5 login attempts within 1 minute
|
||||
WHEN the client attempts another login
|
||||
THEN the system returns HTTP 429 Too Many Requests
|
||||
AND the response includes a Retry-After header
|
||||
"""
|
||||
# Mock the external auth service to return auth error
|
||||
with patch("app.api.auth.router.verify_credentials", new_callable=AsyncMock) as mock_verify:
|
||||
mock_verify.side_effect = AuthAPIError("Invalid credentials")
|
||||
|
||||
login_data = {"email": "test@example.com", "password": "wrongpassword"}
|
||||
|
||||
# Make 5 requests (the limit)
|
||||
for i in range(5):
|
||||
response = client.post("/api/auth/login", json=login_data)
|
||||
# These should fail due to invalid credentials (401), but not rate limit
|
||||
assert response.status_code == 401, f"Request {i+1} expected 401, got {response.status_code}"
|
||||
|
||||
# The 6th request should be rate limited
|
||||
response = client.post("/api/auth/login", json=login_data)
|
||||
assert response.status_code == 429, f"Expected 429 Too Many Requests, got {response.status_code}"
|
||||
|
||||
# Response should contain error details
|
||||
data = response.json()
|
||||
assert "error" in data or "detail" in data, "Response should contain error details"
|
||||
|
||||
def test_login_within_rate_limit(self, client):
|
||||
"""
|
||||
Test that requests within the rate limit are allowed.
|
||||
|
||||
GIVEN a client IP has not exceeded the rate limit
|
||||
WHEN the client makes login requests
|
||||
THEN the requests are processed normally (not rate limited)
|
||||
"""
|
||||
with patch("app.api.auth.router.verify_credentials", new_callable=AsyncMock) as mock_verify:
|
||||
mock_verify.side_effect = AuthAPIError("Invalid credentials")
|
||||
|
||||
login_data = {"email": "test@example.com", "password": "wrongpassword"}
|
||||
|
||||
# Make requests within the limit
|
||||
for i in range(3):
|
||||
response = client.post("/api/auth/login", json=login_data)
|
||||
# These should fail due to invalid credentials (401), but not be rate limited
|
||||
assert response.status_code == 401, f"Request {i+1} expected 401, got {response.status_code}"
|
||||
|
||||
def test_rate_limit_response_format(self, client):
|
||||
"""
|
||||
Test that the 429 response format matches API standards.
|
||||
|
||||
GIVEN the rate limit has been exceeded
|
||||
WHEN the client receives a 429 response
|
||||
THEN the response body contains appropriate error information
|
||||
"""
|
||||
with patch("app.api.auth.router.verify_credentials", new_callable=AsyncMock) as mock_verify:
|
||||
mock_verify.side_effect = AuthAPIError("Invalid credentials")
|
||||
|
||||
login_data = {"email": "test@example.com", "password": "wrongpassword"}
|
||||
|
||||
# Exhaust the rate limit
|
||||
for _ in range(5):
|
||||
client.post("/api/auth/login", json=login_data)
|
||||
|
||||
# The next request should be rate limited
|
||||
response = client.post("/api/auth/login", json=login_data)
|
||||
|
||||
assert response.status_code == 429
|
||||
|
||||
# Check response body contains error information
|
||||
data = response.json()
|
||||
assert "error" in data or "detail" in data, "Response should contain error details"
|
||||
|
||||
|
||||
class TestRateLimiterConfiguration:
|
||||
"""Test rate limiter configuration."""
|
||||
|
||||
def test_limiter_uses_redis_storage(self):
|
||||
"""
|
||||
Test that the limiter is configured with Redis storage.
|
||||
|
||||
GIVEN the rate limiter configuration
|
||||
WHEN we inspect the storage URI
|
||||
THEN it should be configured to use Redis
|
||||
"""
|
||||
from app.core.rate_limiter import limiter
|
||||
from app.core.config import settings
|
||||
|
||||
# The limiter should be configured
|
||||
assert limiter is not None
|
||||
|
||||
# Verify Redis URL is properly configured
|
||||
assert settings.REDIS_URL.startswith("redis://")
|
||||
|
||||
def test_limiter_uses_remote_address_key(self):
|
||||
"""
|
||||
Test that the limiter uses client IP as the key.
|
||||
|
||||
GIVEN the rate limiter configuration
|
||||
WHEN we check the key function
|
||||
THEN it should use get_remote_address
|
||||
"""
|
||||
from app.core.rate_limiter import limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
# The key function should be get_remote_address
|
||||
assert limiter._key_func == get_remote_address
|
||||
664
backend/tests/test_schedule_triggers.py
Normal file
664
backend/tests/test_schedule_triggers.py
Normal file
@@ -0,0 +1,664 @@
|
||||
"""
|
||||
Tests for Schedule Triggers functionality.
|
||||
|
||||
This module tests:
|
||||
- Cron expression parsing and validation
|
||||
- Deadline reminder logic
|
||||
- Schedule trigger execution
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from app.models import User, Space, Project, Task, TaskStatus, Trigger, TriggerLog, Notification
|
||||
from app.services.trigger_scheduler import TriggerSchedulerService
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(db):
|
||||
"""Create a test user."""
|
||||
user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email="scheduleuser@example.com",
|
||||
name="Schedule Test User",
|
||||
role_id="00000000-0000-0000-0000-000000000003",
|
||||
is_active=True,
|
||||
is_system_admin=False,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_token(client, mock_redis, test_user):
|
||||
"""Get a token for test user."""
|
||||
from app.core.security import create_access_token, create_token_payload
|
||||
|
||||
token_data = create_token_payload(
|
||||
user_id=test_user.id,
|
||||
email=test_user.email,
|
||||
role="engineer",
|
||||
department_id=None,
|
||||
is_system_admin=False,
|
||||
)
|
||||
token = create_access_token(token_data)
|
||||
mock_redis.setex(f"session:{test_user.id}", 900, token)
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_space(db, test_user):
|
||||
"""Create a test space."""
|
||||
space = Space(
|
||||
id=str(uuid.uuid4()),
|
||||
name="Schedule Test Space",
|
||||
description="Test space for schedule triggers",
|
||||
owner_id=test_user.id,
|
||||
)
|
||||
db.add(space)
|
||||
db.commit()
|
||||
return space
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_project(db, test_space, test_user):
|
||||
"""Create a test project."""
|
||||
project = Project(
|
||||
id=str(uuid.uuid4()),
|
||||
space_id=test_space.id,
|
||||
title="Schedule Test Project",
|
||||
description="Test project for schedule triggers",
|
||||
owner_id=test_user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
return project
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_status(db, test_project):
|
||||
"""Create test task statuses."""
|
||||
status = TaskStatus(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=test_project.id,
|
||||
name="To Do",
|
||||
color="#808080",
|
||||
position=0,
|
||||
)
|
||||
db.add(status)
|
||||
db.commit()
|
||||
return status
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cron_trigger(db, test_project, test_user):
|
||||
"""Create a cron-based schedule trigger."""
|
||||
trigger = Trigger(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=test_project.id,
|
||||
name="Daily Reminder",
|
||||
description="Daily reminder at 9am",
|
||||
trigger_type="schedule",
|
||||
conditions={
|
||||
"cron_expression": "0 9 * * *", # Every day at 9am
|
||||
},
|
||||
actions=[{
|
||||
"type": "notify",
|
||||
"target": "project_owner",
|
||||
"template": "Daily scheduled trigger fired for {project_name}",
|
||||
}],
|
||||
is_active=True,
|
||||
created_by=test_user.id,
|
||||
)
|
||||
db.add(trigger)
|
||||
db.commit()
|
||||
return trigger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deadline_trigger(db, test_project, test_user):
|
||||
"""Create a deadline reminder trigger."""
|
||||
trigger = Trigger(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=test_project.id,
|
||||
name="Deadline Reminder",
|
||||
description="Remind 3 days before deadline",
|
||||
trigger_type="schedule",
|
||||
conditions={
|
||||
"deadline_reminder_days": 3,
|
||||
},
|
||||
actions=[{
|
||||
"type": "notify",
|
||||
"target": "assignee",
|
||||
"template": "Task '{task_title}' is due in {reminder_days} days",
|
||||
}],
|
||||
is_active=True,
|
||||
created_by=test_user.id,
|
||||
)
|
||||
db.add(trigger)
|
||||
db.commit()
|
||||
return trigger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_with_deadline(db, test_project, test_user, test_status):
|
||||
"""Create a task with a deadline 3 days from now."""
|
||||
due_date = datetime.now(timezone.utc) + timedelta(days=3)
|
||||
task = Task(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=test_project.id,
|
||||
title="Task with Deadline",
|
||||
description="This task has a deadline",
|
||||
status_id=test_status.id,
|
||||
created_by=test_user.id,
|
||||
assignee_id=test_user.id,
|
||||
due_date=due_date,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
return task
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests: Cron Expression Parsing
|
||||
# ============================================================================
|
||||
|
||||
class TestCronExpressionParsing:
|
||||
"""Tests for cron expression parsing and validation."""
|
||||
|
||||
def test_parse_valid_cron_expression(self):
|
||||
"""Test parsing a valid cron expression."""
|
||||
is_valid, error = TriggerSchedulerService.parse_cron_expression("0 9 * * 1")
|
||||
assert is_valid is True
|
||||
assert error is None
|
||||
|
||||
def test_parse_valid_cron_every_minute(self):
|
||||
"""Test parsing every minute cron expression."""
|
||||
is_valid, error = TriggerSchedulerService.parse_cron_expression("* * * * *")
|
||||
assert is_valid is True
|
||||
|
||||
def test_parse_valid_cron_weekdays(self):
|
||||
"""Test parsing weekdays-only cron expression."""
|
||||
is_valid, error = TriggerSchedulerService.parse_cron_expression("0 9 * * 1-5")
|
||||
assert is_valid is True
|
||||
|
||||
def test_parse_valid_cron_monthly(self):
|
||||
"""Test parsing monthly cron expression."""
|
||||
is_valid, error = TriggerSchedulerService.parse_cron_expression("0 0 1 * *")
|
||||
assert is_valid is True
|
||||
|
||||
def test_parse_invalid_cron_expression(self):
|
||||
"""Test parsing an invalid cron expression."""
|
||||
is_valid, error = TriggerSchedulerService.parse_cron_expression("invalid")
|
||||
assert is_valid is False
|
||||
assert error is not None
|
||||
assert "Invalid cron expression" in error
|
||||
|
||||
def test_parse_invalid_cron_too_many_fields(self):
|
||||
"""Test parsing cron with too many fields."""
|
||||
is_valid, error = TriggerSchedulerService.parse_cron_expression("0 0 0 0 0 0 0")
|
||||
assert is_valid is False
|
||||
|
||||
def test_parse_invalid_cron_bad_range(self):
|
||||
"""Test parsing cron with invalid range."""
|
||||
is_valid, error = TriggerSchedulerService.parse_cron_expression("0 25 * * *")
|
||||
assert is_valid is False
|
||||
|
||||
def test_get_next_run_time(self):
|
||||
"""Test getting next run time from cron expression."""
|
||||
base_time = datetime(2025, 1, 1, 8, 0, 0, tzinfo=timezone.utc)
|
||||
next_time = TriggerSchedulerService.get_next_run_time("0 9 * * *", base_time)
|
||||
|
||||
assert next_time is not None
|
||||
assert next_time.hour == 9
|
||||
assert next_time.minute == 0
|
||||
|
||||
def test_get_previous_run_time(self):
|
||||
"""Test getting previous run time from cron expression."""
|
||||
base_time = datetime(2025, 1, 1, 10, 0, 0, tzinfo=timezone.utc)
|
||||
prev_time = TriggerSchedulerService.get_previous_run_time("0 9 * * *", base_time)
|
||||
|
||||
assert prev_time is not None
|
||||
assert prev_time.hour == 9
|
||||
assert prev_time.minute == 0
|
||||
|
||||
def test_get_next_run_time_invalid_cron(self):
|
||||
"""Test getting next run time with invalid cron returns None."""
|
||||
result = TriggerSchedulerService.get_next_run_time("invalid")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests: Schedule Trigger Should Fire Logic
|
||||
# ============================================================================
|
||||
|
||||
class TestScheduleTriggerShouldFire:
|
||||
"""Tests for schedule trigger firing logic."""
|
||||
|
||||
def test_should_trigger_within_window(self, db, cron_trigger):
|
||||
"""Test trigger should fire when within execution window."""
|
||||
# Set current time to just after scheduled time
|
||||
scheduled_time = datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||||
current_time = scheduled_time + timedelta(minutes=2)
|
||||
|
||||
result = TriggerSchedulerService.should_trigger(
|
||||
cron_trigger, current_time, last_execution_time=None
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_should_not_trigger_outside_window(self, db, cron_trigger):
|
||||
"""Test trigger should not fire when outside execution window."""
|
||||
# Set current time to well after scheduled time (more than 5 minutes)
|
||||
scheduled_time = datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||||
current_time = scheduled_time + timedelta(minutes=10)
|
||||
|
||||
result = TriggerSchedulerService.should_trigger(
|
||||
cron_trigger, current_time, last_execution_time=None
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_should_not_trigger_if_already_executed(self, db, cron_trigger):
|
||||
"""Test trigger should not fire if already executed after last schedule."""
|
||||
scheduled_time = datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||||
current_time = scheduled_time + timedelta(minutes=2)
|
||||
last_execution = scheduled_time + timedelta(minutes=1)
|
||||
|
||||
result = TriggerSchedulerService.should_trigger(
|
||||
cron_trigger, current_time, last_execution_time=last_execution
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_should_trigger_if_new_schedule_since_last_execution(self, db, cron_trigger):
|
||||
"""Test trigger should fire if a new schedule time has passed since last execution."""
|
||||
# Last execution was yesterday at 9:01
|
||||
last_execution = datetime(2025, 1, 1, 9, 1, 0, tzinfo=timezone.utc)
|
||||
# Current time is today at 9:02 (new schedule at 9:00 passed)
|
||||
current_time = datetime(2025, 1, 2, 9, 2, 0, tzinfo=timezone.utc)
|
||||
|
||||
result = TriggerSchedulerService.should_trigger(
|
||||
cron_trigger, current_time, last_execution_time=last_execution
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_should_not_trigger_inactive(self, db, cron_trigger):
|
||||
"""Test inactive trigger should not fire."""
|
||||
cron_trigger.is_active = False
|
||||
db.commit()
|
||||
|
||||
current_time = datetime(2025, 1, 1, 9, 1, 0, tzinfo=timezone.utc)
|
||||
result = TriggerSchedulerService.should_trigger(
|
||||
cron_trigger, current_time, last_execution_time=None
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_should_not_trigger_field_change_type(self, db, test_project, test_user):
|
||||
"""Test field_change trigger type should not be evaluated as schedule trigger."""
|
||||
trigger = Trigger(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=test_project.id,
|
||||
name="Field Change Trigger",
|
||||
trigger_type="field_change",
|
||||
conditions={
|
||||
"field": "status_id",
|
||||
"operator": "equals",
|
||||
"value": "some-id",
|
||||
},
|
||||
actions=[{"type": "notify"}],
|
||||
is_active=True,
|
||||
created_by=test_user.id,
|
||||
)
|
||||
db.add(trigger)
|
||||
db.commit()
|
||||
|
||||
result = TriggerSchedulerService.should_trigger(
|
||||
trigger, datetime.now(timezone.utc), last_execution_time=None
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests: Deadline Reminder Logic
|
||||
# ============================================================================
|
||||
|
||||
class TestDeadlineReminderLogic:
|
||||
"""Tests for deadline reminder functionality."""
|
||||
|
||||
def test_deadline_reminder_finds_matching_tasks(
|
||||
self, db, deadline_trigger, task_with_deadline, test_user
|
||||
):
|
||||
"""Test that deadline reminder finds tasks due in N days."""
|
||||
# Execute deadline reminders
|
||||
logs = TriggerSchedulerService.execute_deadline_reminders(db)
|
||||
db.commit()
|
||||
|
||||
assert len(logs) == 1
|
||||
assert logs[0].status == "success"
|
||||
assert logs[0].task_id == task_with_deadline.id
|
||||
assert logs[0].details["trigger_type"] == "deadline_reminder"
|
||||
assert logs[0].details["reminder_days"] == 3
|
||||
|
||||
def test_deadline_reminder_creates_notification(
|
||||
self, db, deadline_trigger, task_with_deadline, test_user
|
||||
):
|
||||
"""Test that deadline reminder creates a notification."""
|
||||
logs = TriggerSchedulerService.execute_deadline_reminders(db)
|
||||
db.commit()
|
||||
|
||||
# Check notification was created
|
||||
notifications = db.query(Notification).filter(
|
||||
Notification.user_id == test_user.id,
|
||||
Notification.type == "deadline_reminder",
|
||||
).all()
|
||||
|
||||
assert len(notifications) == 1
|
||||
assert task_with_deadline.title in notifications[0].message
|
||||
|
||||
def test_deadline_reminder_only_sends_once(
|
||||
self, db, deadline_trigger, task_with_deadline
|
||||
):
|
||||
"""Test that deadline reminder only sends once per task per trigger."""
|
||||
# First execution
|
||||
logs1 = TriggerSchedulerService.execute_deadline_reminders(db)
|
||||
db.commit()
|
||||
assert len(logs1) == 1
|
||||
|
||||
# Second execution should not send again
|
||||
logs2 = TriggerSchedulerService.execute_deadline_reminders(db)
|
||||
db.commit()
|
||||
assert len(logs2) == 0
|
||||
|
||||
def test_deadline_reminder_ignores_deleted_tasks(
|
||||
self, db, deadline_trigger, task_with_deadline
|
||||
):
|
||||
"""Test that deadline reminder ignores soft-deleted tasks."""
|
||||
task_with_deadline.is_deleted = True
|
||||
db.commit()
|
||||
|
||||
logs = TriggerSchedulerService.execute_deadline_reminders(db)
|
||||
assert len(logs) == 0
|
||||
|
||||
def test_deadline_reminder_ignores_tasks_without_due_date(
|
||||
self, db, deadline_trigger, test_project, test_user, test_status
|
||||
):
|
||||
"""Test that deadline reminder ignores tasks without due dates."""
|
||||
task = Task(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=test_project.id,
|
||||
title="No Deadline Task",
|
||||
status_id=test_status.id,
|
||||
created_by=test_user.id,
|
||||
due_date=None,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
|
||||
logs = TriggerSchedulerService.execute_deadline_reminders(db)
|
||||
assert len(logs) == 0
|
||||
|
||||
def test_deadline_reminder_different_reminder_days(
|
||||
self, db, test_project, test_user, test_status
|
||||
):
|
||||
"""Test deadline reminder with different reminder days configuration."""
|
||||
# Create a trigger for 7 days reminder
|
||||
trigger = Trigger(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=test_project.id,
|
||||
name="7 Day Reminder",
|
||||
trigger_type="schedule",
|
||||
conditions={"deadline_reminder_days": 7},
|
||||
actions=[{"type": "notify", "target": "assignee"}],
|
||||
is_active=True,
|
||||
created_by=test_user.id,
|
||||
)
|
||||
db.add(trigger)
|
||||
|
||||
# Create a task due in 7 days
|
||||
task = Task(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=test_project.id,
|
||||
title="Task Due in 7 Days",
|
||||
status_id=test_status.id,
|
||||
created_by=test_user.id,
|
||||
assignee_id=test_user.id,
|
||||
due_date=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
|
||||
logs = TriggerSchedulerService.execute_deadline_reminders(db)
|
||||
db.commit()
|
||||
|
||||
assert len(logs) == 1
|
||||
assert logs[0].details["reminder_days"] == 7
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests: Schedule Trigger API
|
||||
# ============================================================================
|
||||
|
||||
class TestScheduleTriggerAPI:
|
||||
"""Tests for Schedule Trigger API endpoints."""
|
||||
|
||||
def test_create_cron_trigger(self, client, test_user_token, test_project):
|
||||
"""Test creating a schedule trigger with cron expression."""
|
||||
response = client.post(
|
||||
f"/api/projects/{test_project.id}/triggers",
|
||||
headers={"Authorization": f"Bearer {test_user_token}"},
|
||||
json={
|
||||
"name": "Weekly Monday Reminder",
|
||||
"description": "Remind every Monday at 9am",
|
||||
"trigger_type": "schedule",
|
||||
"conditions": {
|
||||
"cron_expression": "0 9 * * 1",
|
||||
},
|
||||
"actions": [{
|
||||
"type": "notify",
|
||||
"target": "project_owner",
|
||||
"template": "Weekly reminder for {project_name}",
|
||||
}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "Weekly Monday Reminder"
|
||||
assert data["trigger_type"] == "schedule"
|
||||
assert data["conditions"]["cron_expression"] == "0 9 * * 1"
|
||||
|
||||
def test_create_deadline_trigger(self, client, test_user_token, test_project):
|
||||
"""Test creating a schedule trigger with deadline reminder."""
|
||||
response = client.post(
|
||||
f"/api/projects/{test_project.id}/triggers",
|
||||
headers={"Authorization": f"Bearer {test_user_token}"},
|
||||
json={
|
||||
"name": "Deadline Reminder",
|
||||
"description": "Remind 5 days before deadline",
|
||||
"trigger_type": "schedule",
|
||||
"conditions": {
|
||||
"deadline_reminder_days": 5,
|
||||
},
|
||||
"actions": [{
|
||||
"type": "notify",
|
||||
"target": "assignee",
|
||||
}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["conditions"]["deadline_reminder_days"] == 5
|
||||
|
||||
def test_create_schedule_trigger_invalid_cron(self, client, test_user_token, test_project):
|
||||
"""Test creating a schedule trigger with invalid cron expression."""
|
||||
response = client.post(
|
||||
f"/api/projects/{test_project.id}/triggers",
|
||||
headers={"Authorization": f"Bearer {test_user_token}"},
|
||||
json={
|
||||
"name": "Invalid Cron Trigger",
|
||||
"trigger_type": "schedule",
|
||||
"conditions": {
|
||||
"cron_expression": "invalid cron",
|
||||
},
|
||||
"actions": [{"type": "notify"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid cron expression" in response.json()["detail"]
|
||||
|
||||
def test_create_schedule_trigger_missing_condition(self, client, test_user_token, test_project):
|
||||
"""Test creating a schedule trigger without cron or deadline condition."""
|
||||
response = client.post(
|
||||
f"/api/projects/{test_project.id}/triggers",
|
||||
headers={"Authorization": f"Bearer {test_user_token}"},
|
||||
json={
|
||||
"name": "Empty Schedule Trigger",
|
||||
"trigger_type": "schedule",
|
||||
"conditions": {},
|
||||
"actions": [{"type": "notify"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "require either cron_expression or deadline_reminder_days" in response.json()["detail"]
|
||||
|
||||
def test_update_schedule_trigger_cron(self, client, test_user_token, cron_trigger):
|
||||
"""Test updating a schedule trigger's cron expression."""
|
||||
response = client.put(
|
||||
f"/api/triggers/{cron_trigger.id}",
|
||||
headers={"Authorization": f"Bearer {test_user_token}"},
|
||||
json={
|
||||
"conditions": {
|
||||
"cron_expression": "0 10 * * *", # Changed to 10am
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["conditions"]["cron_expression"] == "0 10 * * *"
|
||||
|
||||
def test_update_schedule_trigger_invalid_cron(self, client, test_user_token, cron_trigger):
|
||||
"""Test updating a schedule trigger with invalid cron expression."""
|
||||
response = client.put(
|
||||
f"/api/triggers/{cron_trigger.id}",
|
||||
headers={"Authorization": f"Bearer {test_user_token}"},
|
||||
json={
|
||||
"conditions": {
|
||||
"cron_expression": "not valid",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid cron expression" in response.json()["detail"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests: Integration - Schedule Trigger Execution
|
||||
# ============================================================================
|
||||
|
||||
class TestScheduleTriggerExecution:
|
||||
"""Integration tests for schedule trigger execution."""
|
||||
|
||||
def test_execute_scheduled_triggers(self, db, cron_trigger, test_user):
|
||||
"""Test executing scheduled triggers creates logs."""
|
||||
# Manually set conditions to trigger execution
|
||||
# Create a log entry as if it was executed before
|
||||
# The trigger should not fire again immediately
|
||||
|
||||
# First, verify no logs exist
|
||||
logs_before = db.query(TriggerLog).filter(
|
||||
TriggerLog.trigger_id == cron_trigger.id
|
||||
).all()
|
||||
assert len(logs_before) == 0
|
||||
|
||||
def test_evaluate_schedule_triggers_combined(
|
||||
self, db, cron_trigger, deadline_trigger, task_with_deadline
|
||||
):
|
||||
"""Test that evaluate_schedule_triggers runs both cron and deadline triggers."""
|
||||
# Note: This test verifies the combined execution method exists and works
|
||||
# The actual execution depends on timing, so we mainly test structure
|
||||
|
||||
# Execute the combined evaluation
|
||||
logs = TriggerSchedulerService.evaluate_schedule_triggers(db)
|
||||
|
||||
# Should have deadline reminder executed
|
||||
deadline_logs = [l for l in logs if l.details and l.details.get("trigger_type") == "deadline_reminder"]
|
||||
assert len(deadline_logs) == 1
|
||||
|
||||
def test_trigger_log_details(self, db, deadline_trigger, task_with_deadline):
|
||||
"""Test that trigger logs contain proper details."""
|
||||
logs = TriggerSchedulerService.execute_deadline_reminders(db)
|
||||
db.commit()
|
||||
|
||||
assert len(logs) == 1
|
||||
log = logs[0]
|
||||
|
||||
assert log.trigger_id == deadline_trigger.id
|
||||
assert log.task_id == task_with_deadline.id
|
||||
assert log.status == "success"
|
||||
assert log.details is not None
|
||||
assert log.details["trigger_name"] == deadline_trigger.name
|
||||
assert log.details["task_title"] == task_with_deadline.title
|
||||
assert "due_date" in log.details
|
||||
|
||||
def test_inactive_trigger_not_executed(self, db, deadline_trigger, task_with_deadline):
|
||||
"""Test that inactive triggers are not executed."""
|
||||
deadline_trigger.is_active = False
|
||||
db.commit()
|
||||
|
||||
logs = TriggerSchedulerService.execute_deadline_reminders(db)
|
||||
assert len(logs) == 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests: Template Formatting
|
||||
# ============================================================================
|
||||
|
||||
class TestTemplateFormatting:
|
||||
"""Tests for message template formatting."""
|
||||
|
||||
def test_format_deadline_template_basic(
|
||||
self, db, deadline_trigger, task_with_deadline
|
||||
):
|
||||
"""Test basic deadline template formatting."""
|
||||
template = "Task '{task_title}' is due in {reminder_days} days"
|
||||
result = TriggerSchedulerService._format_deadline_template(
|
||||
template, deadline_trigger, task_with_deadline, 3
|
||||
)
|
||||
|
||||
assert task_with_deadline.title in result
|
||||
assert "3" in result
|
||||
|
||||
def test_format_deadline_template_all_variables(
|
||||
self, db, deadline_trigger, task_with_deadline
|
||||
):
|
||||
"""Test template with all available variables."""
|
||||
template = (
|
||||
"Trigger: {trigger_name}, Task: {task_title}, "
|
||||
"Due: {due_date}, Days: {reminder_days}, Project: {project_name}"
|
||||
)
|
||||
result = TriggerSchedulerService._format_deadline_template(
|
||||
template, deadline_trigger, task_with_deadline, 3
|
||||
)
|
||||
|
||||
assert deadline_trigger.name in result
|
||||
assert task_with_deadline.title in result
|
||||
assert "3" in result
|
||||
|
||||
def test_format_scheduled_trigger_template(self, db, cron_trigger):
|
||||
"""Test scheduled trigger template formatting."""
|
||||
template = "Trigger '{trigger_name}' fired for project '{project_name}'"
|
||||
result = TriggerSchedulerService._format_template(
|
||||
template, cron_trigger, cron_trigger.project
|
||||
)
|
||||
|
||||
assert cron_trigger.name in result
|
||||
assert cron_trigger.project.title in result
|
||||
@@ -93,6 +93,263 @@ class TestUserEndpoints:
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestCapacityUpdate:
|
||||
"""Test user capacity update API endpoint."""
|
||||
|
||||
def test_update_own_capacity(self, client, db, mock_redis):
|
||||
"""Test that a user can update their own capacity."""
|
||||
from app.core.security import create_access_token, create_token_payload
|
||||
|
||||
# Create a test user
|
||||
test_user = User(
|
||||
id="capacity-user-001",
|
||||
email="capacityuser@example.com",
|
||||
name="Capacity User",
|
||||
is_active=True,
|
||||
capacity=40.00,
|
||||
)
|
||||
db.add(test_user)
|
||||
db.commit()
|
||||
|
||||
# Create token for the user
|
||||
token_data = create_token_payload(
|
||||
user_id="capacity-user-001",
|
||||
email="capacityuser@example.com",
|
||||
role="engineer",
|
||||
department_id=None,
|
||||
is_system_admin=False,
|
||||
)
|
||||
token = create_access_token(token_data)
|
||||
mock_redis.setex("session:capacity-user-001", 900, token)
|
||||
|
||||
# Update own capacity
|
||||
response = client.put(
|
||||
"/api/users/capacity-user-001/capacity",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
json={"capacity_hours": 35.5},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert float(data["capacity"]) == 35.5
|
||||
|
||||
def test_admin_can_update_other_user_capacity(self, client, admin_token, db):
|
||||
"""Test that admin can update another user's capacity."""
|
||||
# Create a test user
|
||||
test_user = User(
|
||||
id="capacity-user-002",
|
||||
email="capacityuser2@example.com",
|
||||
name="Capacity User 2",
|
||||
is_active=True,
|
||||
capacity=40.00,
|
||||
)
|
||||
db.add(test_user)
|
||||
db.commit()
|
||||
|
||||
# Admin updates another user's capacity
|
||||
response = client.put(
|
||||
"/api/users/capacity-user-002/capacity",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
json={"capacity_hours": 20.0},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert float(data["capacity"]) == 20.0
|
||||
|
||||
def test_non_admin_cannot_update_other_user_capacity(self, client, db, mock_redis):
|
||||
"""Test that a non-admin user cannot update another user's capacity."""
|
||||
from app.core.security import create_access_token, create_token_payload
|
||||
|
||||
# Create two test users
|
||||
user1 = User(
|
||||
id="capacity-user-003",
|
||||
email="capacityuser3@example.com",
|
||||
name="Capacity User 3",
|
||||
is_active=True,
|
||||
capacity=40.00,
|
||||
)
|
||||
user2 = User(
|
||||
id="capacity-user-004",
|
||||
email="capacityuser4@example.com",
|
||||
name="Capacity User 4",
|
||||
is_active=True,
|
||||
capacity=40.00,
|
||||
)
|
||||
db.add_all([user1, user2])
|
||||
db.commit()
|
||||
|
||||
# Create token for user1
|
||||
token_data = create_token_payload(
|
||||
user_id="capacity-user-003",
|
||||
email="capacityuser3@example.com",
|
||||
role="engineer",
|
||||
department_id=None,
|
||||
is_system_admin=False,
|
||||
)
|
||||
token = create_access_token(token_data)
|
||||
mock_redis.setex("session:capacity-user-003", 900, token)
|
||||
|
||||
# User1 tries to update user2's capacity - should fail
|
||||
response = client.put(
|
||||
"/api/users/capacity-user-004/capacity",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
json={"capacity_hours": 30.0},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
assert "Only admin, manager, or the user themselves" in response.json()["detail"]
|
||||
|
||||
def test_update_capacity_invalid_value_negative(self, client, admin_token, db):
|
||||
"""Test that negative capacity hours are rejected."""
|
||||
# Create a test user
|
||||
test_user = User(
|
||||
id="capacity-user-005",
|
||||
email="capacityuser5@example.com",
|
||||
name="Capacity User 5",
|
||||
is_active=True,
|
||||
capacity=40.00,
|
||||
)
|
||||
db.add(test_user)
|
||||
db.commit()
|
||||
|
||||
response = client.put(
|
||||
"/api/users/capacity-user-005/capacity",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
json={"capacity_hours": -5.0},
|
||||
)
|
||||
# Pydantic validation returns 422 Unprocessable Entity
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()["detail"]
|
||||
# Check validation error message in Pydantic format
|
||||
assert any("non-negative" in str(err).lower() for err in error_detail)
|
||||
|
||||
def test_update_capacity_invalid_value_too_high(self, client, admin_token, db):
|
||||
"""Test that capacity hours exceeding 168 are rejected."""
|
||||
# Create a test user
|
||||
test_user = User(
|
||||
id="capacity-user-006",
|
||||
email="capacityuser6@example.com",
|
||||
name="Capacity User 6",
|
||||
is_active=True,
|
||||
capacity=40.00,
|
||||
)
|
||||
db.add(test_user)
|
||||
db.commit()
|
||||
|
||||
response = client.put(
|
||||
"/api/users/capacity-user-006/capacity",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
json={"capacity_hours": 200.0},
|
||||
)
|
||||
# Pydantic validation returns 422 Unprocessable Entity
|
||||
assert response.status_code == 422
|
||||
error_detail = response.json()["detail"]
|
||||
# Check validation error message in Pydantic format
|
||||
assert any("168" in str(err) for err in error_detail)
|
||||
|
||||
def test_update_capacity_nonexistent_user(self, client, admin_token):
|
||||
"""Test updating capacity for a nonexistent user."""
|
||||
response = client.put(
|
||||
"/api/users/nonexistent-user-id/capacity",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
json={"capacity_hours": 40.0},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
assert "User not found" in response.json()["detail"]
|
||||
|
||||
def test_manager_can_update_other_user_capacity(self, client, db, mock_redis):
|
||||
"""Test that manager can update another user's capacity."""
|
||||
from app.core.security import create_access_token, create_token_payload
|
||||
from app.models.role import Role
|
||||
|
||||
# Create manager role if not exists
|
||||
manager_role = db.query(Role).filter(Role.name == "manager").first()
|
||||
if not manager_role:
|
||||
manager_role = Role(
|
||||
id="manager-role-cap",
|
||||
name="manager",
|
||||
permissions={"users.read": True, "users.write": True},
|
||||
)
|
||||
db.add(manager_role)
|
||||
db.commit()
|
||||
|
||||
# Create a manager user
|
||||
manager_user = User(
|
||||
id="manager-cap-001",
|
||||
email="managercap@example.com",
|
||||
name="Manager Cap",
|
||||
role_id=manager_role.id,
|
||||
is_active=True,
|
||||
is_system_admin=False,
|
||||
)
|
||||
# Create a regular user
|
||||
regular_user = User(
|
||||
id="regular-cap-001",
|
||||
email="regularcap@example.com",
|
||||
name="Regular Cap",
|
||||
is_active=True,
|
||||
capacity=40.00,
|
||||
)
|
||||
db.add_all([manager_user, regular_user])
|
||||
db.commit()
|
||||
|
||||
# Create token for manager
|
||||
token_data = create_token_payload(
|
||||
user_id="manager-cap-001",
|
||||
email="managercap@example.com",
|
||||
role="manager",
|
||||
department_id=None,
|
||||
is_system_admin=False,
|
||||
)
|
||||
token = create_access_token(token_data)
|
||||
mock_redis.setex("session:manager-cap-001", 900, token)
|
||||
|
||||
# Manager updates regular user's capacity
|
||||
response = client.put(
|
||||
"/api/users/regular-cap-001/capacity",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
json={"capacity_hours": 30.0},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert float(data["capacity"]) == 30.0
|
||||
|
||||
def test_capacity_change_creates_audit_log(self, client, admin_token, db):
|
||||
"""Test that capacity changes are recorded in audit trail."""
|
||||
from app.models import AuditLog
|
||||
|
||||
# Create a test user
|
||||
test_user = User(
|
||||
id="capacity-audit-001",
|
||||
email="capacityaudit@example.com",
|
||||
name="Capacity Audit User",
|
||||
is_active=True,
|
||||
capacity=40.00,
|
||||
)
|
||||
db.add(test_user)
|
||||
db.commit()
|
||||
|
||||
# Update capacity
|
||||
response = client.put(
|
||||
"/api/users/capacity-audit-001/capacity",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
json={"capacity_hours": 35.0},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Check audit log was created
|
||||
audit_log = db.query(AuditLog).filter(
|
||||
AuditLog.resource_id == "capacity-audit-001",
|
||||
AuditLog.event_type == "user.capacity_change"
|
||||
).first()
|
||||
|
||||
assert audit_log is not None
|
||||
assert audit_log.resource_type == "user"
|
||||
assert audit_log.action == "update"
|
||||
assert len(audit_log.changes) == 1
|
||||
assert audit_log.changes[0]["field"] == "capacity"
|
||||
assert audit_log.changes[0]["old_value"] == 40.0
|
||||
assert audit_log.changes[0]["new_value"] == 35.0
|
||||
|
||||
|
||||
class TestDepartmentIsolation:
|
||||
"""Test department-based access control."""
|
||||
|
||||
|
||||
755
backend/tests/test_watermark.py
Normal file
755
backend/tests/test_watermark.py
Normal file
@@ -0,0 +1,755 @@
|
||||
"""
|
||||
Tests for MED-009: Dynamic Watermark for Downloads
|
||||
|
||||
This module contains unit tests for WatermarkService and
|
||||
integration tests for the download endpoint with watermark functionality.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
import os
|
||||
import io
|
||||
import tempfile
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from app.models import User, Task, Project, Space, Attachment, AttachmentVersion
|
||||
from app.services.watermark_service import WatermarkService, watermark_service
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(db):
|
||||
"""Create a test user for watermark tests."""
|
||||
user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email="watermark.test@example.com",
|
||||
employee_id="EMP-WM001",
|
||||
name="Watermark Tester",
|
||||
role_id="00000000-0000-0000-0000-000000000003",
|
||||
is_active=True,
|
||||
is_system_admin=False,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_token(client, mock_redis, test_user):
|
||||
"""Get a token for test user."""
|
||||
from app.core.security import create_access_token, create_token_payload
|
||||
|
||||
token_data = create_token_payload(
|
||||
user_id=test_user.id,
|
||||
email=test_user.email,
|
||||
role="engineer",
|
||||
department_id=None,
|
||||
is_system_admin=False,
|
||||
)
|
||||
token = create_access_token(token_data)
|
||||
mock_redis.setex(f"session:{test_user.id}", 900, token)
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_space(db, test_user):
|
||||
"""Create a test space."""
|
||||
space = Space(
|
||||
id=str(uuid.uuid4()),
|
||||
name="Watermark Test Space",
|
||||
description="Test space for watermark tests",
|
||||
owner_id=test_user.id,
|
||||
)
|
||||
db.add(space)
|
||||
db.commit()
|
||||
return space
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_project(db, test_space, test_user):
|
||||
"""Create a test project."""
|
||||
project = Project(
|
||||
id=str(uuid.uuid4()),
|
||||
space_id=test_space.id,
|
||||
title="Watermark Test Project",
|
||||
description="Test project for watermark tests",
|
||||
owner_id=test_user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
return project
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_task(db, test_project, test_user):
|
||||
"""Create a test task."""
|
||||
task = Task(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=test_project.id,
|
||||
title="Watermark Test Task",
|
||||
description="Test task for watermark tests",
|
||||
created_by=test_user.id,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_upload_dir():
|
||||
"""Create a temporary upload directory."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_png_bytes():
|
||||
"""Create a sample PNG image as bytes."""
|
||||
img = Image.new("RGB", (200, 200), color=(255, 255, 255))
|
||||
output = io.BytesIO()
|
||||
img.save(output, format="PNG")
|
||||
output.seek(0)
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_jpeg_bytes():
|
||||
"""Create a sample JPEG image as bytes."""
|
||||
img = Image.new("RGB", (200, 200), color=(255, 255, 255))
|
||||
output = io.BytesIO()
|
||||
img.save(output, format="JPEG")
|
||||
output.seek(0)
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pdf_bytes():
|
||||
"""Create a sample PDF as bytes."""
|
||||
from reportlab.lib.pagesizes import letter
|
||||
from reportlab.pdfgen import canvas
|
||||
|
||||
buffer = io.BytesIO()
|
||||
c = canvas.Canvas(buffer, pagesize=letter)
|
||||
c.drawString(100, 750, "Test PDF Document")
|
||||
c.drawString(100, 700, "This is a test page for watermarking.")
|
||||
c.showPage()
|
||||
c.drawString(100, 750, "Page 2")
|
||||
c.drawString(100, 700, "Second page content.")
|
||||
c.showPage()
|
||||
c.save()
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Unit Tests for WatermarkService
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestWatermarkServiceUnit:
|
||||
"""Unit tests for WatermarkService class."""
|
||||
|
||||
def test_format_watermark_text(self):
|
||||
"""Test watermark text formatting with employee_id."""
|
||||
test_time = datetime(2024, 1, 15, 10, 30, 45)
|
||||
text = WatermarkService._format_watermark_text(
|
||||
user_name="John Doe",
|
||||
employee_id="EMP001",
|
||||
download_time=test_time
|
||||
)
|
||||
|
||||
assert "John Doe" in text
|
||||
assert "EMP001" in text
|
||||
assert "2024-01-15 10:30:45" in text
|
||||
assert text == "John Doe (EMP001) - 2024-01-15 10:30:45"
|
||||
|
||||
def test_format_watermark_text_without_employee_id(self):
|
||||
"""Test that watermark text uses N/A when employee_id is not provided."""
|
||||
test_time = datetime(2024, 1, 15, 10, 30, 45)
|
||||
text = WatermarkService._format_watermark_text(
|
||||
user_name="Jane Doe",
|
||||
employee_id=None,
|
||||
download_time=test_time
|
||||
)
|
||||
|
||||
assert "Jane Doe" in text
|
||||
assert "(N/A)" in text
|
||||
assert text == "Jane Doe (N/A) - 2024-01-15 10:30:45"
|
||||
|
||||
def test_format_watermark_text_defaults_to_now(self):
|
||||
"""Test that watermark text defaults to current time."""
|
||||
text = WatermarkService._format_watermark_text(
|
||||
user_name="Jane Doe",
|
||||
employee_id="EMP002"
|
||||
)
|
||||
|
||||
assert "Jane Doe" in text
|
||||
assert "EMP002" in text
|
||||
# Should contain a date-like string
|
||||
assert "-" in text # Date separator
|
||||
|
||||
def test_is_supported_image_png(self):
|
||||
"""Test PNG is recognized as supported image."""
|
||||
service = WatermarkService()
|
||||
assert service.is_supported_image("image/png") is True
|
||||
assert service.is_supported_image("IMAGE/PNG") is True
|
||||
|
||||
def test_is_supported_image_jpeg(self):
|
||||
"""Test JPEG is recognized as supported image."""
|
||||
service = WatermarkService()
|
||||
assert service.is_supported_image("image/jpeg") is True
|
||||
assert service.is_supported_image("image/jpg") is True
|
||||
|
||||
def test_is_supported_image_unsupported(self):
|
||||
"""Test unsupported image formats are rejected."""
|
||||
service = WatermarkService()
|
||||
assert service.is_supported_image("image/gif") is False
|
||||
assert service.is_supported_image("image/bmp") is False
|
||||
assert service.is_supported_image("image/webp") is False
|
||||
|
||||
def test_is_supported_pdf(self):
|
||||
"""Test PDF is recognized."""
|
||||
service = WatermarkService()
|
||||
assert service.is_supported_pdf("application/pdf") is True
|
||||
assert service.is_supported_pdf("APPLICATION/PDF") is True
|
||||
|
||||
def test_is_supported_pdf_negative(self):
|
||||
"""Test non-PDF types are not recognized as PDF."""
|
||||
service = WatermarkService()
|
||||
assert service.is_supported_pdf("application/json") is False
|
||||
assert service.is_supported_pdf("text/plain") is False
|
||||
|
||||
def test_supports_watermark_images(self):
|
||||
"""Test supports_watermark for images."""
|
||||
service = WatermarkService()
|
||||
assert service.supports_watermark("image/png") is True
|
||||
assert service.supports_watermark("image/jpeg") is True
|
||||
|
||||
def test_supports_watermark_pdf(self):
|
||||
"""Test supports_watermark for PDF."""
|
||||
service = WatermarkService()
|
||||
assert service.supports_watermark("application/pdf") is True
|
||||
|
||||
def test_supports_watermark_unsupported(self):
|
||||
"""Test supports_watermark for unsupported types."""
|
||||
service = WatermarkService()
|
||||
assert service.supports_watermark("text/plain") is False
|
||||
assert service.supports_watermark("application/zip") is False
|
||||
assert service.supports_watermark("application/octet-stream") is False
|
||||
|
||||
|
||||
class TestImageWatermarking:
|
||||
"""Unit tests for image watermarking functionality."""
|
||||
|
||||
def test_add_image_watermark_png(self, sample_png_bytes):
|
||||
"""Test adding watermark to PNG image."""
|
||||
test_time = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
result_bytes, output_format = watermark_service.add_image_watermark(
|
||||
image_bytes=sample_png_bytes,
|
||||
user_name="Test User",
|
||||
employee_id="EMP001",
|
||||
download_time=test_time
|
||||
)
|
||||
|
||||
# Verify output is valid image bytes
|
||||
assert len(result_bytes) > 0
|
||||
assert output_format.lower() == "png"
|
||||
|
||||
# Verify output is valid PNG image
|
||||
result_image = Image.open(io.BytesIO(result_bytes))
|
||||
assert result_image.format == "PNG"
|
||||
assert result_image.size == (200, 200)
|
||||
|
||||
def test_add_image_watermark_jpeg(self, sample_jpeg_bytes):
|
||||
"""Test adding watermark to JPEG image."""
|
||||
test_time = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
result_bytes, output_format = watermark_service.add_image_watermark(
|
||||
image_bytes=sample_jpeg_bytes,
|
||||
user_name="Test User",
|
||||
employee_id="EMP001",
|
||||
download_time=test_time
|
||||
)
|
||||
|
||||
# Verify output is valid image bytes
|
||||
assert len(result_bytes) > 0
|
||||
assert output_format.lower() == "jpeg"
|
||||
|
||||
# Verify output is valid JPEG image
|
||||
result_image = Image.open(io.BytesIO(result_bytes))
|
||||
assert result_image.format == "JPEG"
|
||||
assert result_image.size == (200, 200)
|
||||
|
||||
def test_add_image_watermark_preserves_dimensions(self, sample_png_bytes):
|
||||
"""Test that watermarking preserves image dimensions."""
|
||||
original = Image.open(io.BytesIO(sample_png_bytes))
|
||||
original_size = original.size
|
||||
|
||||
result_bytes, _ = watermark_service.add_image_watermark(
|
||||
image_bytes=sample_png_bytes,
|
||||
user_name="Test User",
|
||||
employee_id="EMP001"
|
||||
)
|
||||
|
||||
result = Image.open(io.BytesIO(result_bytes))
|
||||
assert result.size == original_size
|
||||
|
||||
def test_add_image_watermark_modifies_image(self, sample_png_bytes):
|
||||
"""Test that watermark actually modifies the image."""
|
||||
result_bytes, _ = watermark_service.add_image_watermark(
|
||||
image_bytes=sample_png_bytes,
|
||||
user_name="Test User",
|
||||
employee_id="EMP001"
|
||||
)
|
||||
|
||||
# The watermarked image should be different from original
|
||||
# (Note: size might differ slightly due to compression)
|
||||
# We verify the image data is actually different
|
||||
original = Image.open(io.BytesIO(sample_png_bytes))
|
||||
result = Image.open(io.BytesIO(result_bytes))
|
||||
|
||||
# Convert to same mode for comparison
|
||||
original_rgb = original.convert("RGB")
|
||||
result_rgb = result.convert("RGB")
|
||||
|
||||
# Compare pixel data - they should be different
|
||||
original_data = list(original_rgb.getdata())
|
||||
result_data = list(result_rgb.getdata())
|
||||
|
||||
# At least some pixels should be different (watermark added)
|
||||
different_pixels = sum(1 for o, r in zip(original_data, result_data) if o != r)
|
||||
assert different_pixels > 0, "Watermark should modify image pixels"
|
||||
|
||||
def test_add_image_watermark_large_image(self):
|
||||
"""Test watermarking a larger image."""
|
||||
# Create a larger image
|
||||
large_img = Image.new("RGB", (1920, 1080), color=(100, 150, 200))
|
||||
output = io.BytesIO()
|
||||
large_img.save(output, format="PNG")
|
||||
large_bytes = output.getvalue()
|
||||
|
||||
result_bytes, output_format = watermark_service.add_image_watermark(
|
||||
image_bytes=large_bytes,
|
||||
user_name="Large Image User",
|
||||
employee_id="EMP-LARGE"
|
||||
)
|
||||
|
||||
assert len(result_bytes) > 0
|
||||
result_image = Image.open(io.BytesIO(result_bytes))
|
||||
assert result_image.size == (1920, 1080)
|
||||
|
||||
|
||||
class TestPdfWatermarking:
|
||||
"""Unit tests for PDF watermarking functionality."""
|
||||
|
||||
def test_add_pdf_watermark_basic(self, sample_pdf_bytes):
|
||||
"""Test adding watermark to PDF."""
|
||||
import fitz # PyMuPDF
|
||||
|
||||
test_time = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
result_bytes = watermark_service.add_pdf_watermark(
|
||||
pdf_bytes=sample_pdf_bytes,
|
||||
user_name="PDF Test User",
|
||||
employee_id="EMP-PDF001",
|
||||
download_time=test_time
|
||||
)
|
||||
|
||||
# Verify output is valid PDF bytes
|
||||
assert len(result_bytes) > 0
|
||||
|
||||
# Verify output is valid PDF using PyMuPDF
|
||||
result_pdf = fitz.open(stream=result_bytes, filetype="pdf")
|
||||
assert len(result_pdf) == 2
|
||||
result_pdf.close()
|
||||
|
||||
def test_add_pdf_watermark_preserves_page_count(self, sample_pdf_bytes):
|
||||
"""Test that watermarking preserves page count."""
|
||||
import fitz # PyMuPDF
|
||||
|
||||
original_pdf = fitz.open(stream=sample_pdf_bytes, filetype="pdf")
|
||||
original_page_count = len(original_pdf)
|
||||
original_pdf.close()
|
||||
|
||||
result_bytes = watermark_service.add_pdf_watermark(
|
||||
pdf_bytes=sample_pdf_bytes,
|
||||
user_name="Test User",
|
||||
employee_id="EMP001"
|
||||
)
|
||||
|
||||
result_pdf = fitz.open(stream=result_bytes, filetype="pdf")
|
||||
assert len(result_pdf) == original_page_count
|
||||
result_pdf.close()
|
||||
|
||||
def test_add_pdf_watermark_modifies_content(self, sample_pdf_bytes):
|
||||
"""Test that watermark actually modifies the PDF content."""
|
||||
result_bytes = watermark_service.add_pdf_watermark(
|
||||
pdf_bytes=sample_pdf_bytes,
|
||||
user_name="Modified User",
|
||||
employee_id="EMP-MOD"
|
||||
)
|
||||
|
||||
# The watermarked PDF should be different from original
|
||||
assert result_bytes != sample_pdf_bytes
|
||||
|
||||
def test_add_pdf_watermark_single_page(self):
|
||||
"""Test watermarking a single-page PDF."""
|
||||
import fitz # PyMuPDF
|
||||
|
||||
# Create single page PDF with PyMuPDF
|
||||
doc = fitz.open()
|
||||
page = doc.new_page(width=612, height=792) # Letter size
|
||||
page.insert_text(point=(100, 750), text="Single Page Document", fontsize=12)
|
||||
buffer = io.BytesIO()
|
||||
doc.save(buffer)
|
||||
doc.close()
|
||||
single_page_bytes = buffer.getvalue()
|
||||
|
||||
result_bytes = watermark_service.add_pdf_watermark(
|
||||
pdf_bytes=single_page_bytes,
|
||||
user_name="Single Page User",
|
||||
employee_id="EMP-SINGLE"
|
||||
)
|
||||
|
||||
result_pdf = fitz.open(stream=result_bytes, filetype="pdf")
|
||||
assert len(result_pdf) == 1
|
||||
result_pdf.close()
|
||||
|
||||
def test_add_pdf_watermark_many_pages(self):
|
||||
"""Test watermarking a multi-page PDF."""
|
||||
import fitz # PyMuPDF
|
||||
|
||||
# Create multi-page PDF with PyMuPDF
|
||||
doc = fitz.open()
|
||||
for i in range(5):
|
||||
page = doc.new_page(width=612, height=792)
|
||||
page.insert_text(point=(100, 750), text=f"Page {i + 1}", fontsize=12)
|
||||
buffer = io.BytesIO()
|
||||
doc.save(buffer)
|
||||
doc.close()
|
||||
multi_page_bytes = buffer.getvalue()
|
||||
|
||||
result_bytes = watermark_service.add_pdf_watermark(
|
||||
pdf_bytes=multi_page_bytes,
|
||||
user_name="Multi Page User",
|
||||
employee_id="EMP-MULTI"
|
||||
)
|
||||
|
||||
result_pdf = fitz.open(stream=result_bytes, filetype="pdf")
|
||||
assert len(result_pdf) == 5
|
||||
result_pdf.close()
|
||||
|
||||
|
||||
class TestWatermarkServiceConfiguration:
|
||||
"""Tests for WatermarkService configuration constants."""
|
||||
|
||||
def test_default_opacity(self):
|
||||
"""Test default watermark opacity."""
|
||||
assert WatermarkService.WATERMARK_OPACITY == 0.3
|
||||
|
||||
def test_default_angle(self):
|
||||
"""Test default watermark angle."""
|
||||
assert WatermarkService.WATERMARK_ANGLE == -45
|
||||
|
||||
def test_default_font_size(self):
|
||||
"""Test default watermark font size."""
|
||||
assert WatermarkService.WATERMARK_FONT_SIZE == 24
|
||||
|
||||
def test_default_color(self):
|
||||
"""Test default watermark color (gray)."""
|
||||
assert WatermarkService.WATERMARK_COLOR == (128, 128, 128)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests for Download with Watermark
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDownloadWithWatermark:
|
||||
"""Integration tests for download endpoint with watermark."""
|
||||
|
||||
def test_download_png_with_watermark(
|
||||
self, client, test_user_token, test_task, db, monkeypatch, temp_upload_dir, sample_png_bytes
|
||||
):
|
||||
"""Test downloading PNG file applies watermark."""
|
||||
from pathlib import Path
|
||||
from app.services.file_storage_service import file_storage_service
|
||||
monkeypatch.setattr("app.core.config.settings.UPLOAD_DIR", temp_upload_dir)
|
||||
monkeypatch.setattr(file_storage_service, "base_dir", Path(temp_upload_dir))
|
||||
|
||||
# Create attachment and version
|
||||
attachment_id = str(uuid.uuid4())
|
||||
version_id = str(uuid.uuid4())
|
||||
|
||||
# Save the file to disk
|
||||
file_dir = os.path.join(temp_upload_dir, test_task.project_id, test_task.id, attachment_id, "v1")
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
file_path = os.path.join(file_dir, "test.png")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(sample_png_bytes)
|
||||
|
||||
relative_path = os.path.join(test_task.project_id, test_task.id, attachment_id, "v1", "test.png")
|
||||
|
||||
attachment = Attachment(
|
||||
id=attachment_id,
|
||||
task_id=test_task.id,
|
||||
filename="test.png",
|
||||
original_filename="test.png",
|
||||
mime_type="image/png",
|
||||
file_size=len(sample_png_bytes),
|
||||
current_version=1,
|
||||
uploaded_by=test_task.created_by,
|
||||
)
|
||||
db.add(attachment)
|
||||
|
||||
version = AttachmentVersion(
|
||||
id=version_id,
|
||||
attachment_id=attachment_id,
|
||||
version=1,
|
||||
file_path=relative_path,
|
||||
file_size=len(sample_png_bytes),
|
||||
checksum="0" * 64,
|
||||
uploaded_by=test_task.created_by,
|
||||
)
|
||||
db.add(version)
|
||||
db.commit()
|
||||
|
||||
# Download the file
|
||||
response = client.get(
|
||||
f"/api/attachments/{attachment_id}/download",
|
||||
headers={"Authorization": f"Bearer {test_user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "image/png"
|
||||
|
||||
# Verify watermark was applied (image should be different)
|
||||
downloaded_image = Image.open(io.BytesIO(response.content))
|
||||
original_image = Image.open(io.BytesIO(sample_png_bytes))
|
||||
|
||||
# Convert to comparable format
|
||||
downloaded_rgb = downloaded_image.convert("RGB")
|
||||
original_rgb = original_image.convert("RGB")
|
||||
|
||||
downloaded_data = list(downloaded_rgb.getdata())
|
||||
original_data = list(original_rgb.getdata())
|
||||
|
||||
# At least some pixels should be different (watermark present)
|
||||
different_pixels = sum(1 for o, d in zip(original_data, downloaded_data) if o != d)
|
||||
assert different_pixels > 0, "Downloaded image should have watermark"
|
||||
|
||||
def test_download_pdf_with_watermark(
|
||||
self, client, test_user_token, test_task, db, monkeypatch, temp_upload_dir, sample_pdf_bytes
|
||||
):
|
||||
"""Test downloading PDF file applies watermark."""
|
||||
from pathlib import Path
|
||||
from app.services.file_storage_service import file_storage_service
|
||||
monkeypatch.setattr("app.core.config.settings.UPLOAD_DIR", temp_upload_dir)
|
||||
monkeypatch.setattr(file_storage_service, "base_dir", Path(temp_upload_dir))
|
||||
|
||||
# Create attachment and version
|
||||
attachment_id = str(uuid.uuid4())
|
||||
version_id = str(uuid.uuid4())
|
||||
|
||||
# Save the file to disk
|
||||
file_dir = os.path.join(temp_upload_dir, test_task.project_id, test_task.id, attachment_id, "v1")
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
file_path = os.path.join(file_dir, "test.pdf")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(sample_pdf_bytes)
|
||||
|
||||
relative_path = os.path.join(test_task.project_id, test_task.id, attachment_id, "v1", "test.pdf")
|
||||
|
||||
attachment = Attachment(
|
||||
id=attachment_id,
|
||||
task_id=test_task.id,
|
||||
filename="test.pdf",
|
||||
original_filename="test.pdf",
|
||||
mime_type="application/pdf",
|
||||
file_size=len(sample_pdf_bytes),
|
||||
current_version=1,
|
||||
uploaded_by=test_task.created_by,
|
||||
)
|
||||
db.add(attachment)
|
||||
|
||||
version = AttachmentVersion(
|
||||
id=version_id,
|
||||
attachment_id=attachment_id,
|
||||
version=1,
|
||||
file_path=relative_path,
|
||||
file_size=len(sample_pdf_bytes),
|
||||
checksum="0" * 64,
|
||||
uploaded_by=test_task.created_by,
|
||||
)
|
||||
db.add(version)
|
||||
db.commit()
|
||||
|
||||
# Download the file
|
||||
response = client.get(
|
||||
f"/api/attachments/{attachment_id}/download",
|
||||
headers={"Authorization": f"Bearer {test_user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/pdf"
|
||||
|
||||
# Verify watermark was applied (PDF content should be different)
|
||||
assert response.content != sample_pdf_bytes, "Downloaded PDF should have watermark"
|
||||
|
||||
def test_download_unsupported_file_no_watermark(
|
||||
self, client, test_user_token, test_task, db, monkeypatch, temp_upload_dir
|
||||
):
|
||||
"""Test downloading unsupported file type returns original without watermark."""
|
||||
from pathlib import Path
|
||||
from app.services.file_storage_service import file_storage_service
|
||||
monkeypatch.setattr("app.core.config.settings.UPLOAD_DIR", temp_upload_dir)
|
||||
monkeypatch.setattr(file_storage_service, "base_dir", Path(temp_upload_dir))
|
||||
|
||||
# Create a text file
|
||||
text_content = b"This is a plain text file."
|
||||
|
||||
attachment_id = str(uuid.uuid4())
|
||||
version_id = str(uuid.uuid4())
|
||||
|
||||
# Save the file to disk
|
||||
file_dir = os.path.join(temp_upload_dir, test_task.project_id, test_task.id, attachment_id, "v1")
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
file_path = os.path.join(file_dir, "test.txt")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(text_content)
|
||||
|
||||
relative_path = os.path.join(test_task.project_id, test_task.id, attachment_id, "v1", "test.txt")
|
||||
|
||||
attachment = Attachment(
|
||||
id=attachment_id,
|
||||
task_id=test_task.id,
|
||||
filename="test.txt",
|
||||
original_filename="test.txt",
|
||||
mime_type="text/plain",
|
||||
file_size=len(text_content),
|
||||
current_version=1,
|
||||
uploaded_by=test_task.created_by,
|
||||
)
|
||||
db.add(attachment)
|
||||
|
||||
version = AttachmentVersion(
|
||||
id=version_id,
|
||||
attachment_id=attachment_id,
|
||||
version=1,
|
||||
file_path=relative_path,
|
||||
file_size=len(text_content),
|
||||
checksum="0" * 64,
|
||||
uploaded_by=test_task.created_by,
|
||||
)
|
||||
db.add(version)
|
||||
db.commit()
|
||||
|
||||
# Download the file
|
||||
response = client.get(
|
||||
f"/api/attachments/{attachment_id}/download",
|
||||
headers={"Authorization": f"Bearer {test_user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Content should be unchanged for unsupported types
|
||||
assert response.content == text_content
|
||||
|
||||
def test_download_jpeg_with_watermark(
|
||||
self, client, test_user_token, test_task, db, monkeypatch, temp_upload_dir, sample_jpeg_bytes
|
||||
):
|
||||
"""Test downloading JPEG file applies watermark."""
|
||||
from pathlib import Path
|
||||
from app.services.file_storage_service import file_storage_service
|
||||
monkeypatch.setattr("app.core.config.settings.UPLOAD_DIR", temp_upload_dir)
|
||||
monkeypatch.setattr(file_storage_service, "base_dir", Path(temp_upload_dir))
|
||||
|
||||
attachment_id = str(uuid.uuid4())
|
||||
version_id = str(uuid.uuid4())
|
||||
|
||||
# Save the file to disk
|
||||
file_dir = os.path.join(temp_upload_dir, test_task.project_id, test_task.id, attachment_id, "v1")
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
file_path = os.path.join(file_dir, "test.jpg")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(sample_jpeg_bytes)
|
||||
|
||||
relative_path = os.path.join(test_task.project_id, test_task.id, attachment_id, "v1", "test.jpg")
|
||||
|
||||
attachment = Attachment(
|
||||
id=attachment_id,
|
||||
task_id=test_task.id,
|
||||
filename="test.jpg",
|
||||
original_filename="test.jpg",
|
||||
mime_type="image/jpeg",
|
||||
file_size=len(sample_jpeg_bytes),
|
||||
current_version=1,
|
||||
uploaded_by=test_task.created_by,
|
||||
)
|
||||
db.add(attachment)
|
||||
|
||||
version = AttachmentVersion(
|
||||
id=version_id,
|
||||
attachment_id=attachment_id,
|
||||
version=1,
|
||||
file_path=relative_path,
|
||||
file_size=len(sample_jpeg_bytes),
|
||||
checksum="0" * 64,
|
||||
uploaded_by=test_task.created_by,
|
||||
)
|
||||
db.add(version)
|
||||
db.commit()
|
||||
|
||||
# Download the file
|
||||
response = client.get(
|
||||
f"/api/attachments/{attachment_id}/download",
|
||||
headers={"Authorization": f"Bearer {test_user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "image/jpeg"
|
||||
|
||||
# Verify the response is a valid JPEG
|
||||
downloaded_image = Image.open(io.BytesIO(response.content))
|
||||
assert downloaded_image.format == "JPEG"
|
||||
|
||||
|
||||
class TestWatermarkErrorHandling:
|
||||
"""Tests for watermark error handling and graceful degradation."""
|
||||
|
||||
def test_watermark_service_singleton_exists(self):
|
||||
"""Test that watermark_service singleton is available."""
|
||||
assert watermark_service is not None
|
||||
assert isinstance(watermark_service, WatermarkService)
|
||||
|
||||
def test_invalid_image_bytes_graceful_handling(self):
|
||||
"""Test handling of invalid image bytes."""
|
||||
invalid_bytes = b"not an image"
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# Should raise an exception for invalid image data
|
||||
watermark_service.add_image_watermark(
|
||||
image_bytes=invalid_bytes,
|
||||
user_name="Test",
|
||||
employee_id="EMP001"
|
||||
)
|
||||
|
||||
def test_invalid_pdf_bytes_graceful_handling(self):
|
||||
"""Test handling of invalid PDF bytes."""
|
||||
invalid_bytes = b"not a pdf"
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# Should raise an exception for invalid PDF data
|
||||
watermark_service.add_pdf_watermark(
|
||||
pdf_bytes=invalid_bytes,
|
||||
user_name="Test",
|
||||
employee_id="EMP001"
|
||||
)
|
||||
Reference in New Issue
Block a user