feat: implement custom fields, gantt view, calendar view, and file encryption

- Custom Fields (FEAT-001):
  - CustomField and TaskCustomValue models with formula support
  - CRUD API for custom field management
  - Formula engine for calculated fields
  - Frontend: CustomFieldEditor, CustomFieldInput, ProjectSettings page
  - Task list API now includes custom_values
  - KanbanBoard displays custom field values

- Gantt View (FEAT-003):
  - TaskDependency model with FS/SS/FF/SF dependency types
  - Dependency CRUD API with cycle detection
  - start_date field added to tasks
  - GanttChart component with Frappe Gantt integration
  - Dependency type selector in UI

- Calendar View (FEAT-004):
  - CalendarView component with FullCalendar integration
  - Date range filtering API for tasks
  - Drag-and-drop date updates
  - View mode switching in Tasks page

- File Encryption (FEAT-010):
  - AES-256-GCM encryption service
  - EncryptionKey model with key rotation support
  - Admin API for key management
  - Encrypted upload/download for confidential projects

- Migrations: 011 (custom fields), 012 (encryption keys), 013 (task dependencies)
- Updated issues.md with completion status

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-05 23:39:12 +08:00
parent 69b81d9241
commit 2d80a8384e
65 changed files with 11045 additions and 82 deletions

View File

@@ -20,3 +20,14 @@ AUTH_API_URL=https://pj-auth-api.vercel.app
# System Admin
SYSTEM_ADMIN_EMAIL=ymirliu@panjit.com.tw
# File Encryption (AES-256)
# Master key for encrypting file encryption keys (optional - if not set, file encryption is disabled)
# Generate a new key with:
# python -c "import secrets, base64; print(base64.urlsafe_b64encode(secrets.token_bytes(32)).decode())"
#
# IMPORTANT:
# - Keep this key secure and back it up! If lost, encrypted files cannot be decrypted.
# - Store backup in a secure location separate from the database backup.
# - Do NOT change this key after files have been encrypted (use key rotation instead).
ENCRYPTION_MASTER_KEY=

View File

@@ -0,0 +1 @@
# Admin API module

View File

@@ -0,0 +1,299 @@
"""
Encryption Key Management API (Admin only).
Provides endpoints for:
- Listing encryption keys (without actual key data)
- Creating new encryption keys
- Key rotation
- Checking encryption status
"""
import uuid
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.orm import Session
from typing import Optional
from app.core.database import get_db
from app.middleware.auth import get_current_user, require_system_admin
from app.models import User, EncryptionKey, AuditAction
from app.schemas.encryption_key import (
EncryptionKeyResponse,
EncryptionKeyListResponse,
EncryptionKeyCreateResponse,
EncryptionKeyRotateResponse,
EncryptionStatusResponse,
)
from app.services.encryption_service import (
encryption_service,
MasterKeyNotConfiguredError,
)
from app.services.audit_service import AuditService
router = APIRouter(prefix="/api/admin/encryption-keys", tags=["Admin - Encryption Keys"])
def key_to_response(key: EncryptionKey) -> EncryptionKeyResponse:
"""Convert EncryptionKey model to response (without key data)."""
return EncryptionKeyResponse(
id=key.id,
algorithm=key.algorithm,
is_active=key.is_active,
created_at=key.created_at,
rotated_at=key.rotated_at,
)
@router.get("/status", response_model=EncryptionStatusResponse)
async def get_encryption_status(
db: Session = Depends(get_db),
current_user: User = Depends(require_system_admin),
):
"""
Get the current encryption status.
Returns whether encryption is available, active key info, and total key count.
"""
encryption_available = encryption_service.is_encryption_available()
active_key = None
total_keys = 0
if encryption_available:
active_key = db.query(EncryptionKey).filter(
EncryptionKey.is_active == True
).first()
total_keys = db.query(EncryptionKey).count()
message = "Encryption is available" if encryption_available else "Encryption is not configured (ENCRYPTION_MASTER_KEY not set)"
return EncryptionStatusResponse(
encryption_available=encryption_available,
active_key_id=active_key.id if active_key else None,
total_keys=total_keys,
message=message,
)
@router.get("", response_model=EncryptionKeyListResponse)
async def list_encryption_keys(
db: Session = Depends(get_db),
current_user: User = Depends(require_system_admin),
):
"""
List all encryption keys (without actual key data).
Only accessible by system administrators.
"""
keys = db.query(EncryptionKey).order_by(EncryptionKey.created_at.desc()).all()
return EncryptionKeyListResponse(
keys=[key_to_response(k) for k in keys],
total=len(keys),
)
@router.post("", response_model=EncryptionKeyCreateResponse)
async def create_encryption_key(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_system_admin),
):
"""
Create a new encryption key.
The key is generated, encrypted with the Master Key, and stored.
This does NOT automatically make it the active key - use rotate for that.
"""
if not encryption_service.is_encryption_available():
raise HTTPException(
status_code=400,
detail="Encryption is not configured. Set ENCRYPTION_MASTER_KEY in environment."
)
try:
# Generate new key
raw_key = encryption_service.generate_key()
# Encrypt with master key
encrypted_key = encryption_service.encrypt_key(raw_key)
# Create key record (not active by default)
key = EncryptionKey(
id=str(uuid.uuid4()),
key_data=encrypted_key,
algorithm="AES-256-GCM",
is_active=False,
)
db.add(key)
# Audit log
AuditService.log_event(
db=db,
event_type="encryption.key_create",
resource_type="encryption_key",
action=AuditAction.CREATE,
user_id=current_user.id,
resource_id=key.id,
changes=[{"field": "algorithm", "old_value": None, "new_value": "AES-256-GCM"}],
request_metadata=getattr(request.state, "audit_metadata", None),
)
db.commit()
db.refresh(key)
return EncryptionKeyCreateResponse(
id=key.id,
algorithm=key.algorithm,
is_active=key.is_active,
created_at=key.created_at,
message="Encryption key created successfully. Use /rotate to make it active.",
)
except MasterKeyNotConfiguredError:
raise HTTPException(
status_code=400,
detail="Master key is not configured"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to create encryption key: {str(e)}"
)
@router.post("/rotate", response_model=EncryptionKeyRotateResponse)
async def rotate_encryption_key(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_system_admin),
):
"""
Rotate to a new encryption key.
This will:
1. Create a new encryption key
2. Mark the new key as active
3. Mark the old active key as inactive (but keep it for decrypting old files)
After rotation, new file uploads will use the new key.
Old files remain readable using their original keys.
"""
if not encryption_service.is_encryption_available():
raise HTTPException(
status_code=400,
detail="Encryption is not configured. Set ENCRYPTION_MASTER_KEY in environment."
)
try:
# Find current active key
old_active_key = db.query(EncryptionKey).filter(
EncryptionKey.is_active == True
).first()
# Generate new key
raw_key = encryption_service.generate_key()
encrypted_key = encryption_service.encrypt_key(raw_key)
# Create new key as active
new_key = EncryptionKey(
id=str(uuid.uuid4()),
key_data=encrypted_key,
algorithm="AES-256-GCM",
is_active=True,
)
db.add(new_key)
# Deactivate old key if exists
old_key_id = None
if old_active_key:
old_active_key.is_active = False
old_active_key.rotated_at = datetime.utcnow()
old_key_id = old_active_key.id
# Audit log for rotation
AuditService.log_event(
db=db,
event_type="encryption.key_rotate",
resource_type="encryption_key",
action=AuditAction.UPDATE,
user_id=current_user.id,
resource_id=new_key.id,
changes=[
{"field": "rotation", "old_value": old_key_id, "new_value": new_key.id},
{"field": "is_active", "old_value": False, "new_value": True},
],
request_metadata=getattr(request.state, "audit_metadata", None),
)
db.commit()
return EncryptionKeyRotateResponse(
new_key_id=new_key.id,
old_key_id=old_key_id,
message="Key rotation completed successfully. New uploads will use the new key.",
)
except MasterKeyNotConfiguredError:
raise HTTPException(
status_code=400,
detail="Master key is not configured"
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to rotate encryption key: {str(e)}"
)
@router.delete("/{key_id}")
async def deactivate_encryption_key(
key_id: str,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(require_system_admin),
):
"""
Deactivate an encryption key.
Note: This does NOT delete the key. Keys are never deleted to ensure
encrypted files can always be decrypted.
If this is the only active key, you must rotate to a new key first.
"""
key = db.query(EncryptionKey).filter(EncryptionKey.id == key_id).first()
if not key:
raise HTTPException(status_code=404, detail="Encryption key not found")
if key.is_active:
# Check if there are other active keys
other_active = db.query(EncryptionKey).filter(
EncryptionKey.is_active == True,
EncryptionKey.id != key_id
).first()
if not other_active:
raise HTTPException(
status_code=400,
detail="Cannot deactivate the only active key. Rotate to a new key first."
)
key.is_active = False
key.rotated_at = datetime.utcnow()
# Audit log
AuditService.log_event(
db=db,
event_type="encryption.key_deactivate",
resource_type="encryption_key",
action=AuditAction.UPDATE,
user_id=current_user.id,
resource_id=key.id,
changes=[{"field": "is_active", "old_value": True, "new_value": False}],
request_metadata=getattr(request.state, "audit_metadata", None),
)
db.commit()
return {"detail": "Encryption key deactivated", "id": key_id}

View File

@@ -1,5 +1,7 @@
import uuid
import logging
from datetime import datetime
from io import BytesIO
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Request
from fastapi.responses import FileResponse, Response
from sqlalchemy.orm import Session
@@ -7,7 +9,7 @@ from typing import Optional
from app.core.database import get_db
from app.middleware.auth import get_current_user, check_task_access, check_task_edit_access
from app.models import User, Task, Project, Attachment, AttachmentVersion, AuditAction
from app.models import User, Task, Project, Attachment, AttachmentVersion, EncryptionKey, AuditAction
from app.schemas.attachment import (
AttachmentResponse, AttachmentListResponse, AttachmentDetailResponse,
AttachmentVersionResponse, VersionHistoryResponse
@@ -15,6 +17,13 @@ from app.schemas.attachment import (
from app.services.file_storage_service import file_storage_service
from app.services.audit_service import AuditService
from app.services.watermark_service import watermark_service
from app.services.encryption_service import (
encryption_service,
MasterKeyNotConfiguredError,
DecryptionError,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["attachments"])
@@ -103,6 +112,40 @@ def version_to_response(version: AttachmentVersion) -> AttachmentVersionResponse
)
def should_encrypt_file(project: Project, db: Session) -> tuple[bool, Optional[EncryptionKey]]:
"""
Determine if a file should be encrypted based on project security level.
Returns:
Tuple of (should_encrypt, encryption_key)
"""
# Only encrypt for confidential projects
if project.security_level != "confidential":
return False, None
# Check if encryption is available
if not encryption_service.is_encryption_available():
logger.warning(
f"Project {project.id} is confidential but encryption is not configured. "
"Files will be stored unencrypted."
)
return False, None
# Get active encryption key
active_key = db.query(EncryptionKey).filter(
EncryptionKey.is_active == True
).first()
if not active_key:
logger.warning(
f"Project {project.id} is confidential but no active encryption key exists. "
"Files will be stored unencrypted. Create a key using /api/admin/encryption-keys/rotate"
)
return False, None
return True, active_key
@router.post("/tasks/{task_id}/attachments", response_model=AttachmentResponse)
async def upload_attachment(
task_id: str,
@@ -111,10 +154,22 @@ async def upload_attachment(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Upload a file attachment to a task."""
"""
Upload a file attachment to a task.
For confidential projects, files are automatically encrypted using AES-256-GCM.
"""
task = get_task_with_access_check(db, task_id, current_user, require_edit=True)
# Check if attachment with same filename exists (for versioning in Phase 2)
# Get project to check security level
project = db.query(Project).filter(Project.id == task.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Determine if encryption is needed
should_encrypt, encryption_key = should_encrypt_file(project, db)
# Check if attachment with same filename exists (for versioning)
existing = db.query(Attachment).filter(
Attachment.task_id == task_id,
Attachment.original_filename == file.filename,
@@ -122,17 +177,73 @@ async def upload_attachment(
).first()
if existing:
# Phase 2: Create new version
# Create new version for existing attachment
new_version = existing.current_version + 1
# Save file
file_path, file_size, checksum = await file_storage_service.save_file(
file=file,
project_id=task.project_id,
task_id=task_id,
attachment_id=existing.id,
version=new_version
)
if should_encrypt and encryption_key:
# Read and encrypt file content
file_content = await file.read()
await file.seek(0)
try:
# Decrypt the encryption key
raw_key = encryption_service.decrypt_key(encryption_key.key_data)
# Encrypt the file
encrypted_content = encryption_service.encrypt_bytes(file_content, raw_key)
# Create a new UploadFile-like object with encrypted content
encrypted_file = BytesIO(encrypted_content)
encrypted_file.seek(0)
# Save encrypted file using a modified approach
file_path, _, checksum = await file_storage_service.save_file(
file=file,
project_id=task.project_id,
task_id=task_id,
attachment_id=existing.id,
version=new_version
)
# Overwrite with encrypted content
with open(file_path, "wb") as f:
f.write(encrypted_content)
file_size = len(encrypted_content)
# Update existing attachment with encryption info
existing.is_encrypted = True
existing.encryption_key_id = encryption_key.id
# Audit log for encryption
AuditService.log_event(
db=db,
event_type="attachment.encrypt",
resource_type="attachment",
action=AuditAction.UPDATE,
user_id=current_user.id,
resource_id=existing.id,
changes=[
{"field": "is_encrypted", "old_value": False, "new_value": True},
{"field": "encryption_key_id", "old_value": None, "new_value": encryption_key.id},
],
request_metadata=getattr(request.state, "audit_metadata", None),
)
except Exception as e:
logger.error(f"Failed to encrypt file for attachment {existing.id}: {e}")
raise HTTPException(
status_code=500,
detail="Failed to encrypt file. Please try again."
)
else:
# Save file without encryption
file_path, file_size, checksum = await file_storage_service.save_file(
file=file,
project_id=task.project_id,
task_id=task_id,
attachment_id=existing.id,
version=new_version
)
# Create version record
version = AttachmentVersion(
@@ -170,15 +281,52 @@ async def upload_attachment(
# Create new attachment
attachment_id = str(uuid.uuid4())
is_encrypted = False
encryption_key_id = None
# Save file
file_path, file_size, checksum = await file_storage_service.save_file(
file=file,
project_id=task.project_id,
task_id=task_id,
attachment_id=attachment_id,
version=1
)
if should_encrypt and encryption_key:
# Read and encrypt file content
file_content = await file.read()
await file.seek(0)
try:
# Decrypt the encryption key
raw_key = encryption_service.decrypt_key(encryption_key.key_data)
# Encrypt the file
encrypted_content = encryption_service.encrypt_bytes(file_content, raw_key)
# Save file first to get path
file_path, _, checksum = await file_storage_service.save_file(
file=file,
project_id=task.project_id,
task_id=task_id,
attachment_id=attachment_id,
version=1
)
# Overwrite with encrypted content
with open(file_path, "wb") as f:
f.write(encrypted_content)
file_size = len(encrypted_content)
is_encrypted = True
encryption_key_id = encryption_key.id
except Exception as e:
logger.error(f"Failed to encrypt new file: {e}")
raise HTTPException(
status_code=500,
detail="Failed to encrypt file. Please try again."
)
else:
# Save file without encryption
file_path, file_size, checksum = await file_storage_service.save_file(
file=file,
project_id=task.project_id,
task_id=task_id,
attachment_id=attachment_id,
version=1
)
# Get mime type from file storage validation
extension = file_storage_service.get_extension(file.filename or "")
@@ -193,7 +341,8 @@ async def upload_attachment(
mime_type=mime_type,
file_size=file_size,
current_version=1,
is_encrypted=False,
is_encrypted=is_encrypted,
encryption_key_id=encryption_key_id,
uploaded_by=current_user.id
)
db.add(attachment)
@@ -211,6 +360,10 @@ async def upload_attachment(
db.add(version)
# Audit log
changes = [{"field": "filename", "old_value": None, "new_value": attachment.filename}]
if is_encrypted:
changes.append({"field": "is_encrypted", "old_value": None, "new_value": True})
AuditService.log_event(
db=db,
event_type="attachment.upload",
@@ -218,7 +371,7 @@ async def upload_attachment(
action=AuditAction.CREATE,
user_id=current_user.id,
resource_id=attachment.id,
changes=[{"field": "filename", "old_value": None, "new_value": attachment.filename}],
changes=changes,
request_metadata=getattr(request.state, "audit_metadata", None)
)
@@ -286,7 +439,11 @@ async def download_attachment(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Download an attachment file with dynamic watermark."""
"""
Download an attachment file with dynamic watermark.
For encrypted files, the file is automatically decrypted before returning.
"""
attachment = get_attachment_with_access_check(db, attachment_id, current_user, require_edit=False)
# Get version to download
@@ -319,14 +476,69 @@ async def download_attachment(
)
db.commit()
# Read file content
with open(file_path, "rb") as f:
file_bytes = f.read()
# Decrypt if encrypted
if attachment.is_encrypted:
if not attachment.encryption_key_id:
raise HTTPException(
status_code=500,
detail="Encrypted file is missing encryption key reference"
)
encryption_key = db.query(EncryptionKey).filter(
EncryptionKey.id == attachment.encryption_key_id
).first()
if not encryption_key:
raise HTTPException(
status_code=500,
detail="Encryption key not found for this file"
)
try:
# Decrypt the encryption key
raw_key = encryption_service.decrypt_key(encryption_key.key_data)
# Decrypt the file
file_bytes = encryption_service.decrypt_bytes(file_bytes, raw_key)
# Audit log for decryption
AuditService.log_event(
db=db,
event_type="attachment.decrypt",
resource_type="attachment",
action=AuditAction.UPDATE,
user_id=current_user.id,
resource_id=attachment.id,
changes=[{"field": "decrypted_for_download", "old_value": None, "new_value": True}],
request_metadata=getattr(request.state, "audit_metadata", None) if request else None,
)
db.commit()
except DecryptionError as e:
logger.error(f"Failed to decrypt attachment {attachment_id}: {e}")
raise HTTPException(
status_code=500,
detail="Failed to decrypt file. The file may be corrupted."
)
except MasterKeyNotConfiguredError:
raise HTTPException(
status_code=500,
detail="Encryption is not configured. Cannot decrypt file."
)
except Exception as e:
logger.error(f"Unexpected error decrypting attachment {attachment_id}: {e}")
raise HTTPException(
status_code=500,
detail="Failed to decrypt file."
)
# Check if watermark should be applied
mime_type = attachment.mime_type or ""
if watermark_service.supports_watermark(mime_type):
try:
# Read the original file
with open(file_path, "rb") as f:
file_bytes = f.read()
# Apply watermark based on file type
if watermark_service.is_supported_image(mime_type):
watermarked_bytes, output_format = watermark_service.add_image_watermark(
@@ -367,19 +579,19 @@ async def download_attachment(
)
except Exception as e:
# If watermarking fails, log the error but still return the original file
# This ensures users can still download files even if watermarking has issues
import logging
logging.getLogger(__name__).warning(
# If watermarking fails, log the error but still return the file
logger.warning(
f"Watermarking failed for attachment {attachment_id}: {str(e)}. "
"Returning original file."
"Returning file without watermark."
)
# Return original file without watermark for unsupported types or on error
return FileResponse(
path=str(file_path),
filename=attachment.original_filename,
media_type=attachment.mime_type
# Return file (decrypted if needed, without watermark for unsupported types)
return Response(
content=file_bytes,
media_type=attachment.mime_type,
headers={
"Content-Disposition": f'attachment; filename="{attachment.original_filename}"'
}
)

View File

@@ -0,0 +1,3 @@
from app.api.custom_fields.router import router
__all__ = ["router"]

View File

@@ -0,0 +1,368 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import Optional
from app.core.database import get_db
from app.models import User, Project, CustomField, TaskCustomValue
from app.schemas.custom_field import (
CustomFieldCreate, CustomFieldUpdate, CustomFieldResponse, CustomFieldListResponse
)
from app.middleware.auth import get_current_user, check_project_access, check_project_edit_access
from app.services.formula_service import FormulaService
router = APIRouter(tags=["custom-fields"])
# Maximum custom fields per project
MAX_CUSTOM_FIELDS_PER_PROJECT = 20
def custom_field_to_response(field: CustomField) -> CustomFieldResponse:
"""Convert CustomField model to response schema."""
return CustomFieldResponse(
id=field.id,
project_id=field.project_id,
name=field.name,
field_type=field.field_type,
options=field.options,
formula=field.formula,
is_required=field.is_required,
position=field.position,
created_at=field.created_at,
updated_at=field.updated_at,
)
@router.post("/api/projects/{project_id}/custom-fields", response_model=CustomFieldResponse, status_code=status.HTTP_201_CREATED)
async def create_custom_field(
project_id: str,
field_data: CustomFieldCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Create a new custom field for a project.
Only project owner or system admin can create custom fields.
Maximum 20 custom fields per project.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Project not found",
)
if not check_project_edit_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Permission denied - only project owner can manage custom fields",
)
# Check custom field count limit
field_count = db.query(CustomField).filter(
CustomField.project_id == project_id
).count()
if field_count >= MAX_CUSTOM_FIELDS_PER_PROJECT:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Maximum {MAX_CUSTOM_FIELDS_PER_PROJECT} custom fields per project exceeded",
)
# Check for duplicate name
existing = db.query(CustomField).filter(
CustomField.project_id == project_id,
CustomField.name == field_data.name,
).first()
if existing:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Custom field with name '{field_data.name}' already exists",
)
# Validate formula if it's a formula field
if field_data.field_type.value == "formula":
if not field_data.formula:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Formula is required for formula fields",
)
is_valid, error_msg = FormulaService.validate_formula(
field_data.formula, project_id, db
)
if not is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_msg,
)
# Get next position
max_pos = db.query(CustomField).filter(
CustomField.project_id == project_id
).order_by(CustomField.position.desc()).first()
next_position = (max_pos.position + 1) if max_pos else 0
# Create the custom field
field = CustomField(
id=str(uuid.uuid4()),
project_id=project_id,
name=field_data.name,
field_type=field_data.field_type.value,
options=field_data.options,
formula=field_data.formula,
is_required=field_data.is_required,
position=next_position,
)
db.add(field)
db.commit()
db.refresh(field)
return custom_field_to_response(field)
@router.get("/api/projects/{project_id}/custom-fields", response_model=CustomFieldListResponse)
async def list_custom_fields(
project_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
List all custom fields for a project.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Project not found",
)
if not check_project_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied",
)
fields = db.query(CustomField).filter(
CustomField.project_id == project_id
).order_by(CustomField.position).all()
return CustomFieldListResponse(
fields=[custom_field_to_response(f) for f in fields],
total=len(fields),
)
@router.get("/api/custom-fields/{field_id}", response_model=CustomFieldResponse)
async def get_custom_field(
field_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Get a specific custom field by ID.
"""
field = db.query(CustomField).filter(CustomField.id == field_id).first()
if not field:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Custom field not found",
)
project = field.project
if not check_project_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied",
)
return custom_field_to_response(field)
@router.put("/api/custom-fields/{field_id}", response_model=CustomFieldResponse)
async def update_custom_field(
field_id: str,
field_data: CustomFieldUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Update a custom field.
Only project owner or system admin can update custom fields.
Note: field_type cannot be changed after creation.
"""
field = db.query(CustomField).filter(CustomField.id == field_id).first()
if not field:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Custom field not found",
)
project = field.project
if not check_project_edit_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Permission denied",
)
# Check for duplicate name if name is being updated
if field_data.name is not None and field_data.name != field.name:
existing = db.query(CustomField).filter(
CustomField.project_id == field.project_id,
CustomField.name == field_data.name,
CustomField.id != field_id,
).first()
if existing:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Custom field with name '{field_data.name}' already exists",
)
# Validate formula if updating formula field
if field.field_type == "formula" and field_data.formula is not None:
is_valid, error_msg = FormulaService.validate_formula(
field_data.formula, field.project_id, db, field_id
)
if not is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_msg,
)
# Validate options if updating dropdown field
if field.field_type == "dropdown" and field_data.options is not None:
if len(field_data.options) == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Dropdown fields must have at least one option",
)
if len(field_data.options) > 50:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Dropdown fields can have at most 50 options",
)
# Update fields
if field_data.name is not None:
field.name = field_data.name
if field_data.options is not None and field.field_type == "dropdown":
field.options = field_data.options
if field_data.formula is not None and field.field_type == "formula":
field.formula = field_data.formula
if field_data.is_required is not None:
field.is_required = field_data.is_required
db.commit()
db.refresh(field)
return custom_field_to_response(field)
@router.delete("/api/custom-fields/{field_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_custom_field(
field_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Delete a custom field.
Only project owner or system admin can delete custom fields.
This will also delete all stored values for this field.
"""
field = db.query(CustomField).filter(CustomField.id == field_id).first()
if not field:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Custom field not found",
)
project = field.project
if not check_project_edit_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Permission denied",
)
# Check if any formula fields reference this field
if field.field_type != "formula":
formula_fields = db.query(CustomField).filter(
CustomField.project_id == field.project_id,
CustomField.field_type == "formula",
).all()
for formula_field in formula_fields:
if formula_field.formula:
references = FormulaService.extract_field_references(formula_field.formula)
if field.name in references:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot delete: field is referenced by formula field '{formula_field.name}'",
)
# Delete the field (cascade will delete associated values)
db.delete(field)
db.commit()
@router.patch("/api/custom-fields/{field_id}/position", response_model=CustomFieldResponse)
async def update_custom_field_position(
field_id: str,
position: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Update a custom field's position (for reordering).
"""
field = db.query(CustomField).filter(CustomField.id == field_id).first()
if not field:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Custom field not found",
)
project = field.project
if not check_project_edit_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Permission denied",
)
old_position = field.position
if position == old_position:
return custom_field_to_response(field)
# Reorder other fields
if position > old_position:
# Moving down: shift fields between old and new position up
db.query(CustomField).filter(
CustomField.project_id == field.project_id,
CustomField.position > old_position,
CustomField.position <= position,
).update({CustomField.position: CustomField.position - 1})
else:
# Moving up: shift fields between new and old position down
db.query(CustomField).filter(
CustomField.project_id == field.project_id,
CustomField.position >= position,
CustomField.position < old_position,
).update({CustomField.position: CustomField.position + 1})
field.position = position
db.commit()
db.refresh(field)
return custom_field_to_response(field)

View File

@@ -0,0 +1,3 @@
from app.api.task_dependencies.router import router
__all__ = ["router"]

View File

@@ -0,0 +1,431 @@
"""
Task Dependencies API Router
Provides CRUD operations for task dependencies used in Gantt view.
Includes circular dependency detection and date constraint validation.
"""
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from typing import Optional
from app.core.database import get_db
from app.models import User, Task, TaskDependency, AuditAction
from app.schemas.task_dependency import (
TaskDependencyCreate,
TaskDependencyUpdate,
TaskDependencyResponse,
TaskDependencyListResponse,
TaskInfo
)
from app.middleware.auth import get_current_user, check_task_access, check_task_edit_access
from app.middleware.audit import get_audit_metadata
from app.services.audit_service import AuditService
from app.services.dependency_service import DependencyService, DependencyValidationError
router = APIRouter(tags=["task-dependencies"])
def dependency_to_response(
dep: TaskDependency,
include_tasks: bool = True
) -> TaskDependencyResponse:
"""Convert TaskDependency model to response schema."""
predecessor_info = None
successor_info = None
if include_tasks:
if dep.predecessor:
predecessor_info = TaskInfo(
id=dep.predecessor.id,
title=dep.predecessor.title,
start_date=dep.predecessor.start_date,
due_date=dep.predecessor.due_date
)
if dep.successor:
successor_info = TaskInfo(
id=dep.successor.id,
title=dep.successor.title,
start_date=dep.successor.start_date,
due_date=dep.successor.due_date
)
return TaskDependencyResponse(
id=dep.id,
predecessor_id=dep.predecessor_id,
successor_id=dep.successor_id,
dependency_type=dep.dependency_type,
lag_days=dep.lag_days,
created_at=dep.created_at,
predecessor=predecessor_info,
successor=successor_info
)
@router.post(
"/api/tasks/{task_id}/dependencies",
response_model=TaskDependencyResponse,
status_code=status.HTTP_201_CREATED
)
async def create_dependency(
task_id: str,
dependency_data: TaskDependencyCreate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Add a dependency to a task (the task becomes the successor).
The predecessor_id in the request body specifies which task must complete first.
The task_id in the URL becomes the successor (depends on the predecessor).
Validates:
- Both tasks exist and are in the same project
- No self-reference
- No duplicate dependency
- No circular dependency
- Dependency limit not exceeded
"""
# Get the successor task (from URL)
successor = db.query(Task).filter(Task.id == task_id).first()
if not successor:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Task not found"
)
# Check edit permission on successor
if not check_task_edit_access(current_user, successor, successor.project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Permission denied"
)
# Validate the dependency
try:
DependencyService.validate_dependency(
db,
predecessor_id=dependency_data.predecessor_id,
successor_id=task_id
)
except DependencyValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error_type": e.error_type,
"message": e.message,
"details": e.details
}
)
# Create the dependency
dependency = TaskDependency(
id=str(uuid.uuid4()),
predecessor_id=dependency_data.predecessor_id,
successor_id=task_id,
dependency_type=dependency_data.dependency_type.value,
lag_days=dependency_data.lag_days
)
db.add(dependency)
# Audit log
AuditService.log_event(
db=db,
event_type="task.dependency.create",
resource_type="task_dependency",
action=AuditAction.CREATE,
user_id=current_user.id,
resource_id=dependency.id,
changes=[{
"field": "dependency",
"old_value": None,
"new_value": {
"predecessor_id": dependency.predecessor_id,
"successor_id": dependency.successor_id,
"dependency_type": dependency.dependency_type,
"lag_days": dependency.lag_days
}
}],
request_metadata=get_audit_metadata(request)
)
db.commit()
db.refresh(dependency)
return dependency_to_response(dependency)
@router.get(
"/api/tasks/{task_id}/dependencies",
response_model=TaskDependencyListResponse
)
async def list_task_dependencies(
task_id: str,
direction: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Get all dependencies for a task.
Args:
task_id: The task to get dependencies for
direction: Optional filter
- 'predecessors': Only get tasks this task depends on
- 'successors': Only get tasks that depend on this task
- None: Get both
Returns all dependencies where the task is either the predecessor or successor.
"""
task = db.query(Task).filter(Task.id == task_id).first()
if not task:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Task not found"
)
if not check_task_access(current_user, task, task.project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied"
)
dependencies = []
if direction is None or direction == "predecessors":
# Get dependencies where this task is the successor (predecessors)
predecessor_deps = db.query(TaskDependency).filter(
TaskDependency.successor_id == task_id
).all()
dependencies.extend(predecessor_deps)
if direction is None or direction == "successors":
# Get dependencies where this task is the predecessor (successors)
successor_deps = db.query(TaskDependency).filter(
TaskDependency.predecessor_id == task_id
).all()
# Avoid duplicates if direction is None
if direction is None:
for dep in successor_deps:
if dep not in dependencies:
dependencies.append(dep)
else:
dependencies.extend(successor_deps)
return TaskDependencyListResponse(
dependencies=[dependency_to_response(d) for d in dependencies],
total=len(dependencies)
)
@router.get(
"/api/task-dependencies/{dependency_id}",
response_model=TaskDependencyResponse
)
async def get_dependency(
dependency_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Get a specific dependency by ID."""
dependency = db.query(TaskDependency).filter(
TaskDependency.id == dependency_id
).first()
if not dependency:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Dependency not found"
)
# Check access via the successor task
task = dependency.successor
if not check_task_access(current_user, task, task.project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied"
)
return dependency_to_response(dependency)
@router.patch(
"/api/task-dependencies/{dependency_id}",
response_model=TaskDependencyResponse
)
async def update_dependency(
dependency_id: str,
update_data: TaskDependencyUpdate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Update a dependency's type or lag days.
Cannot change predecessor_id or successor_id - delete and recreate instead.
"""
dependency = db.query(TaskDependency).filter(
TaskDependency.id == dependency_id
).first()
if not dependency:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Dependency not found"
)
# Check edit permission via the successor task
task = dependency.successor
if not check_task_edit_access(current_user, task, task.project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Permission denied"
)
# Track changes for audit
old_values = {
"dependency_type": dependency.dependency_type,
"lag_days": dependency.lag_days
}
# Update fields
if update_data.dependency_type is not None:
dependency.dependency_type = update_data.dependency_type.value
if update_data.lag_days is not None:
dependency.lag_days = update_data.lag_days
new_values = {
"dependency_type": dependency.dependency_type,
"lag_days": dependency.lag_days
}
# Audit log
changes = AuditService.detect_changes(old_values, new_values)
if changes:
AuditService.log_event(
db=db,
event_type="task.dependency.update",
resource_type="task_dependency",
action=AuditAction.UPDATE,
user_id=current_user.id,
resource_id=dependency.id,
changes=changes,
request_metadata=get_audit_metadata(request)
)
db.commit()
db.refresh(dependency)
return dependency_to_response(dependency)
@router.delete(
"/api/task-dependencies/{dependency_id}",
status_code=status.HTTP_204_NO_CONTENT
)
async def delete_dependency(
dependency_id: str,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""Delete a dependency."""
dependency = db.query(TaskDependency).filter(
TaskDependency.id == dependency_id
).first()
if not dependency:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Dependency not found"
)
# Check edit permission via the successor task
task = dependency.successor
if not check_task_edit_access(current_user, task, task.project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Permission denied"
)
# Audit log
AuditService.log_event(
db=db,
event_type="task.dependency.delete",
resource_type="task_dependency",
action=AuditAction.DELETE,
user_id=current_user.id,
resource_id=dependency.id,
changes=[{
"field": "dependency",
"old_value": {
"predecessor_id": dependency.predecessor_id,
"successor_id": dependency.successor_id,
"dependency_type": dependency.dependency_type,
"lag_days": dependency.lag_days
},
"new_value": None
}],
request_metadata=get_audit_metadata(request)
)
db.delete(dependency)
db.commit()
return None
@router.get(
"/api/projects/{project_id}/dependencies",
response_model=TaskDependencyListResponse
)
async def list_project_dependencies(
project_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Get all dependencies for a project.
Useful for rendering the full Gantt chart with all dependency arrows.
"""
from app.models import Project
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Project not found"
)
from app.middleware.auth import check_project_access
if not check_project_access(current_user, project):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied"
)
# Get all dependencies for tasks in this project (exclude soft-deleted tasks)
# Create aliases for joining both predecessor and successor
from sqlalchemy.orm import aliased
Successor = aliased(Task)
Predecessor = aliased(Task)
dependencies = db.query(TaskDependency).join(
Successor, TaskDependency.successor_id == Successor.id
).join(
Predecessor, TaskDependency.predecessor_id == Predecessor.id
).filter(
Successor.project_id == project_id,
Successor.is_deleted == False,
Predecessor.is_deleted == False
).all()
return TaskDependencyListResponse(
dependencies=[dependency_to_response(d) for d in dependencies],
total=len(dependencies)
)

View File

@@ -10,7 +10,7 @@ from app.core.redis_pubsub import publish_task_event
from app.models import User, Project, Task, TaskStatus, AuditAction, Blocker
from app.schemas.task import (
TaskCreate, TaskUpdate, TaskResponse, TaskWithDetails, TaskListResponse,
TaskStatusUpdate, TaskAssignUpdate
TaskStatusUpdate, TaskAssignUpdate, CustomValueResponse
)
from app.middleware.auth import (
get_current_user, check_project_access, check_task_access, check_task_edit_access
@@ -19,6 +19,8 @@ from app.middleware.audit import get_audit_metadata
from app.services.audit_service import AuditService
from app.services.trigger_service import TriggerService
from app.services.workload_cache import invalidate_user_workload_cache
from app.services.custom_value_service import CustomValueService
from app.services.dependency_service import DependencyService
logger = logging.getLogger(__name__)
@@ -40,13 +42,18 @@ def get_task_depth(db: Session, task: Task) -> int:
return depth
def task_to_response(task: Task) -> TaskWithDetails:
def task_to_response(task: Task, db: Session = None, include_custom_values: bool = False) -> TaskWithDetails:
"""Convert a Task model to TaskWithDetails response."""
# Count only non-deleted subtasks
subtask_count = 0
if task.subtasks:
subtask_count = sum(1 for st in task.subtasks if not st.is_deleted)
# Get custom values if requested
custom_values = None
if include_custom_values and db:
custom_values = CustomValueService.get_custom_values_for_task(db, task)
return TaskWithDetails(
id=task.id,
project_id=task.project_id,
@@ -56,6 +63,7 @@ def task_to_response(task: Task) -> TaskWithDetails:
priority=task.priority,
original_estimate=task.original_estimate,
time_spent=task.time_spent,
start_date=task.start_date,
due_date=task.due_date,
assignee_id=task.assignee_id,
status_id=task.status_id,
@@ -69,6 +77,7 @@ def task_to_response(task: Task) -> TaskWithDetails:
status_color=task.status.color if task.status else None,
creator_name=task.creator.name if task.creator else None,
subtask_count=subtask_count,
custom_values=custom_values,
)
@@ -78,12 +87,24 @@ async def list_tasks(
parent_task_id: Optional[str] = Query(None, description="Filter by parent task"),
status_id: Optional[str] = Query(None, description="Filter by status"),
assignee_id: Optional[str] = Query(None, description="Filter by assignee"),
due_after: Optional[datetime] = Query(None, description="Filter tasks with due_date >= this value (for calendar view)"),
due_before: Optional[datetime] = Query(None, description="Filter tasks with due_date <= this value (for calendar view)"),
include_deleted: bool = Query(False, description="Include deleted tasks (admin only)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
List all tasks in a project.
Supports filtering by:
- parent_task_id: Filter by parent task (empty string for root tasks only)
- status_id: Filter by task status
- assignee_id: Filter by assigned user
- due_after: Filter tasks with due_date >= this value (ISO 8601 datetime)
- due_before: Filter tasks with due_date <= this value (ISO 8601 datetime)
The due_after and due_before parameters are useful for calendar view
to fetch tasks within a specific date range.
"""
project = db.query(Project).filter(Project.id == project_id).first()
@@ -124,10 +145,17 @@ async def list_tasks(
if assignee_id:
query = query.filter(Task.assignee_id == assignee_id)
# Date range filter for calendar view
if due_after:
query = query.filter(Task.due_date >= due_after)
if due_before:
query = query.filter(Task.due_date <= due_before)
tasks = query.order_by(Task.position, Task.created_at).all()
return TaskListResponse(
tasks=[task_to_response(t) for t in tasks],
tasks=[task_to_response(t, db=db, include_custom_values=True) for t in tasks],
total=len(tasks),
)
@@ -204,6 +232,25 @@ async def create_task(
).order_by(Task.position.desc()).first()
next_position = (max_pos_result.position + 1) if max_pos_result else 0
# Validate required custom fields
if task_data.custom_values:
missing_fields = CustomValueService.validate_required_fields(
db, project_id, task_data.custom_values
)
if missing_fields:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Missing required custom fields: {', '.join(missing_fields)}",
)
# Validate start_date <= due_date
if task_data.start_date and task_data.due_date:
if task_data.start_date > task_data.due_date:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Start date cannot be after due date",
)
task = Task(
id=str(uuid.uuid4()),
project_id=project_id,
@@ -212,6 +259,7 @@ async def create_task(
description=task_data.description,
priority=task_data.priority.value if task_data.priority else "medium",
original_estimate=task_data.original_estimate,
start_date=task_data.start_date,
due_date=task_data.due_date,
assignee_id=task_data.assignee_id,
status_id=task_data.status_id,
@@ -220,6 +268,17 @@ async def create_task(
)
db.add(task)
db.flush() # Flush to get task.id for custom values
# Save custom values
if task_data.custom_values:
try:
CustomValueService.save_custom_values(db, task, task_data.custom_values)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
# Audit log
AuditService.log_event(
@@ -256,6 +315,7 @@ async def create_task(
"assignee_id": str(task.assignee_id) if task.assignee_id else None,
"assignee_name": task.assignee.name if task.assignee else None,
"priority": task.priority,
"start_date": str(task.start_date) if task.start_date else None,
"due_date": str(task.due_date) if task.due_date else None,
"time_estimate": task.original_estimate,
"original_estimate": task.original_estimate,
@@ -303,7 +363,7 @@ async def get_task(
detail="Access denied",
)
return task_to_response(task)
return task_to_response(task, db, include_custom_values=True)
@router.patch("/api/tasks/{task_id}", response_model=TaskResponse)
@@ -336,13 +396,42 @@ async def update_task(
"title": task.title,
"description": task.description,
"priority": task.priority,
"start_date": task.start_date,
"due_date": task.due_date,
"original_estimate": task.original_estimate,
"time_spent": task.time_spent,
}
# Update fields
# Update fields (exclude custom_values, handle separately)
update_data = task_data.model_dump(exclude_unset=True)
custom_values_data = update_data.pop("custom_values", None)
# Get the proposed start_date and due_date for validation
new_start_date = update_data.get("start_date", task.start_date)
new_due_date = update_data.get("due_date", task.due_date)
# Validate start_date <= due_date
if new_start_date and new_due_date:
if new_start_date > new_due_date:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Start date cannot be after due date",
)
# Validate date constraints against dependencies
if "start_date" in update_data or "due_date" in update_data:
violations = DependencyService.validate_date_constraints(
task, new_start_date, new_due_date, db
)
if violations:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"message": "Date change violates dependency constraints",
"violations": violations
}
)
for field, value in update_data.items():
if field == "priority" and value:
setattr(task, field, value.value)
@@ -354,6 +443,7 @@ async def update_task(
"title": task.title,
"description": task.description,
"priority": task.priority,
"start_date": task.start_date,
"due_date": task.due_date,
"original_estimate": task.original_estimate,
"time_spent": task.time_spent,
@@ -377,6 +467,18 @@ async def update_task(
if "priority" in update_data:
TriggerService.evaluate_triggers(db, task, old_values, new_values, current_user)
# Handle custom values update
if custom_values_data:
try:
from app.schemas.task import CustomValueInput
custom_values = [CustomValueInput(**cv) for cv in custom_values_data]
CustomValueService.save_custom_values(db, task, custom_values)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
db.commit()
db.refresh(task)
@@ -400,6 +502,7 @@ async def update_task(
"assignee_id": str(task.assignee_id) if task.assignee_id else None,
"assignee_name": task.assignee.name if task.assignee else None,
"priority": task.priority,
"start_date": str(task.start_date) if task.start_date else None,
"due_date": str(task.due_date) if task.due_date else None,
"time_estimate": task.original_estimate,
"original_estimate": task.original_estimate,

View File

@@ -1,6 +1,6 @@
from pydantic_settings import BaseSettings
from pydantic import field_validator
from typing import List
from typing import List, Optional
import os
@@ -52,6 +52,35 @@ class Settings(BaseSettings):
)
return v
# Encryption - Master key for encrypting file encryption keys
# Must be a 32-byte (256-bit) key encoded as base64 for AES-256
# Generate with: python -c "import secrets, base64; print(base64.urlsafe_b64encode(secrets.token_bytes(32)).decode())"
ENCRYPTION_MASTER_KEY: Optional[str] = None
@field_validator("ENCRYPTION_MASTER_KEY")
@classmethod
def validate_encryption_master_key(cls, v: Optional[str]) -> Optional[str]:
"""Validate that ENCRYPTION_MASTER_KEY is properly formatted if set."""
if v is None or v.strip() == "":
return None
# Basic validation - should be base64 encoded 32 bytes
import base64
try:
decoded = base64.urlsafe_b64decode(v)
if len(decoded) != 32:
raise ValueError(
"ENCRYPTION_MASTER_KEY must be a base64-encoded 32-byte key. "
"Generate with: python -c \"import secrets, base64; print(base64.urlsafe_b64encode(secrets.token_bytes(32)).decode())\""
)
except Exception as e:
if "must be a base64-encoded" in str(e):
raise
raise ValueError(
"ENCRYPTION_MASTER_KEY must be a valid base64-encoded string. "
f"Error: {e}"
)
return v
# External Auth API
AUTH_API_URL: str = "https://pj-auth-api.vercel.app"

View File

@@ -34,6 +34,9 @@ from app.api.attachments import router as attachments_router
from app.api.triggers import router as triggers_router
from app.api.reports import router as reports_router
from app.api.health import router as health_router
from app.api.custom_fields import router as custom_fields_router
from app.api.task_dependencies import router as task_dependencies_router
from app.api.admin import encryption_keys as admin_encryption_keys_router
from app.core.config import settings
app = FastAPI(
@@ -76,6 +79,9 @@ app.include_router(attachments_router)
app.include_router(triggers_router)
app.include_router(reports_router)
app.include_router(health_router)
app.include_router(custom_fields_router)
app.include_router(task_dependencies_router)
app.include_router(admin_encryption_keys_router.router)
@app.get("/health")

View File

@@ -12,6 +12,7 @@ from app.models.notification import Notification
from app.models.blocker import Blocker
from app.models.audit_log import AuditLog, AuditAction, SensitivityLevel, EVENT_SENSITIVITY, ALERT_EVENTS
from app.models.audit_alert import AuditAlert
from app.models.encryption_key import EncryptionKey
from app.models.attachment import Attachment
from app.models.attachment_version import AttachmentVersion
from app.models.trigger import Trigger, TriggerType
@@ -19,13 +20,18 @@ from app.models.trigger_log import TriggerLog, TriggerLogStatus
from app.models.scheduled_report import ScheduledReport, ReportType
from app.models.report_history import ReportHistory, ReportHistoryStatus
from app.models.project_health import ProjectHealth, RiskLevel, ScheduleStatus, ResourceStatus
from app.models.custom_field import CustomField, FieldType
from app.models.task_custom_value import TaskCustomValue
from app.models.task_dependency import TaskDependency, DependencyType
__all__ = [
"User", "Role", "Department", "Space", "Project", "TaskStatus", "Task", "WorkloadSnapshot",
"Comment", "Mention", "Notification", "Blocker",
"AuditLog", "AuditAlert", "AuditAction", "SensitivityLevel", "EVENT_SENSITIVITY", "ALERT_EVENTS",
"Attachment", "AttachmentVersion",
"EncryptionKey", "Attachment", "AttachmentVersion",
"Trigger", "TriggerType", "TriggerLog", "TriggerLogStatus",
"ScheduledReport", "ReportType", "ReportHistory", "ReportHistoryStatus",
"ProjectHealth", "RiskLevel", "ScheduleStatus", "ResourceStatus"
"ProjectHealth", "RiskLevel", "ScheduleStatus", "ResourceStatus",
"CustomField", "FieldType", "TaskCustomValue",
"TaskDependency", "DependencyType"
]

View File

@@ -16,6 +16,11 @@ class Attachment(Base):
file_size = Column(BigInteger, nullable=False)
current_version = Column(Integer, default=1, nullable=False)
is_encrypted = Column(Boolean, default=False, nullable=False)
encryption_key_id = Column(
String(36),
ForeignKey("pjctrl_encryption_keys.id", ondelete="SET NULL"),
nullable=True
)
uploaded_by = Column(String(36), ForeignKey("pjctrl_users.id", ondelete="SET NULL"), nullable=True)
is_deleted = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime, server_default=func.now(), nullable=False)
@@ -24,6 +29,7 @@ class Attachment(Base):
# Relationships
task = relationship("Task", back_populates="attachments")
uploader = relationship("User", foreign_keys=[uploaded_by])
encryption_key = relationship("EncryptionKey", foreign_keys=[encryption_key_id])
versions = relationship("AttachmentVersion", back_populates="attachment", cascade="all, delete-orphan")
__table_args__ = (

View File

@@ -0,0 +1,37 @@
import uuid
import enum
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Enum, JSON, Integer
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.core.database import Base
class FieldType(str, enum.Enum):
TEXT = "text"
NUMBER = "number"
DROPDOWN = "dropdown"
DATE = "date"
PERSON = "person"
FORMULA = "formula"
class CustomField(Base):
__tablename__ = "pjctrl_custom_fields"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("pjctrl_projects.id", ondelete="CASCADE"), nullable=False)
name = Column(String(100), nullable=False)
field_type = Column(
Enum("text", "number", "dropdown", "date", "person", "formula", name="field_type_enum"),
nullable=False
)
options = Column(JSON, nullable=True) # For dropdown: list of options
formula = Column(Text, nullable=True) # For formula: formula expression
is_required = Column(Boolean, default=False, nullable=False)
position = Column(Integer, default=0, nullable=False) # For ordering fields
created_at = Column(DateTime, server_default=func.now(), nullable=False)
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False)
# Relationships
project = relationship("Project", back_populates="custom_fields")
values = relationship("TaskCustomValue", back_populates="field", cascade="all, delete-orphan")

View File

@@ -0,0 +1,22 @@
"""EncryptionKey model for AES-256 file encryption key management."""
import uuid
from sqlalchemy import Column, String, Text, Boolean, DateTime
from sqlalchemy.sql import func
from app.core.database import Base
class EncryptionKey(Base):
"""
Encryption key storage for file encryption.
Keys are encrypted with the Master Key before storage.
Only system admin can manage encryption keys.
"""
__tablename__ = "pjctrl_encryption_keys"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
key_data = Column(Text, nullable=False) # Encrypted key using Master Key
algorithm = Column(String(20), default="AES-256-GCM", nullable=False)
is_active = Column(Boolean, default=True, nullable=False)
created_at = Column(DateTime, server_default=func.now(), nullable=False)
rotated_at = Column(DateTime, nullable=True) # When this key was superseded

View File

@@ -40,3 +40,4 @@ class Project(Base):
tasks = relationship("Task", back_populates="project", cascade="all, delete-orphan")
triggers = relationship("Trigger", back_populates="project", cascade="all, delete-orphan")
health = relationship("ProjectHealth", back_populates="project", uselist=False, cascade="all, delete-orphan")
custom_fields = relationship("CustomField", back_populates="project", cascade="all, delete-orphan")

View File

@@ -30,6 +30,7 @@ class Task(Base):
original_estimate = Column(Numeric(8, 2), nullable=True)
time_spent = Column(Numeric(8, 2), default=0, nullable=False)
blocker_flag = Column(Boolean, default=False, nullable=False)
start_date = Column(DateTime, nullable=True)
due_date = Column(DateTime, nullable=True)
position = Column(Integer, default=0, nullable=False)
created_by = Column(String(36), ForeignKey("pjctrl_users.id"), nullable=False)
@@ -55,3 +56,18 @@ class Task(Base):
blockers = relationship("Blocker", back_populates="task", cascade="all, delete-orphan")
attachments = relationship("Attachment", back_populates="task", cascade="all, delete-orphan")
trigger_logs = relationship("TriggerLog", back_populates="task")
custom_values = relationship("TaskCustomValue", back_populates="task", cascade="all, delete-orphan")
# Dependency relationships (for Gantt view)
predecessors = relationship(
"TaskDependency",
foreign_keys="TaskDependency.successor_id",
back_populates="successor",
cascade="all, delete-orphan"
)
successors = relationship(
"TaskDependency",
foreign_keys="TaskDependency.predecessor_id",
back_populates="predecessor",
cascade="all, delete-orphan"
)

View File

@@ -0,0 +1,24 @@
import uuid
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, UniqueConstraint
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.core.database import Base
class TaskCustomValue(Base):
__tablename__ = "pjctrl_task_custom_values"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
task_id = Column(String(36), ForeignKey("pjctrl_tasks.id", ondelete="CASCADE"), nullable=False)
field_id = Column(String(36), ForeignKey("pjctrl_custom_fields.id", ondelete="CASCADE"), nullable=False)
value = Column(Text, nullable=True) # Stored as text, parsed based on field_type
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False)
# Unique constraint: one value per task-field combination
__table_args__ = (
UniqueConstraint('task_id', 'field_id', name='uq_task_field'),
)
# Relationships
task = relationship("Task", back_populates="custom_values")
field = relationship("CustomField", back_populates="values")

View File

@@ -0,0 +1,68 @@
from sqlalchemy import Column, String, Integer, Enum, DateTime, ForeignKey, UniqueConstraint
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from app.core.database import Base
import enum
class DependencyType(str, enum.Enum):
"""
Task dependency types for Gantt chart.
FS (Finish-to-Start): Predecessor must finish before successor starts (most common)
SS (Start-to-Start): Predecessor must start before successor starts
FF (Finish-to-Finish): Predecessor must finish before successor finishes
SF (Start-to-Finish): Predecessor must start before successor finishes (rare)
"""
FS = "FS" # Finish-to-Start
SS = "SS" # Start-to-Start
FF = "FF" # Finish-to-Finish
SF = "SF" # Start-to-Finish
class TaskDependency(Base):
"""
Represents a dependency relationship between two tasks.
The predecessor task affects when the successor task can be scheduled,
based on the dependency_type. This is used for Gantt chart visualization
and date validation.
"""
__tablename__ = "pjctrl_task_dependencies"
__table_args__ = (
UniqueConstraint('predecessor_id', 'successor_id', name='uq_predecessor_successor'),
)
id = Column(String(36), primary_key=True)
predecessor_id = Column(
String(36),
ForeignKey("pjctrl_tasks.id", ondelete="CASCADE"),
nullable=False,
index=True
)
successor_id = Column(
String(36),
ForeignKey("pjctrl_tasks.id", ondelete="CASCADE"),
nullable=False,
index=True
)
dependency_type = Column(
Enum("FS", "SS", "FF", "SF", name="dependency_type_enum"),
default="FS",
nullable=False
)
lag_days = Column(Integer, default=0, nullable=False)
created_at = Column(DateTime, server_default=func.now(), nullable=False)
# Relationships
predecessor = relationship(
"Task",
foreign_keys=[predecessor_id],
back_populates="successors"
)
successor = relationship(
"Task",
foreign_keys=[successor_id],
back_populates="predecessors"
)

View File

@@ -0,0 +1,88 @@
from pydantic import BaseModel, Field, field_validator
from typing import Optional, List, Any, Dict
from datetime import datetime
from enum import Enum
class FieldType(str, Enum):
TEXT = "text"
NUMBER = "number"
DROPDOWN = "dropdown"
DATE = "date"
PERSON = "person"
FORMULA = "formula"
class CustomFieldBase(BaseModel):
name: str = Field(..., min_length=1, max_length=100)
field_type: FieldType
options: Optional[List[str]] = None # For dropdown type
formula: Optional[str] = None # For formula type
is_required: bool = False
@field_validator('options')
@classmethod
def validate_options(cls, v, info):
field_type = info.data.get('field_type')
if field_type == FieldType.DROPDOWN:
if not v or len(v) == 0:
raise ValueError('Dropdown fields must have at least one option')
if len(v) > 50:
raise ValueError('Dropdown fields can have at most 50 options')
return v
@field_validator('formula')
@classmethod
def validate_formula(cls, v, info):
field_type = info.data.get('field_type')
if field_type == FieldType.FORMULA and not v:
raise ValueError('Formula fields must have a formula expression')
return v
class CustomFieldCreate(CustomFieldBase):
pass
class CustomFieldUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=100)
options: Optional[List[str]] = None
formula: Optional[str] = None
is_required: Optional[bool] = None
class CustomFieldResponse(CustomFieldBase):
id: str
project_id: str
position: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class CustomFieldListResponse(BaseModel):
fields: List[CustomFieldResponse]
total: int
# Task custom value schemas
class CustomValueInput(BaseModel):
field_id: str
value: Optional[Any] = None # Can be string, number, date string, or user id
class CustomValueResponse(BaseModel):
field_id: str
field_name: str
field_type: FieldType
value: Optional[Any] = None
display_value: Optional[str] = None # Formatted for display
class Config:
from_attributes = True
class TaskCustomValuesUpdate(BaseModel):
custom_values: List[CustomValueInput]

View File

@@ -0,0 +1,46 @@
"""Schemas for encryption key API."""
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel
class EncryptionKeyResponse(BaseModel):
"""Response schema for encryption key (without actual key data)."""
id: str
algorithm: str
is_active: bool
created_at: datetime
rotated_at: Optional[datetime] = None
class Config:
from_attributes = True
class EncryptionKeyListResponse(BaseModel):
"""Response schema for list of encryption keys."""
keys: List[EncryptionKeyResponse]
total: int
class EncryptionKeyCreateResponse(BaseModel):
"""Response schema after creating a new encryption key."""
id: str
algorithm: str
is_active: bool
created_at: datetime
message: str
class EncryptionKeyRotateResponse(BaseModel):
"""Response schema after key rotation."""
new_key_id: str
old_key_id: Optional[str]
message: str
class EncryptionStatusResponse(BaseModel):
"""Response schema for encryption status."""
encryption_available: bool
active_key_id: Optional[str]
total_keys: int
message: str

View File

@@ -1,5 +1,5 @@
from pydantic import BaseModel
from typing import Optional, List
from typing import Optional, List, Any, Dict
from datetime import datetime
from decimal import Decimal
from enum import Enum
@@ -12,11 +12,27 @@ class Priority(str, Enum):
URGENT = "urgent"
class CustomValueInput(BaseModel):
"""Input for setting a custom field value."""
field_id: str
value: Optional[Any] = None # Can be string, number, date string, or user id
class CustomValueResponse(BaseModel):
"""Response for a custom field value."""
field_id: str
field_name: str
field_type: str
value: Optional[Any] = None
display_value: Optional[str] = None # Formatted for display
class TaskBase(BaseModel):
title: str
description: Optional[str] = None
priority: Priority = Priority.MEDIUM
original_estimate: Optional[Decimal] = None
start_date: Optional[datetime] = None
due_date: Optional[datetime] = None
@@ -24,6 +40,7 @@ class TaskCreate(TaskBase):
parent_task_id: Optional[str] = None
assignee_id: Optional[str] = None
status_id: Optional[str] = None
custom_values: Optional[List[CustomValueInput]] = None
class TaskUpdate(BaseModel):
@@ -32,8 +49,10 @@ class TaskUpdate(BaseModel):
priority: Optional[Priority] = None
original_estimate: Optional[Decimal] = None
time_spent: Optional[Decimal] = None
start_date: Optional[datetime] = None
due_date: Optional[datetime] = None
position: Optional[int] = None
custom_values: Optional[List[CustomValueInput]] = None
class TaskStatusUpdate(BaseModel):
@@ -67,6 +86,7 @@ class TaskWithDetails(TaskResponse):
status_color: Optional[str] = None
creator_name: Optional[str] = None
subtask_count: int = 0
custom_values: Optional[List[CustomValueResponse]] = None
class TaskListResponse(BaseModel):

View File

@@ -0,0 +1,78 @@
from pydantic import BaseModel, field_validator
from typing import Optional, List
from datetime import datetime
from enum import Enum
class DependencyType(str, Enum):
"""Task dependency types for Gantt chart."""
FS = "FS" # Finish-to-Start (most common)
SS = "SS" # Start-to-Start
FF = "FF" # Finish-to-Finish
SF = "SF" # Start-to-Finish (rare)
class TaskDependencyCreate(BaseModel):
"""Schema for creating a task dependency."""
predecessor_id: str
dependency_type: DependencyType = DependencyType.FS
lag_days: int = 0
@field_validator('lag_days')
@classmethod
def validate_lag_days(cls, v):
if v < -365 or v > 365:
raise ValueError('lag_days must be between -365 and 365')
return v
class TaskDependencyUpdate(BaseModel):
"""Schema for updating a task dependency."""
dependency_type: Optional[DependencyType] = None
lag_days: Optional[int] = None
@field_validator('lag_days')
@classmethod
def validate_lag_days(cls, v):
if v is not None and (v < -365 or v > 365):
raise ValueError('lag_days must be between -365 and 365')
return v
class TaskInfo(BaseModel):
"""Brief task information for dependency response."""
id: str
title: str
start_date: Optional[datetime] = None
due_date: Optional[datetime] = None
class Config:
from_attributes = True
class TaskDependencyResponse(BaseModel):
"""Schema for task dependency response."""
id: str
predecessor_id: str
successor_id: str
dependency_type: DependencyType
lag_days: int
created_at: datetime
predecessor: Optional[TaskInfo] = None
successor: Optional[TaskInfo] = None
class Config:
from_attributes = True
class TaskDependencyListResponse(BaseModel):
"""Schema for list of task dependencies."""
dependencies: List[TaskDependencyResponse]
total: int
class DependencyValidationError(BaseModel):
"""Schema for dependency validation error details."""
error_type: str # 'circular', 'self_reference', 'duplicate', 'cross_project'
message: str
details: Optional[dict] = None

View File

@@ -0,0 +1,278 @@
"""
Service for managing task custom values.
"""
import uuid
from typing import List, Dict, Any, Optional
from decimal import Decimal, InvalidOperation
from datetime import datetime
from sqlalchemy.orm import Session
from app.models import Task, CustomField, TaskCustomValue, User
from app.schemas.task import CustomValueInput, CustomValueResponse
from app.services.formula_service import FormulaService
class CustomValueService:
"""Service for managing custom field values on tasks."""
@staticmethod
def get_custom_values_for_task(
db: Session,
task: Task,
include_formula_calculations: bool = True,
) -> List[CustomValueResponse]:
"""
Get all custom field values for a task.
Args:
db: Database session
task: The task to get values for
include_formula_calculations: Whether to calculate formula field values
Returns:
List of CustomValueResponse objects
"""
# Get all custom fields for the project
fields = db.query(CustomField).filter(
CustomField.project_id == task.project_id
).order_by(CustomField.position).all()
if not fields:
return []
# Get stored values
stored_values = db.query(TaskCustomValue).filter(
TaskCustomValue.task_id == task.id
).all()
value_map = {v.field_id: v.value for v in stored_values}
# Calculate formula values if requested
formula_values = {}
if include_formula_calculations:
formula_values = FormulaService.calculate_all_formulas_for_task(db, task)
result = []
for field in fields:
if field.field_type == "formula":
# Use calculated formula value
calculated = formula_values.get(field.id)
value = str(calculated) if calculated is not None else None
display_value = CustomValueService._format_display_value(
field, value, db
)
else:
# Use stored value
value = value_map.get(field.id)
display_value = CustomValueService._format_display_value(
field, value, db
)
result.append(CustomValueResponse(
field_id=field.id,
field_name=field.name,
field_type=field.field_type,
value=value,
display_value=display_value,
))
return result
@staticmethod
def _format_display_value(
field: CustomField,
value: Optional[str],
db: Session,
) -> Optional[str]:
"""Format a value for display based on field type."""
if value is None:
return None
field_type = field.field_type
if field_type == "person":
# Look up user name
from app.models import User
user = db.query(User).filter(User.id == value).first()
return user.name if user else value
elif field_type == "number" or field_type == "formula":
# Format number
try:
num = Decimal(value)
# Remove trailing zeros after decimal point
formatted = f"{num:,.4f}".rstrip('0').rstrip('.')
return formatted
except (InvalidOperation, ValueError):
return value
elif field_type == "date":
# Format date
try:
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
return dt.strftime('%Y-%m-%d')
except (ValueError, AttributeError):
return value
else:
return value
@staticmethod
def save_custom_values(
db: Session,
task: Task,
custom_values: List[CustomValueInput],
) -> List[str]:
"""
Save custom field values for a task.
Args:
db: Database session
task: The task to save values for
custom_values: List of values to save
Returns:
List of field IDs that were updated (for formula recalculation)
"""
if not custom_values:
return []
updated_field_ids = []
for cv in custom_values:
field = db.query(CustomField).filter(
CustomField.id == cv.field_id,
CustomField.project_id == task.project_id,
).first()
if not field:
continue
# Skip formula fields - they are calculated, not stored directly
if field.field_type == "formula":
continue
# Validate value based on field type
validated_value = CustomValueService._validate_value(
field, cv.value, db
)
# Find existing value or create new
existing = db.query(TaskCustomValue).filter(
TaskCustomValue.task_id == task.id,
TaskCustomValue.field_id == cv.field_id,
).first()
if existing:
if existing.value != validated_value:
existing.value = validated_value
updated_field_ids.append(cv.field_id)
else:
new_value = TaskCustomValue(
id=str(uuid.uuid4()),
task_id=task.id,
field_id=cv.field_id,
value=validated_value,
)
db.add(new_value)
updated_field_ids.append(cv.field_id)
# Recalculate formula fields if any values were updated
if updated_field_ids:
for field_id in updated_field_ids:
FormulaService.recalculate_dependent_formulas(db, task, field_id)
return updated_field_ids
@staticmethod
def _validate_value(
field: CustomField,
value: Any,
db: Session,
) -> Optional[str]:
"""
Validate and normalize a value based on field type.
Returns the validated value as a string, or None.
"""
if value is None or value == "":
if field.is_required:
raise ValueError(f"Field '{field.name}' is required")
return None
field_type = field.field_type
str_value = str(value)
if field_type == "text":
return str_value
elif field_type == "number":
try:
Decimal(str_value)
return str_value
except (InvalidOperation, ValueError):
raise ValueError(f"Invalid number for field '{field.name}'")
elif field_type == "dropdown":
if field.options and str_value not in field.options:
raise ValueError(
f"Invalid option for field '{field.name}'. "
f"Must be one of: {', '.join(field.options)}"
)
return str_value
elif field_type == "date":
# Validate date format
try:
datetime.fromisoformat(str_value.replace('Z', '+00:00'))
return str_value
except ValueError:
# Try parsing as date only
try:
datetime.strptime(str_value, '%Y-%m-%d')
return str_value
except ValueError:
raise ValueError(f"Invalid date for field '{field.name}'")
elif field_type == "person":
# Validate user exists
from app.models import User
user = db.query(User).filter(User.id == str_value).first()
if not user:
raise ValueError(f"Invalid user ID for field '{field.name}'")
return str_value
return str_value
@staticmethod
def validate_required_fields(
db: Session,
project_id: str,
custom_values: Optional[List[CustomValueInput]],
) -> List[str]:
"""
Validate that all required custom fields have values.
Returns list of missing required field names.
"""
required_fields = db.query(CustomField).filter(
CustomField.project_id == project_id,
CustomField.is_required == True,
CustomField.field_type != "formula", # Formula fields are calculated
).all()
if not required_fields:
return []
provided_field_ids = set()
if custom_values:
for cv in custom_values:
if cv.value is not None and cv.value != "":
provided_field_ids.add(cv.field_id)
missing = []
for field in required_fields:
if field.id not in provided_field_ids:
missing.append(field.name)
return missing

View File

@@ -0,0 +1,424 @@
"""
Dependency Service
Handles task dependency validation including:
- Circular dependency detection using DFS
- Date constraint validation based on dependency types
- Self-reference prevention
- Cross-project dependency prevention
"""
from typing import List, Optional, Set, Tuple, Dict, Any
from collections import defaultdict
from sqlalchemy.orm import Session
from datetime import datetime, timedelta
from app.models import Task, TaskDependency
class DependencyValidationError(Exception):
"""Custom exception for dependency validation errors."""
def __init__(self, error_type: str, message: str, details: Optional[dict] = None):
self.error_type = error_type
self.message = message
self.details = details or {}
super().__init__(message)
class DependencyService:
"""Service for managing task dependencies with validation."""
# Maximum number of direct dependencies per task (as per spec)
MAX_DIRECT_DEPENDENCIES = 10
@staticmethod
def detect_circular_dependency(
db: Session,
predecessor_id: str,
successor_id: str,
project_id: str
) -> Optional[List[str]]:
"""
Detect if adding a dependency would create a circular reference.
Uses DFS to traverse from the successor to check if we can reach
the predecessor through existing dependencies.
Args:
db: Database session
predecessor_id: The task that must complete first
successor_id: The task that depends on the predecessor
project_id: Project ID to scope the query
Returns:
List of task IDs forming the cycle if circular, None otherwise
"""
# If adding predecessor -> successor, check if successor can reach predecessor
# This would mean predecessor depends (transitively) on successor, creating a cycle
# Build adjacency list for the project's dependencies
dependencies = db.query(TaskDependency).join(
Task, TaskDependency.successor_id == Task.id
).filter(Task.project_id == project_id).all()
# Graph: successor -> [predecessors]
# We need to check if predecessor is reachable from successor
# by following the chain of "what does this task depend on"
graph: Dict[str, List[str]] = defaultdict(list)
for dep in dependencies:
graph[dep.successor_id].append(dep.predecessor_id)
# Simulate adding the new edge
graph[successor_id].append(predecessor_id)
# DFS to find if there's a path from predecessor back to successor
# (which would complete a cycle)
visited: Set[str] = set()
path: List[str] = []
in_path: Set[str] = set()
def dfs(node: str) -> Optional[List[str]]:
"""DFS traversal to detect cycles."""
if node in in_path:
# Found a cycle - return the cycle path
cycle_start = path.index(node)
return path[cycle_start:] + [node]
if node in visited:
return None
visited.add(node)
in_path.add(node)
path.append(node)
for neighbor in graph.get(node, []):
result = dfs(neighbor)
if result:
return result
path.pop()
in_path.remove(node)
return None
# Start DFS from the successor to check if we can reach back to it
return dfs(successor_id)
@staticmethod
def validate_dependency(
db: Session,
predecessor_id: str,
successor_id: str
) -> None:
"""
Validate that a dependency can be created.
Raises DependencyValidationError if validation fails.
Checks:
1. Self-reference
2. Both tasks exist
3. Both tasks are in the same project
4. No duplicate dependency
5. No circular dependency
6. Dependency limit not exceeded
"""
# Check self-reference
if predecessor_id == successor_id:
raise DependencyValidationError(
error_type="self_reference",
message="A task cannot depend on itself"
)
# Get both tasks
predecessor = db.query(Task).filter(Task.id == predecessor_id).first()
successor = db.query(Task).filter(Task.id == successor_id).first()
if not predecessor:
raise DependencyValidationError(
error_type="not_found",
message="Predecessor task not found",
details={"task_id": predecessor_id}
)
if not successor:
raise DependencyValidationError(
error_type="not_found",
message="Successor task not found",
details={"task_id": successor_id}
)
# Check same project
if predecessor.project_id != successor.project_id:
raise DependencyValidationError(
error_type="cross_project",
message="Dependencies can only be created between tasks in the same project",
details={
"predecessor_project_id": predecessor.project_id,
"successor_project_id": successor.project_id
}
)
# Check duplicate
existing = db.query(TaskDependency).filter(
TaskDependency.predecessor_id == predecessor_id,
TaskDependency.successor_id == successor_id
).first()
if existing:
raise DependencyValidationError(
error_type="duplicate",
message="This dependency already exists"
)
# Check dependency limit
current_count = db.query(TaskDependency).filter(
TaskDependency.successor_id == successor_id
).count()
if current_count >= DependencyService.MAX_DIRECT_DEPENDENCIES:
raise DependencyValidationError(
error_type="limit_exceeded",
message=f"A task cannot have more than {DependencyService.MAX_DIRECT_DEPENDENCIES} direct dependencies",
details={"current_count": current_count}
)
# Check circular dependency
cycle = DependencyService.detect_circular_dependency(
db, predecessor_id, successor_id, predecessor.project_id
)
if cycle:
raise DependencyValidationError(
error_type="circular",
message="Adding this dependency would create a circular reference",
details={"cycle": cycle}
)
@staticmethod
def validate_date_constraints(
task: Task,
start_date: Optional[datetime],
due_date: Optional[datetime],
db: Session
) -> List[Dict[str, Any]]:
"""
Validate date changes against dependency constraints.
Returns a list of constraint violations (empty if valid).
Dependency type meanings:
- FS: predecessor.due_date + lag <= successor.start_date
- SS: predecessor.start_date + lag <= successor.start_date
- FF: predecessor.due_date + lag <= successor.due_date
- SF: predecessor.start_date + lag <= successor.due_date
"""
violations = []
# Use provided dates or fall back to current task dates
new_start = start_date if start_date is not None else task.start_date
new_due = due_date if due_date is not None else task.due_date
# Basic date validation
if new_start and new_due and new_start > new_due:
violations.append({
"type": "date_order",
"message": "Start date cannot be after due date",
"start_date": str(new_start),
"due_date": str(new_due)
})
# Get dependencies where this task is the successor (predecessors)
predecessors = db.query(TaskDependency).filter(
TaskDependency.successor_id == task.id
).all()
for dep in predecessors:
pred_task = dep.predecessor
if not pred_task:
continue
lag = timedelta(days=dep.lag_days)
violation = None
if dep.dependency_type == "FS":
# Predecessor must finish before successor starts
if pred_task.due_date and new_start:
required_start = pred_task.due_date + lag
if new_start < required_start:
violation = {
"type": "dependency_constraint",
"dependency_type": "FS",
"predecessor_id": pred_task.id,
"predecessor_title": pred_task.title,
"message": f"Start date must be on or after {required_start.date()} (predecessor due date + {dep.lag_days} days lag)"
}
elif dep.dependency_type == "SS":
# Predecessor must start before successor starts
if pred_task.start_date and new_start:
required_start = pred_task.start_date + lag
if new_start < required_start:
violation = {
"type": "dependency_constraint",
"dependency_type": "SS",
"predecessor_id": pred_task.id,
"predecessor_title": pred_task.title,
"message": f"Start date must be on or after {required_start.date()} (predecessor start date + {dep.lag_days} days lag)"
}
elif dep.dependency_type == "FF":
# Predecessor must finish before successor finishes
if pred_task.due_date and new_due:
required_due = pred_task.due_date + lag
if new_due < required_due:
violation = {
"type": "dependency_constraint",
"dependency_type": "FF",
"predecessor_id": pred_task.id,
"predecessor_title": pred_task.title,
"message": f"Due date must be on or after {required_due.date()} (predecessor due date + {dep.lag_days} days lag)"
}
elif dep.dependency_type == "SF":
# Predecessor must start before successor finishes
if pred_task.start_date and new_due:
required_due = pred_task.start_date + lag
if new_due < required_due:
violation = {
"type": "dependency_constraint",
"dependency_type": "SF",
"predecessor_id": pred_task.id,
"predecessor_title": pred_task.title,
"message": f"Due date must be on or after {required_due.date()} (predecessor start date + {dep.lag_days} days lag)"
}
if violation:
violations.append(violation)
# Get dependencies where this task is the predecessor (successors)
successors = db.query(TaskDependency).filter(
TaskDependency.predecessor_id == task.id
).all()
for dep in successors:
succ_task = dep.successor
if not succ_task:
continue
lag = timedelta(days=dep.lag_days)
violation = None
if dep.dependency_type == "FS":
# This task must finish before successor starts
if new_due and succ_task.start_date:
required_due = succ_task.start_date - lag
if new_due > required_due:
violation = {
"type": "dependency_constraint",
"dependency_type": "FS",
"successor_id": succ_task.id,
"successor_title": succ_task.title,
"message": f"Due date must be on or before {required_due.date()} (successor start date - {dep.lag_days} days lag)"
}
elif dep.dependency_type == "SS":
# This task must start before successor starts
if new_start and succ_task.start_date:
required_start = succ_task.start_date - lag
if new_start > required_start:
violation = {
"type": "dependency_constraint",
"dependency_type": "SS",
"successor_id": succ_task.id,
"successor_title": succ_task.title,
"message": f"Start date must be on or before {required_start.date()} (successor start date - {dep.lag_days} days lag)"
}
elif dep.dependency_type == "FF":
# This task must finish before successor finishes
if new_due and succ_task.due_date:
required_due = succ_task.due_date - lag
if new_due > required_due:
violation = {
"type": "dependency_constraint",
"dependency_type": "FF",
"successor_id": succ_task.id,
"successor_title": succ_task.title,
"message": f"Due date must be on or before {required_due.date()} (successor due date - {dep.lag_days} days lag)"
}
elif dep.dependency_type == "SF":
# This task must start before successor finishes
if new_start and succ_task.due_date:
required_start = succ_task.due_date - lag
if new_start > required_start:
violation = {
"type": "dependency_constraint",
"dependency_type": "SF",
"successor_id": succ_task.id,
"successor_title": succ_task.title,
"message": f"Start date must be on or before {required_start.date()} (successor due date - {dep.lag_days} days lag)"
}
if violation:
violations.append(violation)
return violations
@staticmethod
def get_all_predecessors(db: Session, task_id: str) -> List[str]:
"""
Get all transitive predecessors of a task.
Uses BFS to find all tasks that this task depends on (directly or indirectly).
"""
visited: Set[str] = set()
queue = [task_id]
predecessors = []
while queue:
current = queue.pop(0)
if current in visited:
continue
visited.add(current)
deps = db.query(TaskDependency).filter(
TaskDependency.successor_id == current
).all()
for dep in deps:
if dep.predecessor_id not in visited:
predecessors.append(dep.predecessor_id)
queue.append(dep.predecessor_id)
return predecessors
@staticmethod
def get_all_successors(db: Session, task_id: str) -> List[str]:
"""
Get all transitive successors of a task.
Uses BFS to find all tasks that depend on this task (directly or indirectly).
"""
visited: Set[str] = set()
queue = [task_id]
successors = []
while queue:
current = queue.pop(0)
if current in visited:
continue
visited.add(current)
deps = db.query(TaskDependency).filter(
TaskDependency.predecessor_id == current
).all()
for dep in deps:
if dep.successor_id not in visited:
successors.append(dep.successor_id)
queue.append(dep.successor_id)
return successors

View File

@@ -0,0 +1,300 @@
"""
Encryption service for AES-256-GCM file encryption.
This service handles:
- File encryption key generation and management
- Encrypting/decrypting file encryption keys with Master Key
- Streaming file encryption/decryption with AES-256-GCM
"""
import os
import base64
import secrets
import logging
from typing import BinaryIO, Tuple, Optional, Generator
from io import BytesIO
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.backends import default_backend
from app.core.config import settings
logger = logging.getLogger(__name__)
# Constants
KEY_SIZE = 32 # 256 bits for AES-256
NONCE_SIZE = 12 # 96 bits for GCM recommended nonce size
TAG_SIZE = 16 # 128 bits for GCM authentication tag
CHUNK_SIZE = 64 * 1024 # 64KB chunks for streaming
class EncryptionError(Exception):
"""Base exception for encryption errors."""
pass
class MasterKeyNotConfiguredError(EncryptionError):
"""Raised when master key is not configured."""
pass
class DecryptionError(EncryptionError):
"""Raised when decryption fails."""
pass
class EncryptionService:
"""
Service for file encryption using AES-256-GCM.
Key hierarchy:
1. Master Key (from environment) -> encrypts file encryption keys
2. File Encryption Keys (stored in DB) -> encrypt actual files
"""
def __init__(self):
self._master_key: Optional[bytes] = None
@property
def master_key(self) -> bytes:
"""Get the master key, loading from config if needed."""
if self._master_key is None:
if not settings.ENCRYPTION_MASTER_KEY:
raise MasterKeyNotConfiguredError(
"ENCRYPTION_MASTER_KEY is not configured. "
"File encryption is disabled."
)
self._master_key = base64.urlsafe_b64decode(settings.ENCRYPTION_MASTER_KEY)
return self._master_key
def is_encryption_available(self) -> bool:
"""Check if encryption is available (master key configured)."""
return settings.ENCRYPTION_MASTER_KEY is not None
def generate_key(self) -> bytes:
"""
Generate a new AES-256 encryption key.
Returns:
32-byte random key
"""
return secrets.token_bytes(KEY_SIZE)
def encrypt_key(self, key: bytes) -> str:
"""
Encrypt a file encryption key using the Master Key.
Args:
key: The raw 32-byte file encryption key
Returns:
Base64-encoded encrypted key (nonce + ciphertext + tag)
"""
aesgcm = AESGCM(self.master_key)
nonce = secrets.token_bytes(NONCE_SIZE)
# Encrypt the key
ciphertext = aesgcm.encrypt(nonce, key, None)
# Combine nonce + ciphertext (includes tag)
encrypted_data = nonce + ciphertext
return base64.urlsafe_b64encode(encrypted_data).decode('utf-8')
def decrypt_key(self, encrypted_key: str) -> bytes:
"""
Decrypt a file encryption key using the Master Key.
Args:
encrypted_key: Base64-encoded encrypted key
Returns:
The raw 32-byte file encryption key
"""
try:
encrypted_data = base64.urlsafe_b64decode(encrypted_key)
# Extract nonce and ciphertext
nonce = encrypted_data[:NONCE_SIZE]
ciphertext = encrypted_data[NONCE_SIZE:]
# Decrypt
aesgcm = AESGCM(self.master_key)
return aesgcm.decrypt(nonce, ciphertext, None)
except Exception as e:
logger.error(f"Failed to decrypt encryption key: {e}")
raise DecryptionError("Failed to decrypt file encryption key")
def encrypt_file(self, file_content: BinaryIO, key: bytes) -> bytes:
"""
Encrypt file content using AES-256-GCM.
For smaller files, encrypts the entire content at once.
The format is: nonce (12 bytes) + ciphertext + tag (16 bytes)
Args:
file_content: File-like object to encrypt
key: 32-byte AES-256 key
Returns:
Encrypted bytes (nonce + ciphertext + tag)
"""
# Read all content
plaintext = file_content.read()
# Generate nonce
nonce = secrets.token_bytes(NONCE_SIZE)
# Encrypt
aesgcm = AESGCM(key)
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
# Return nonce + ciphertext (tag is appended by encrypt)
return nonce + ciphertext
def decrypt_file(self, encrypted_content: BinaryIO, key: bytes) -> bytes:
"""
Decrypt file content using AES-256-GCM.
Args:
encrypted_content: File-like object containing encrypted data
key: 32-byte AES-256 key
Returns:
Decrypted bytes
"""
try:
# Read all encrypted content
encrypted_data = encrypted_content.read()
# Extract nonce and ciphertext
nonce = encrypted_data[:NONCE_SIZE]
ciphertext = encrypted_data[NONCE_SIZE:]
# Decrypt
aesgcm = AESGCM(key)
plaintext = aesgcm.decrypt(nonce, ciphertext, None)
return plaintext
except Exception as e:
logger.error(f"Failed to decrypt file: {e}")
raise DecryptionError("Failed to decrypt file. The file may be corrupted or the key is incorrect.")
def encrypt_file_streaming(self, file_content: BinaryIO, key: bytes) -> Generator[bytes, None, None]:
"""
Encrypt file content using AES-256-GCM with streaming.
For large files, encrypts in chunks. Each chunk has its own nonce.
Format per chunk: chunk_size (4 bytes) + nonce (12 bytes) + ciphertext + tag
Args:
file_content: File-like object to encrypt
key: 32-byte AES-256 key
Yields:
Encrypted chunks
"""
aesgcm = AESGCM(key)
# Write header with version byte
yield b'\x01' # Version 1 for streaming format
while True:
chunk = file_content.read(CHUNK_SIZE)
if not chunk:
break
# Generate nonce for this chunk
nonce = secrets.token_bytes(NONCE_SIZE)
# Encrypt chunk
ciphertext = aesgcm.encrypt(nonce, chunk, None)
# Write chunk size (4 bytes, little endian)
chunk_size = len(ciphertext) + NONCE_SIZE
yield chunk_size.to_bytes(4, 'little')
# Write nonce + ciphertext
yield nonce + ciphertext
# Write end marker (zero size)
yield b'\x00\x00\x00\x00'
def decrypt_file_streaming(self, encrypted_content: BinaryIO, key: bytes) -> Generator[bytes, None, None]:
"""
Decrypt file content using AES-256-GCM with streaming.
Args:
encrypted_content: File-like object containing encrypted data
key: 32-byte AES-256 key
Yields:
Decrypted chunks
"""
aesgcm = AESGCM(key)
# Read version byte
version = encrypted_content.read(1)
if version != b'\x01':
raise DecryptionError(f"Unknown encryption format version")
while True:
# Read chunk size
size_bytes = encrypted_content.read(4)
if len(size_bytes) < 4:
raise DecryptionError("Unexpected end of file")
chunk_size = int.from_bytes(size_bytes, 'little')
# Check for end marker
if chunk_size == 0:
break
# Read chunk (nonce + ciphertext)
chunk = encrypted_content.read(chunk_size)
if len(chunk) < chunk_size:
raise DecryptionError("Unexpected end of file")
# Extract nonce and ciphertext
nonce = chunk[:NONCE_SIZE]
ciphertext = chunk[NONCE_SIZE:]
try:
# Decrypt
plaintext = aesgcm.decrypt(nonce, ciphertext, None)
yield plaintext
except Exception as e:
raise DecryptionError(f"Failed to decrypt chunk: {e}")
def encrypt_bytes(self, data: bytes, key: bytes) -> bytes:
"""
Encrypt bytes directly (convenience method).
Args:
data: Bytes to encrypt
key: 32-byte AES-256 key
Returns:
Encrypted bytes
"""
return self.encrypt_file(BytesIO(data), key)
def decrypt_bytes(self, encrypted_data: bytes, key: bytes) -> bytes:
"""
Decrypt bytes directly (convenience method).
Args:
encrypted_data: Encrypted bytes
key: 32-byte AES-256 key
Returns:
Decrypted bytes
"""
return self.decrypt_file(BytesIO(encrypted_data), key)
# Singleton instance
encryption_service = EncryptionService()

View File

@@ -0,0 +1,420 @@
"""
Formula Service for Custom Fields
Supports:
- Basic math operations: +, -, *, /
- Field references: {field_name}
- Built-in task fields: {original_estimate}, {time_spent}
- Parentheses for grouping
Example formulas:
- "{time_spent} / {original_estimate} * 100"
- "{cost_per_hour} * {hours_worked}"
- "({field_a} + {field_b}) / 2"
"""
import re
import ast
import operator
from typing import Dict, Any, Optional, List, Set, Tuple
from decimal import Decimal, InvalidOperation
from sqlalchemy.orm import Session
from app.models import Task, CustomField, TaskCustomValue
class FormulaError(Exception):
"""Exception raised for formula parsing or calculation errors."""
pass
class CircularReferenceError(FormulaError):
"""Exception raised when circular references are detected in formulas."""
pass
class FormulaService:
"""Service for parsing and calculating formula fields."""
# Built-in task fields that can be referenced in formulas
BUILTIN_FIELDS = {
"original_estimate",
"time_spent",
}
# Supported operators
OPERATORS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.USub: operator.neg,
}
@staticmethod
def extract_field_references(formula: str) -> Set[str]:
"""
Extract all field references from a formula.
Field references are in the format {field_name}.
Returns a set of field names referenced in the formula.
"""
pattern = r'\{([^}]+)\}'
matches = re.findall(pattern, formula)
return set(matches)
@staticmethod
def validate_formula(
formula: str,
project_id: str,
db: Session,
current_field_id: Optional[str] = None,
) -> Tuple[bool, Optional[str]]:
"""
Validate a formula expression.
Checks:
1. Syntax is valid
2. All referenced fields exist
3. Referenced fields are number or formula type
4. No circular references
Returns (is_valid, error_message)
"""
if not formula or not formula.strip():
return False, "Formula cannot be empty"
# Extract field references
references = FormulaService.extract_field_references(formula)
if not references:
return False, "Formula must reference at least one field"
# Validate syntax by trying to parse
try:
# Replace field references with dummy numbers for syntax check
test_formula = formula
for ref in references:
test_formula = test_formula.replace(f"{{{ref}}}", "1")
# Try to parse and evaluate with safe operations
FormulaService._safe_eval(test_formula)
except Exception as e:
return False, f"Invalid formula syntax: {str(e)}"
# Separate builtin and custom field references
custom_references = references - FormulaService.BUILTIN_FIELDS
# Validate custom field references exist and are numeric types
if custom_references:
fields = db.query(CustomField).filter(
CustomField.project_id == project_id,
CustomField.name.in_(custom_references),
).all()
found_names = {f.name for f in fields}
missing = custom_references - found_names
if missing:
return False, f"Unknown field references: {', '.join(missing)}"
# Check field types (must be number or formula)
for field in fields:
if field.field_type not in ("number", "formula"):
return False, f"Field '{field.name}' is not a numeric type"
# Check for circular references
if current_field_id:
try:
FormulaService._check_circular_references(
db, project_id, current_field_id, references
)
except CircularReferenceError as e:
return False, str(e)
return True, None
@staticmethod
def _check_circular_references(
db: Session,
project_id: str,
field_id: str,
references: Set[str],
visited: Optional[Set[str]] = None,
) -> None:
"""
Check for circular references in formula fields.
Raises CircularReferenceError if a cycle is detected.
"""
if visited is None:
visited = set()
# Get the current field's name
current_field = db.query(CustomField).filter(
CustomField.id == field_id
).first()
if current_field:
if current_field.name in references:
raise CircularReferenceError(
f"Circular reference detected: field cannot reference itself"
)
# Get all referenced formula fields
custom_references = references - FormulaService.BUILTIN_FIELDS
if not custom_references:
return
formula_fields = db.query(CustomField).filter(
CustomField.project_id == project_id,
CustomField.name.in_(custom_references),
CustomField.field_type == "formula",
).all()
for field in formula_fields:
if field.id in visited:
raise CircularReferenceError(
f"Circular reference detected involving field '{field.name}'"
)
visited.add(field.id)
if field.formula:
nested_refs = FormulaService.extract_field_references(field.formula)
if current_field and current_field.name in nested_refs:
raise CircularReferenceError(
f"Circular reference detected: '{field.name}' references the current field"
)
FormulaService._check_circular_references(
db, project_id, field_id, nested_refs, visited
)
@staticmethod
def _safe_eval(expression: str) -> Decimal:
"""
Safely evaluate a mathematical expression.
Only allows basic arithmetic operations (+, -, *, /).
"""
try:
node = ast.parse(expression, mode='eval')
return FormulaService._eval_node(node.body)
except Exception as e:
raise FormulaError(f"Failed to evaluate expression: {str(e)}")
@staticmethod
def _eval_node(node: ast.AST) -> Decimal:
"""Recursively evaluate an AST node."""
if isinstance(node, ast.Constant):
if isinstance(node.value, (int, float)):
return Decimal(str(node.value))
raise FormulaError(f"Invalid constant: {node.value}")
elif isinstance(node, ast.BinOp):
left = FormulaService._eval_node(node.left)
right = FormulaService._eval_node(node.right)
op = FormulaService.OPERATORS.get(type(node.op))
if op is None:
raise FormulaError(f"Unsupported operator: {type(node.op).__name__}")
# Handle division by zero
if isinstance(node.op, ast.Div) and right == 0:
return Decimal('0') # Return 0 instead of raising error
return Decimal(str(op(float(left), float(right))))
elif isinstance(node, ast.UnaryOp):
operand = FormulaService._eval_node(node.operand)
op = FormulaService.OPERATORS.get(type(node.op))
if op is None:
raise FormulaError(f"Unsupported operator: {type(node.op).__name__}")
return Decimal(str(op(float(operand))))
else:
raise FormulaError(f"Unsupported expression type: {type(node).__name__}")
@staticmethod
def calculate_formula(
formula: str,
task: Task,
db: Session,
calculated_cache: Optional[Dict[str, Decimal]] = None,
) -> Optional[Decimal]:
"""
Calculate the value of a formula for a given task.
Args:
formula: The formula expression
task: The task to calculate for
db: Database session
calculated_cache: Cache for already calculated formula values (for recursion)
Returns:
The calculated value, or None if calculation fails
"""
if calculated_cache is None:
calculated_cache = {}
references = FormulaService.extract_field_references(formula)
values: Dict[str, Decimal] = {}
# Get builtin field values
for ref in references:
if ref in FormulaService.BUILTIN_FIELDS:
task_value = getattr(task, ref, None)
if task_value is not None:
values[ref] = Decimal(str(task_value))
else:
values[ref] = Decimal('0')
# Get custom field values
custom_references = references - FormulaService.BUILTIN_FIELDS
if custom_references:
# Get field definitions
fields = db.query(CustomField).filter(
CustomField.project_id == task.project_id,
CustomField.name.in_(custom_references),
).all()
field_map = {f.name: f for f in fields}
# Get custom values for this task
custom_values = db.query(TaskCustomValue).filter(
TaskCustomValue.task_id == task.id,
TaskCustomValue.field_id.in_([f.id for f in fields]),
).all()
value_map = {cv.field_id: cv.value for cv in custom_values}
for ref in custom_references:
field = field_map.get(ref)
if not field:
values[ref] = Decimal('0')
continue
if field.field_type == "formula":
# Recursively calculate formula fields
if field.id in calculated_cache:
values[ref] = calculated_cache[field.id]
else:
nested_value = FormulaService.calculate_formula(
field.formula, task, db, calculated_cache
)
values[ref] = nested_value if nested_value is not None else Decimal('0')
calculated_cache[field.id] = values[ref]
else:
# Get stored value
stored_value = value_map.get(field.id)
if stored_value:
try:
values[ref] = Decimal(str(stored_value))
except (InvalidOperation, ValueError):
values[ref] = Decimal('0')
else:
values[ref] = Decimal('0')
# Substitute values into formula
expression = formula
for ref, value in values.items():
expression = expression.replace(f"{{{ref}}}", str(value))
# Evaluate the expression
try:
result = FormulaService._safe_eval(expression)
# Round to 4 decimal places
return result.quantize(Decimal('0.0001'))
except Exception:
return None
@staticmethod
def recalculate_dependent_formulas(
db: Session,
task: Task,
changed_field_id: str,
) -> Dict[str, Decimal]:
"""
Recalculate all formula fields that depend on a changed field.
Returns a dict of field_id -> calculated_value for updated formulas.
"""
# Get the changed field
changed_field = db.query(CustomField).filter(
CustomField.id == changed_field_id
).first()
if not changed_field:
return {}
# Find all formula fields in the project
formula_fields = db.query(CustomField).filter(
CustomField.project_id == task.project_id,
CustomField.field_type == "formula",
).all()
results = {}
calculated_cache: Dict[str, Decimal] = {}
for field in formula_fields:
if not field.formula:
continue
# Check if this formula depends on the changed field
references = FormulaService.extract_field_references(field.formula)
if changed_field.name in references or changed_field.name in FormulaService.BUILTIN_FIELDS:
value = FormulaService.calculate_formula(
field.formula, task, db, calculated_cache
)
if value is not None:
results[field.id] = value
calculated_cache[field.id] = value
# Update or create the custom value
existing = db.query(TaskCustomValue).filter(
TaskCustomValue.task_id == task.id,
TaskCustomValue.field_id == field.id,
).first()
if existing:
existing.value = str(value)
else:
import uuid
new_value = TaskCustomValue(
id=str(uuid.uuid4()),
task_id=task.id,
field_id=field.id,
value=str(value),
)
db.add(new_value)
return results
@staticmethod
def calculate_all_formulas_for_task(
db: Session,
task: Task,
) -> Dict[str, Decimal]:
"""
Calculate all formula fields for a task.
Used when loading a task to get current formula values.
"""
formula_fields = db.query(CustomField).filter(
CustomField.project_id == task.project_id,
CustomField.field_type == "formula",
).all()
results = {}
calculated_cache: Dict[str, Decimal] = {}
for field in formula_fields:
if not field.formula:
continue
value = FormulaService.calculate_formula(
field.formula, task, db, calculated_cache
)
if value is not None:
results[field.id] = value
calculated_cache[field.id] = value
return results

View File

@@ -0,0 +1,70 @@
"""Add custom fields and task custom values tables
Revision ID: 011
Revises: 010
Create Date: 2026-01-05
FEAT-001: Add custom fields feature for flexible task data extension.
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '011'
down_revision: Union[str, None] = '010'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create field_type_enum
field_type_enum = sa.Enum(
'text', 'number', 'dropdown', 'date', 'person', 'formula',
name='field_type_enum'
)
field_type_enum.create(op.get_bind(), checkfirst=True)
# Create pjctrl_custom_fields table
op.create_table(
'pjctrl_custom_fields',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('project_id', sa.String(36), sa.ForeignKey('pjctrl_projects.id', ondelete='CASCADE'), nullable=False),
sa.Column('name', sa.String(100), nullable=False),
sa.Column('field_type', field_type_enum, nullable=False),
sa.Column('options', sa.JSON, nullable=True),
sa.Column('formula', sa.Text, nullable=True),
sa.Column('is_required', sa.Boolean, default=False, nullable=False),
sa.Column('position', sa.Integer, default=0, nullable=False),
sa.Column('created_at', sa.DateTime, server_default=sa.func.now(), nullable=False),
sa.Column('updated_at', sa.DateTime, server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
)
# Create indexes for custom_fields
op.create_index('ix_pjctrl_custom_fields_project_id', 'pjctrl_custom_fields', ['project_id'])
# Create pjctrl_task_custom_values table
op.create_table(
'pjctrl_task_custom_values',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('task_id', sa.String(36), sa.ForeignKey('pjctrl_tasks.id', ondelete='CASCADE'), nullable=False),
sa.Column('field_id', sa.String(36), sa.ForeignKey('pjctrl_custom_fields.id', ondelete='CASCADE'), nullable=False),
sa.Column('value', sa.Text, nullable=True),
sa.Column('updated_at', sa.DateTime, server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
sa.UniqueConstraint('task_id', 'field_id', name='uq_task_field'),
)
# Create indexes for task_custom_values
op.create_index('ix_pjctrl_task_custom_values_task_id', 'pjctrl_task_custom_values', ['task_id'])
op.create_index('ix_pjctrl_task_custom_values_field_id', 'pjctrl_task_custom_values', ['field_id'])
def downgrade() -> None:
op.drop_index('ix_pjctrl_task_custom_values_field_id', table_name='pjctrl_task_custom_values')
op.drop_index('ix_pjctrl_task_custom_values_task_id', table_name='pjctrl_task_custom_values')
op.drop_table('pjctrl_task_custom_values')
op.drop_index('ix_pjctrl_custom_fields_project_id', table_name='pjctrl_custom_fields')
op.drop_table('pjctrl_custom_fields')
# Drop the enum type
sa.Enum(name='field_type_enum').drop(op.get_bind(), checkfirst=True)

View File

@@ -0,0 +1,48 @@
"""Encryption keys table and attachment encryption_key_id
Revision ID: 012
Revises: 011
Create Date: 2026-01-05
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers
revision = '012'
down_revision = '011'
branch_labels = None
depends_on = None
def upgrade():
# Create encryption_keys table
op.create_table(
'pjctrl_encryption_keys',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('key_data', sa.Text, nullable=False), # Encrypted key using Master Key
sa.Column('algorithm', sa.String(20), default='AES-256-GCM', nullable=False),
sa.Column('is_active', sa.Boolean, default=True, nullable=False),
sa.Column('created_at', sa.DateTime, server_default=sa.func.now(), nullable=False),
sa.Column('rotated_at', sa.DateTime, nullable=True),
)
op.create_index('idx_encryption_key_active', 'pjctrl_encryption_keys', ['is_active'])
# Add encryption_key_id column to attachments table
op.add_column(
'pjctrl_attachments',
sa.Column(
'encryption_key_id',
sa.String(36),
sa.ForeignKey('pjctrl_encryption_keys.id', ondelete='SET NULL'),
nullable=True
)
)
op.create_index('idx_attachment_encryption_key', 'pjctrl_attachments', ['encryption_key_id'])
def downgrade():
op.drop_index('idx_attachment_encryption_key', 'pjctrl_attachments')
op.drop_column('pjctrl_attachments', 'encryption_key_id')
op.drop_index('idx_encryption_key_active', 'pjctrl_encryption_keys')
op.drop_table('pjctrl_encryption_keys')

View File

@@ -0,0 +1,92 @@
"""Add start_date to tasks and create task dependencies table for Gantt view
Revision ID: 013
Revises: 012
Create Date: 2026-01-05
FEAT-003: Add Gantt view support with task dependencies and start dates.
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '013'
down_revision: Union[str, None] = '012'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add start_date column to pjctrl_tasks
op.add_column(
'pjctrl_tasks',
sa.Column('start_date', sa.DateTime, nullable=True)
)
# Create dependency_type_enum
dependency_type_enum = sa.Enum(
'FS', 'SS', 'FF', 'SF',
name='dependency_type_enum'
)
dependency_type_enum.create(op.get_bind(), checkfirst=True)
# Create pjctrl_task_dependencies table
op.create_table(
'pjctrl_task_dependencies',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column(
'predecessor_id',
sa.String(36),
sa.ForeignKey('pjctrl_tasks.id', ondelete='CASCADE'),
nullable=False
),
sa.Column(
'successor_id',
sa.String(36),
sa.ForeignKey('pjctrl_tasks.id', ondelete='CASCADE'),
nullable=False
),
sa.Column(
'dependency_type',
dependency_type_enum,
default='FS',
nullable=False
),
sa.Column('lag_days', sa.Integer, default=0, nullable=False),
sa.Column('created_at', sa.DateTime, server_default=sa.func.now(), nullable=False),
# Unique constraint to prevent duplicate dependencies
sa.UniqueConstraint('predecessor_id', 'successor_id', name='uq_predecessor_successor'),
)
# Create indexes for efficient dependency lookups
op.create_index(
'ix_pjctrl_task_dependencies_predecessor_id',
'pjctrl_task_dependencies',
['predecessor_id']
)
op.create_index(
'ix_pjctrl_task_dependencies_successor_id',
'pjctrl_task_dependencies',
['successor_id']
)
def downgrade() -> None:
# Drop indexes
op.drop_index(
'ix_pjctrl_task_dependencies_successor_id',
table_name='pjctrl_task_dependencies'
)
op.drop_index(
'ix_pjctrl_task_dependencies_predecessor_id',
table_name='pjctrl_task_dependencies'
)
# Drop the table
op.drop_table('pjctrl_task_dependencies')
# Drop the enum type
sa.Enum(name='dependency_type_enum').drop(op.get_bind(), checkfirst=True)
# Remove start_date column from tasks
op.drop_column('pjctrl_tasks', 'start_date')

View File

@@ -0,0 +1,440 @@
"""
Tests for Custom Fields feature.
"""
import pytest
from app.models import User, Space, Project, Task, TaskStatus, CustomField, TaskCustomValue
from app.services.formula_service import FormulaService, FormulaError, CircularReferenceError
class TestCustomFieldsCRUD:
"""Test custom fields CRUD operations."""
def setup_project(self, db, owner_id: str):
"""Create a space and project for testing."""
space = Space(
id="test-space-001",
name="Test Space",
owner_id=owner_id,
)
db.add(space)
project = Project(
id="test-project-001",
space_id=space.id,
title="Test Project",
owner_id=owner_id,
)
db.add(project)
# Add default task status
status = TaskStatus(
id="test-status-001",
project_id=project.id,
name="To Do",
color="#3B82F6",
position=0,
)
db.add(status)
db.commit()
return project
def test_create_text_field(self, client, db, admin_token):
"""Test creating a text custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={
"name": "Sprint Number",
"field_type": "text",
"is_required": False,
},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "Sprint Number"
assert data["field_type"] == "text"
assert data["is_required"] is False
def test_create_number_field(self, client, db, admin_token):
"""Test creating a number custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={
"name": "Story Points",
"field_type": "number",
"is_required": True,
},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "Story Points"
assert data["field_type"] == "number"
assert data["is_required"] is True
def test_create_dropdown_field(self, client, db, admin_token):
"""Test creating a dropdown custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={
"name": "Component",
"field_type": "dropdown",
"options": ["Frontend", "Backend", "Database", "API"],
"is_required": False,
},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "Component"
assert data["field_type"] == "dropdown"
assert data["options"] == ["Frontend", "Backend", "Database", "API"]
def test_create_dropdown_field_without_options_fails(self, client, db, admin_token):
"""Test that creating a dropdown field without options fails."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={
"name": "Component",
"field_type": "dropdown",
"options": [],
},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 422 # Validation error
def test_create_formula_field(self, client, db, admin_token):
"""Test creating a formula custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
# First create a number field to reference
client.post(
f"/api/projects/{project.id}/custom-fields",
json={
"name": "hours_worked",
"field_type": "number",
},
headers={"Authorization": f"Bearer {admin_token}"},
)
# Create formula field
response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={
"name": "Progress",
"field_type": "formula",
"formula": "{time_spent} / {original_estimate} * 100",
},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "Progress"
assert data["field_type"] == "formula"
assert "{time_spent}" in data["formula"]
def test_list_custom_fields(self, client, db, admin_token):
"""Test listing custom fields for a project."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
# Create some fields
client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "Field 1", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
)
client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "Field 2", "field_type": "number"},
headers={"Authorization": f"Bearer {admin_token}"},
)
response = client.get(
f"/api/projects/{project.id}/custom-fields",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 2
assert len(data["fields"]) == 2
def test_update_custom_field(self, client, db, admin_token):
"""Test updating a custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
# Create a field
create_response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "Original Name", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
)
field_id = create_response.json()["id"]
# Update it
response = client.put(
f"/api/custom-fields/{field_id}",
json={"name": "Updated Name", "is_required": True},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Updated Name"
assert data["is_required"] is True
def test_delete_custom_field(self, client, db, admin_token):
"""Test deleting a custom field."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
# Create a field
create_response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "To Delete", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
)
field_id = create_response.json()["id"]
# Delete it
response = client.delete(
f"/api/custom-fields/{field_id}",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 204
# Verify it's gone
get_response = client.get(
f"/api/custom-fields/{field_id}",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert get_response.status_code == 404
def test_max_fields_limit(self, client, db, admin_token):
"""Test that maximum 20 custom fields per project is enforced."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
# Create 20 fields
for i in range(20):
response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": f"Field {i}", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 201
# Try to create the 21st field
response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "Field 21", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 400
assert "Maximum" in response.json()["detail"]
def test_duplicate_name_rejected(self, client, db, admin_token):
"""Test that duplicate field names are rejected."""
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
# Create a field
client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "Unique Name", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
)
# Try to create another with same name
response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "Unique Name", "field_type": "number"},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 400
assert "already exists" in response.json()["detail"]
class TestFormulaService:
"""Test formula parsing and calculation."""
def test_extract_field_references(self):
"""Test extracting field references from formulas."""
formula = "{time_spent} / {original_estimate} * 100"
refs = FormulaService.extract_field_references(formula)
assert refs == {"time_spent", "original_estimate"}
def test_extract_multiple_references(self):
"""Test extracting multiple field references."""
formula = "{field_a} + {field_b} - {field_c}"
refs = FormulaService.extract_field_references(formula)
assert refs == {"field_a", "field_b", "field_c"}
def test_safe_eval_addition(self):
"""Test safe evaluation of addition."""
result = FormulaService._safe_eval("10 + 5")
assert float(result) == 15.0
def test_safe_eval_division(self):
"""Test safe evaluation of division."""
result = FormulaService._safe_eval("20 / 4")
assert float(result) == 5.0
def test_safe_eval_complex_expression(self):
"""Test safe evaluation of complex expression."""
result = FormulaService._safe_eval("(10 + 5) * 2 / 3")
assert float(result) == 10.0
def test_safe_eval_division_by_zero(self):
"""Test that division by zero returns 0."""
result = FormulaService._safe_eval("10 / 0")
assert float(result) == 0.0
def test_safe_eval_negative_numbers(self):
"""Test safe evaluation with negative numbers."""
result = FormulaService._safe_eval("-5 + 10")
assert float(result) == 5.0
class TestCustomValuesWithTasks:
"""Test custom values integration with tasks."""
def setup_project_with_fields(self, db, client, admin_token, owner_id: str):
"""Create a project with custom fields for testing."""
space = Space(
id="test-space-002",
name="Test Space",
owner_id=owner_id,
)
db.add(space)
project = Project(
id="test-project-002",
space_id=space.id,
title="Test Project",
owner_id=owner_id,
)
db.add(project)
status = TaskStatus(
id="test-status-002",
project_id=project.id,
name="To Do",
color="#3B82F6",
position=0,
)
db.add(status)
db.commit()
# Create custom fields via API
text_response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "sprint_number", "field_type": "text"},
headers={"Authorization": f"Bearer {admin_token}"},
)
text_field_id = text_response.json()["id"]
number_response = client.post(
f"/api/projects/{project.id}/custom-fields",
json={"name": "story_points", "field_type": "number"},
headers={"Authorization": f"Bearer {admin_token}"},
)
number_field_id = number_response.json()["id"]
return project, text_field_id, number_field_id
def test_create_task_with_custom_values(self, client, db, admin_token):
"""Test creating a task with custom values."""
project, text_field_id, number_field_id = self.setup_project_with_fields(
db, client, admin_token, "00000000-0000-0000-0000-000000000001"
)
response = client.post(
f"/api/projects/{project.id}/tasks",
json={
"title": "Test Task",
"custom_values": [
{"field_id": text_field_id, "value": "Sprint 5"},
{"field_id": number_field_id, "value": "8"},
],
},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 201
def test_get_task_includes_custom_values(self, client, db, admin_token):
"""Test that getting a task includes custom values."""
project, text_field_id, number_field_id = self.setup_project_with_fields(
db, client, admin_token, "00000000-0000-0000-0000-000000000001"
)
# Create task with custom values
create_response = client.post(
f"/api/projects/{project.id}/tasks",
json={
"title": "Test Task",
"custom_values": [
{"field_id": text_field_id, "value": "Sprint 5"},
{"field_id": number_field_id, "value": "8"},
],
},
headers={"Authorization": f"Bearer {admin_token}"},
)
task_id = create_response.json()["id"]
# Get task and check custom values
get_response = client.get(
f"/api/tasks/{task_id}",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert get_response.status_code == 200
data = get_response.json()
assert data["custom_values"] is not None
assert len(data["custom_values"]) >= 2
def test_update_task_custom_values(self, client, db, admin_token):
"""Test updating custom values on a task."""
project, text_field_id, number_field_id = self.setup_project_with_fields(
db, client, admin_token, "00000000-0000-0000-0000-000000000001"
)
# Create task
create_response = client.post(
f"/api/projects/{project.id}/tasks",
json={
"title": "Test Task",
"custom_values": [
{"field_id": text_field_id, "value": "Sprint 5"},
],
},
headers={"Authorization": f"Bearer {admin_token}"},
)
task_id = create_response.json()["id"]
# Update custom values
update_response = client.patch(
f"/api/tasks/{task_id}",
json={
"custom_values": [
{"field_id": text_field_id, "value": "Sprint 6"},
{"field_id": number_field_id, "value": "13"},
],
},
headers={"Authorization": f"Bearer {admin_token}"},
)
assert update_response.status_code == 200

View File

@@ -0,0 +1,275 @@
"""
Tests for the file encryption functionality.
Tests cover:
- Encryption service (key generation, encrypt/decrypt)
- Encryption key management API
- Attachment upload with encryption
- Attachment download with decryption
"""
import pytest
import base64
import secrets
from io import BytesIO
from unittest.mock import patch, MagicMock
from app.services.encryption_service import (
EncryptionService,
encryption_service,
MasterKeyNotConfiguredError,
DecryptionError,
KEY_SIZE,
NONCE_SIZE,
)
class TestEncryptionService:
"""Tests for the encryption service."""
@pytest.fixture
def mock_master_key(self):
"""Generate a valid test master key."""
return base64.urlsafe_b64encode(secrets.token_bytes(32)).decode()
@pytest.fixture
def service_with_key(self, mock_master_key):
"""Create an encryption service with a mock master key."""
with patch('app.services.encryption_service.settings') as mock_settings:
mock_settings.ENCRYPTION_MASTER_KEY = mock_master_key
service = EncryptionService()
yield service
def test_generate_key(self, service_with_key):
"""Test that generate_key produces a 32-byte key."""
key = service_with_key.generate_key()
assert len(key) == KEY_SIZE
assert isinstance(key, bytes)
def test_generate_key_uniqueness(self, service_with_key):
"""Test that each generated key is unique."""
keys = [service_with_key.generate_key() for _ in range(10)]
unique_keys = set(keys)
assert len(unique_keys) == 10
def test_encrypt_decrypt_key(self, service_with_key):
"""Test encryption and decryption of a file encryption key."""
# Generate a key to encrypt
original_key = service_with_key.generate_key()
# Encrypt the key
encrypted_key = service_with_key.encrypt_key(original_key)
assert isinstance(encrypted_key, str)
assert encrypted_key != base64.urlsafe_b64encode(original_key).decode()
# Decrypt the key
decrypted_key = service_with_key.decrypt_key(encrypted_key)
assert decrypted_key == original_key
def test_encrypt_decrypt_file(self, service_with_key):
"""Test file encryption and decryption."""
# Create test file content
original_content = b"This is a test file content for encryption."
file_obj = BytesIO(original_content)
# Generate encryption key
key = service_with_key.generate_key()
# Encrypt
encrypted_content = service_with_key.encrypt_file(file_obj, key)
assert encrypted_content != original_content
assert len(encrypted_content) > len(original_content) # Due to nonce and tag
# Decrypt
encrypted_file = BytesIO(encrypted_content)
decrypted_content = service_with_key.decrypt_file(encrypted_file, key)
assert decrypted_content == original_content
def test_encrypt_decrypt_bytes(self, service_with_key):
"""Test bytes encryption and decryption convenience methods."""
original_data = b"Test data for encryption"
key = service_with_key.generate_key()
# Encrypt
encrypted_data = service_with_key.encrypt_bytes(original_data, key)
assert encrypted_data != original_data
# Decrypt
decrypted_data = service_with_key.decrypt_bytes(encrypted_data, key)
assert decrypted_data == original_data
def test_encrypt_large_file(self, service_with_key):
"""Test encryption of a larger file (1MB)."""
# Create 1MB of random data
original_content = secrets.token_bytes(1024 * 1024)
file_obj = BytesIO(original_content)
key = service_with_key.generate_key()
# Encrypt
encrypted_content = service_with_key.encrypt_file(file_obj, key)
# Decrypt
encrypted_file = BytesIO(encrypted_content)
decrypted_content = service_with_key.decrypt_file(encrypted_file, key)
assert decrypted_content == original_content
def test_decrypt_with_wrong_key(self, service_with_key):
"""Test that decryption fails with wrong key."""
original_content = b"Secret content"
file_obj = BytesIO(original_content)
key1 = service_with_key.generate_key()
key2 = service_with_key.generate_key()
# Encrypt with key1
encrypted_content = service_with_key.encrypt_file(file_obj, key1)
# Try to decrypt with key2
encrypted_file = BytesIO(encrypted_content)
with pytest.raises(DecryptionError):
service_with_key.decrypt_file(encrypted_file, key2)
def test_decrypt_corrupted_data(self, service_with_key):
"""Test that decryption fails with corrupted data."""
original_content = b"Secret content"
file_obj = BytesIO(original_content)
key = service_with_key.generate_key()
# Encrypt
encrypted_content = service_with_key.encrypt_file(file_obj, key)
# Corrupt the encrypted data
corrupted = bytearray(encrypted_content)
corrupted[20] ^= 0xFF # Flip some bits
corrupted_content = bytes(corrupted)
# Try to decrypt
encrypted_file = BytesIO(corrupted_content)
with pytest.raises(DecryptionError):
service_with_key.decrypt_file(encrypted_file, key)
def test_is_encryption_available_with_key(self, mock_master_key):
"""Test is_encryption_available returns True when key is configured."""
with patch('app.services.encryption_service.settings') as mock_settings:
mock_settings.ENCRYPTION_MASTER_KEY = mock_master_key
service = EncryptionService()
assert service.is_encryption_available() is True
def test_is_encryption_available_without_key(self):
"""Test is_encryption_available returns False when key is not configured."""
with patch('app.services.encryption_service.settings') as mock_settings:
mock_settings.ENCRYPTION_MASTER_KEY = None
service = EncryptionService()
assert service.is_encryption_available() is False
def test_master_key_not_configured_error(self):
"""Test that operations fail when master key is not configured."""
with patch('app.services.encryption_service.settings') as mock_settings:
mock_settings.ENCRYPTION_MASTER_KEY = None
service = EncryptionService()
key = secrets.token_bytes(32)
with pytest.raises(MasterKeyNotConfiguredError):
service.encrypt_key(key)
def test_encrypted_key_format(self, service_with_key):
"""Test that encrypted key is valid base64."""
key = service_with_key.generate_key()
encrypted_key = service_with_key.encrypt_key(key)
# Should be valid base64
decoded = base64.urlsafe_b64decode(encrypted_key)
# Should contain nonce + ciphertext + tag
assert len(decoded) >= NONCE_SIZE + KEY_SIZE + 16 # 16 = GCM tag size
class TestEncryptionServiceStreaming:
"""Tests for streaming encryption (for large files)."""
@pytest.fixture
def mock_master_key(self):
"""Generate a valid test master key."""
return base64.urlsafe_b64encode(secrets.token_bytes(32)).decode()
@pytest.fixture
def service_with_key(self, mock_master_key):
"""Create an encryption service with a mock master key."""
with patch('app.services.encryption_service.settings') as mock_settings:
mock_settings.ENCRYPTION_MASTER_KEY = mock_master_key
service = EncryptionService()
yield service
def test_streaming_encrypt_decrypt(self, service_with_key):
"""Test streaming encryption and decryption."""
# Create test content
original_content = b"Test content for streaming encryption. " * 1000
file_obj = BytesIO(original_content)
key = service_with_key.generate_key()
# Encrypt using streaming
encrypted_chunks = list(service_with_key.encrypt_file_streaming(file_obj, key))
encrypted_content = b''.join(encrypted_chunks)
# Decrypt using streaming
encrypted_file = BytesIO(encrypted_content)
decrypted_chunks = list(service_with_key.decrypt_file_streaming(encrypted_file, key))
decrypted_content = b''.join(decrypted_chunks)
assert decrypted_content == original_content
class TestEncryptionKeyValidation:
"""Tests for encryption key validation in config."""
def test_valid_master_key(self):
"""Test that a valid master key passes validation."""
from app.core.config import Settings
valid_key = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode()
# This should not raise
with patch.dict('os.environ', {
'JWT_SECRET_KEY': 'test-secret-key-that-is-valid',
'ENCRYPTION_MASTER_KEY': valid_key
}):
settings = Settings()
assert settings.ENCRYPTION_MASTER_KEY == valid_key
def test_invalid_master_key_length(self):
"""Test that an invalid length master key fails validation."""
from app.core.config import Settings
# 16 bytes instead of 32
invalid_key = base64.urlsafe_b64encode(secrets.token_bytes(16)).decode()
with patch.dict('os.environ', {
'JWT_SECRET_KEY': 'test-secret-key-that-is-valid',
'ENCRYPTION_MASTER_KEY': invalid_key
}):
with pytest.raises(ValueError, match="must be a base64-encoded 32-byte key"):
Settings()
def test_invalid_master_key_format(self):
"""Test that an invalid format master key fails validation."""
from app.core.config import Settings
from pydantic import ValidationError
invalid_key = "not-valid-base64!@#$"
with patch.dict('os.environ', {
'JWT_SECRET_KEY': 'test-secret-key-that-is-valid',
'ENCRYPTION_MASTER_KEY': invalid_key
}):
with pytest.raises(ValidationError, match="ENCRYPTION_MASTER_KEY"):
Settings()
def test_empty_master_key_allowed(self):
"""Test that empty master key is allowed (encryption disabled)."""
from app.core.config import Settings
with patch.dict('os.environ', {
'JWT_SECRET_KEY': 'test-secret-key-that-is-valid',
'ENCRYPTION_MASTER_KEY': ''
}):
settings = Settings()
assert settings.ENCRYPTION_MASTER_KEY is None

File diff suppressed because it is too large Load Diff

View File

@@ -114,3 +114,383 @@ class TestSubtaskDepth:
"""Test that MAX_SUBTASK_DEPTH is defined."""
from app.api.tasks.router import MAX_SUBTASK_DEPTH
assert MAX_SUBTASK_DEPTH == 2
class TestDateRangeFilter:
"""Test date range filter for calendar view."""
def test_due_after_filter(self, client, db, admin_token):
"""Test filtering tasks with due_date >= due_after."""
from datetime import datetime, timedelta
from app.models import Space, Project, Task, TaskStatus
# Create test data
space = Space(
id="test-space-id",
name="Test Space",
owner_id="00000000-0000-0000-0000-000000000001",
)
db.add(space)
project = Project(
id="test-project-id",
space_id="test-space-id",
title="Test Project",
owner_id="00000000-0000-0000-0000-000000000001",
security_level="public",
)
db.add(project)
status = TaskStatus(
id="test-status-id",
project_id="test-project-id",
name="To Do",
color="#808080",
position=0,
)
db.add(status)
# Create tasks with different due dates
now = datetime.now()
task1 = Task(
id="task-1",
project_id="test-project-id",
title="Task Due Yesterday",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id",
due_date=now - timedelta(days=1),
)
task2 = Task(
id="task-2",
project_id="test-project-id",
title="Task Due Today",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id",
due_date=now,
)
task3 = Task(
id="task-3",
project_id="test-project-id",
title="Task Due Tomorrow",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id",
due_date=now + timedelta(days=1),
)
task4 = Task(
id="task-4",
project_id="test-project-id",
title="Task No Due Date",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id",
due_date=None,
)
db.add_all([task1, task2, task3, task4])
db.commit()
# Filter tasks due today or later
due_after = now.isoformat()
response = client.get(
f"/api/projects/test-project-id/tasks?due_after={due_after}",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
data = response.json()
# Should return task2 and task3 (due today and tomorrow)
assert data["total"] == 2
task_ids = [t["id"] for t in data["tasks"]]
assert "task-2" in task_ids
assert "task-3" in task_ids
assert "task-1" not in task_ids
assert "task-4" not in task_ids
def test_due_before_filter(self, client, db, admin_token):
"""Test filtering tasks with due_date <= due_before."""
from datetime import datetime, timedelta
from app.models import Space, Project, Task, TaskStatus
# Create test data
space = Space(
id="test-space-id-2",
name="Test Space 2",
owner_id="00000000-0000-0000-0000-000000000001",
)
db.add(space)
project = Project(
id="test-project-id-2",
space_id="test-space-id-2",
title="Test Project 2",
owner_id="00000000-0000-0000-0000-000000000001",
security_level="public",
)
db.add(project)
status = TaskStatus(
id="test-status-id-2",
project_id="test-project-id-2",
name="To Do",
color="#808080",
position=0,
)
db.add(status)
# Create tasks with different due dates
now = datetime.now()
task1 = Task(
id="task-b-1",
project_id="test-project-id-2",
title="Task Due Yesterday",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id-2",
due_date=now - timedelta(days=1),
)
task2 = Task(
id="task-b-2",
project_id="test-project-id-2",
title="Task Due Today",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id-2",
due_date=now,
)
task3 = Task(
id="task-b-3",
project_id="test-project-id-2",
title="Task Due Tomorrow",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id-2",
due_date=now + timedelta(days=1),
)
db.add_all([task1, task2, task3])
db.commit()
# Filter tasks due today or earlier
due_before = now.isoformat()
response = client.get(
f"/api/projects/test-project-id-2/tasks?due_before={due_before}",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
data = response.json()
# Should return task1 and task2 (due yesterday and today)
assert data["total"] == 2
task_ids = [t["id"] for t in data["tasks"]]
assert "task-b-1" in task_ids
assert "task-b-2" in task_ids
assert "task-b-3" not in task_ids
def test_date_range_filter_combined(self, client, db, admin_token):
"""Test filtering tasks within a date range (due_after AND due_before)."""
from datetime import datetime, timedelta
from app.models import Space, Project, Task, TaskStatus
# Create test data
space = Space(
id="test-space-id-3",
name="Test Space 3",
owner_id="00000000-0000-0000-0000-000000000001",
)
db.add(space)
project = Project(
id="test-project-id-3",
space_id="test-space-id-3",
title="Test Project 3",
owner_id="00000000-0000-0000-0000-000000000001",
security_level="public",
)
db.add(project)
status = TaskStatus(
id="test-status-id-3",
project_id="test-project-id-3",
name="To Do",
color="#808080",
position=0,
)
db.add(status)
# Create tasks spanning a week
now = datetime.now()
start_of_week = now - timedelta(days=now.weekday()) # Monday
end_of_week = start_of_week + timedelta(days=6) # Sunday
task_before = Task(
id="task-c-before",
project_id="test-project-id-3",
title="Task Before Week",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id-3",
due_date=start_of_week - timedelta(days=1),
)
task_in_week = Task(
id="task-c-in-week",
project_id="test-project-id-3",
title="Task In Week",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id-3",
due_date=start_of_week + timedelta(days=3),
)
task_after = Task(
id="task-c-after",
project_id="test-project-id-3",
title="Task After Week",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id-3",
due_date=end_of_week + timedelta(days=1),
)
db.add_all([task_before, task_in_week, task_after])
db.commit()
# Filter tasks within this week
due_after = start_of_week.isoformat()
due_before = end_of_week.isoformat()
response = client.get(
f"/api/projects/test-project-id-3/tasks?due_after={due_after}&due_before={due_before}",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
data = response.json()
# Should only return the task in the week
assert data["total"] == 1
assert data["tasks"][0]["id"] == "task-c-in-week"
def test_date_filter_with_no_due_date(self, client, db, admin_token):
"""Test that tasks without due_date are excluded from date range filters."""
from datetime import datetime, timedelta
from app.models import Space, Project, Task, TaskStatus
# Create test data
space = Space(
id="test-space-id-4",
name="Test Space 4",
owner_id="00000000-0000-0000-0000-000000000001",
)
db.add(space)
project = Project(
id="test-project-id-4",
space_id="test-space-id-4",
title="Test Project 4",
owner_id="00000000-0000-0000-0000-000000000001",
security_level="public",
)
db.add(project)
status = TaskStatus(
id="test-status-id-4",
project_id="test-project-id-4",
name="To Do",
color="#808080",
position=0,
)
db.add(status)
# Create tasks - some with due_date, some without
now = datetime.now()
task_with_date = Task(
id="task-d-with-date",
project_id="test-project-id-4",
title="Task With Due Date",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id-4",
due_date=now,
)
task_without_date = Task(
id="task-d-without-date",
project_id="test-project-id-4",
title="Task Without Due Date",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id-4",
due_date=None,
)
db.add_all([task_with_date, task_without_date])
db.commit()
# When using date filter, tasks without due_date should be excluded
due_after = (now - timedelta(days=1)).isoformat()
due_before = (now + timedelta(days=1)).isoformat()
response = client.get(
f"/api/projects/test-project-id-4/tasks?due_after={due_after}&due_before={due_before}",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
data = response.json()
# Should only return the task with due_date
assert data["total"] == 1
assert data["tasks"][0]["id"] == "task-d-with-date"
def test_date_filter_backward_compatibility(self, client, db, admin_token):
"""Test that not providing date filters returns all tasks (backward compatibility)."""
from datetime import datetime, timedelta
from app.models import Space, Project, Task, TaskStatus
# Create test data
space = Space(
id="test-space-id-5",
name="Test Space 5",
owner_id="00000000-0000-0000-0000-000000000001",
)
db.add(space)
project = Project(
id="test-project-id-5",
space_id="test-space-id-5",
title="Test Project 5",
owner_id="00000000-0000-0000-0000-000000000001",
security_level="public",
)
db.add(project)
status = TaskStatus(
id="test-status-id-5",
project_id="test-project-id-5",
name="To Do",
color="#808080",
position=0,
)
db.add(status)
# Create tasks with and without due dates
now = datetime.now()
task1 = Task(
id="task-e-1",
project_id="test-project-id-5",
title="Task 1",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id-5",
due_date=now,
)
task2 = Task(
id="task-e-2",
project_id="test-project-id-5",
title="Task 2",
priority="medium",
created_by="00000000-0000-0000-0000-000000000001",
status_id="test-status-id-5",
due_date=None,
)
db.add_all([task1, task2])
db.commit()
# Request without date filters - should return all tasks
response = client.get(
"/api/projects/test-project-id-5/tasks",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == 200
data = response.json()
# Should return both tasks
assert data["total"] == 2