feat: implement custom fields, gantt view, calendar view, and file encryption
- Custom Fields (FEAT-001): - CustomField and TaskCustomValue models with formula support - CRUD API for custom field management - Formula engine for calculated fields - Frontend: CustomFieldEditor, CustomFieldInput, ProjectSettings page - Task list API now includes custom_values - KanbanBoard displays custom field values - Gantt View (FEAT-003): - TaskDependency model with FS/SS/FF/SF dependency types - Dependency CRUD API with cycle detection - start_date field added to tasks - GanttChart component with Frappe Gantt integration - Dependency type selector in UI - Calendar View (FEAT-004): - CalendarView component with FullCalendar integration - Date range filtering API for tasks - Drag-and-drop date updates - View mode switching in Tasks page - File Encryption (FEAT-010): - AES-256-GCM encryption service - EncryptionKey model with key rotation support - Admin API for key management - Encrypted upload/download for confidential projects - Migrations: 011 (custom fields), 012 (encryption keys), 013 (task dependencies) - Updated issues.md with completion status 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -20,3 +20,14 @@ AUTH_API_URL=https://pj-auth-api.vercel.app
|
||||
|
||||
# System Admin
|
||||
SYSTEM_ADMIN_EMAIL=ymirliu@panjit.com.tw
|
||||
|
||||
# File Encryption (AES-256)
|
||||
# Master key for encrypting file encryption keys (optional - if not set, file encryption is disabled)
|
||||
# Generate a new key with:
|
||||
# python -c "import secrets, base64; print(base64.urlsafe_b64encode(secrets.token_bytes(32)).decode())"
|
||||
#
|
||||
# IMPORTANT:
|
||||
# - Keep this key secure and back it up! If lost, encrypted files cannot be decrypted.
|
||||
# - Store backup in a secure location separate from the database backup.
|
||||
# - Do NOT change this key after files have been encrypted (use key rotation instead).
|
||||
ENCRYPTION_MASTER_KEY=
|
||||
|
||||
1
backend/app/api/admin/__init__.py
Normal file
1
backend/app/api/admin/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Admin API module
|
||||
299
backend/app/api/admin/encryption_keys.py
Normal file
299
backend/app/api/admin/encryption_keys.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Encryption Key Management API (Admin only).
|
||||
|
||||
Provides endpoints for:
|
||||
- Listing encryption keys (without actual key data)
|
||||
- Creating new encryption keys
|
||||
- Key rotation
|
||||
- Checking encryption status
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.middleware.auth import get_current_user, require_system_admin
|
||||
from app.models import User, EncryptionKey, AuditAction
|
||||
from app.schemas.encryption_key import (
|
||||
EncryptionKeyResponse,
|
||||
EncryptionKeyListResponse,
|
||||
EncryptionKeyCreateResponse,
|
||||
EncryptionKeyRotateResponse,
|
||||
EncryptionStatusResponse,
|
||||
)
|
||||
from app.services.encryption_service import (
|
||||
encryption_service,
|
||||
MasterKeyNotConfiguredError,
|
||||
)
|
||||
from app.services.audit_service import AuditService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/encryption-keys", tags=["Admin - Encryption Keys"])
|
||||
|
||||
|
||||
def key_to_response(key: EncryptionKey) -> EncryptionKeyResponse:
|
||||
"""Convert EncryptionKey model to response (without key data)."""
|
||||
return EncryptionKeyResponse(
|
||||
id=key.id,
|
||||
algorithm=key.algorithm,
|
||||
is_active=key.is_active,
|
||||
created_at=key.created_at,
|
||||
rotated_at=key.rotated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/status", response_model=EncryptionStatusResponse)
|
||||
async def get_encryption_status(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_system_admin),
|
||||
):
|
||||
"""
|
||||
Get the current encryption status.
|
||||
|
||||
Returns whether encryption is available, active key info, and total key count.
|
||||
"""
|
||||
encryption_available = encryption_service.is_encryption_available()
|
||||
|
||||
active_key = None
|
||||
total_keys = 0
|
||||
|
||||
if encryption_available:
|
||||
active_key = db.query(EncryptionKey).filter(
|
||||
EncryptionKey.is_active == True
|
||||
).first()
|
||||
total_keys = db.query(EncryptionKey).count()
|
||||
|
||||
message = "Encryption is available" if encryption_available else "Encryption is not configured (ENCRYPTION_MASTER_KEY not set)"
|
||||
|
||||
return EncryptionStatusResponse(
|
||||
encryption_available=encryption_available,
|
||||
active_key_id=active_key.id if active_key else None,
|
||||
total_keys=total_keys,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=EncryptionKeyListResponse)
|
||||
async def list_encryption_keys(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_system_admin),
|
||||
):
|
||||
"""
|
||||
List all encryption keys (without actual key data).
|
||||
|
||||
Only accessible by system administrators.
|
||||
"""
|
||||
keys = db.query(EncryptionKey).order_by(EncryptionKey.created_at.desc()).all()
|
||||
|
||||
return EncryptionKeyListResponse(
|
||||
keys=[key_to_response(k) for k in keys],
|
||||
total=len(keys),
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=EncryptionKeyCreateResponse)
|
||||
async def create_encryption_key(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_system_admin),
|
||||
):
|
||||
"""
|
||||
Create a new encryption key.
|
||||
|
||||
The key is generated, encrypted with the Master Key, and stored.
|
||||
This does NOT automatically make it the active key - use rotate for that.
|
||||
"""
|
||||
if not encryption_service.is_encryption_available():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Encryption is not configured. Set ENCRYPTION_MASTER_KEY in environment."
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate new key
|
||||
raw_key = encryption_service.generate_key()
|
||||
|
||||
# Encrypt with master key
|
||||
encrypted_key = encryption_service.encrypt_key(raw_key)
|
||||
|
||||
# Create key record (not active by default)
|
||||
key = EncryptionKey(
|
||||
id=str(uuid.uuid4()),
|
||||
key_data=encrypted_key,
|
||||
algorithm="AES-256-GCM",
|
||||
is_active=False,
|
||||
)
|
||||
db.add(key)
|
||||
|
||||
# Audit log
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="encryption.key_create",
|
||||
resource_type="encryption_key",
|
||||
action=AuditAction.CREATE,
|
||||
user_id=current_user.id,
|
||||
resource_id=key.id,
|
||||
changes=[{"field": "algorithm", "old_value": None, "new_value": "AES-256-GCM"}],
|
||||
request_metadata=getattr(request.state, "audit_metadata", None),
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(key)
|
||||
|
||||
return EncryptionKeyCreateResponse(
|
||||
id=key.id,
|
||||
algorithm=key.algorithm,
|
||||
is_active=key.is_active,
|
||||
created_at=key.created_at,
|
||||
message="Encryption key created successfully. Use /rotate to make it active.",
|
||||
)
|
||||
|
||||
except MasterKeyNotConfiguredError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Master key is not configured"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create encryption key: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/rotate", response_model=EncryptionKeyRotateResponse)
|
||||
async def rotate_encryption_key(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_system_admin),
|
||||
):
|
||||
"""
|
||||
Rotate to a new encryption key.
|
||||
|
||||
This will:
|
||||
1. Create a new encryption key
|
||||
2. Mark the new key as active
|
||||
3. Mark the old active key as inactive (but keep it for decrypting old files)
|
||||
|
||||
After rotation, new file uploads will use the new key.
|
||||
Old files remain readable using their original keys.
|
||||
"""
|
||||
if not encryption_service.is_encryption_available():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Encryption is not configured. Set ENCRYPTION_MASTER_KEY in environment."
|
||||
)
|
||||
|
||||
try:
|
||||
# Find current active key
|
||||
old_active_key = db.query(EncryptionKey).filter(
|
||||
EncryptionKey.is_active == True
|
||||
).first()
|
||||
|
||||
# Generate new key
|
||||
raw_key = encryption_service.generate_key()
|
||||
encrypted_key = encryption_service.encrypt_key(raw_key)
|
||||
|
||||
# Create new key as active
|
||||
new_key = EncryptionKey(
|
||||
id=str(uuid.uuid4()),
|
||||
key_data=encrypted_key,
|
||||
algorithm="AES-256-GCM",
|
||||
is_active=True,
|
||||
)
|
||||
db.add(new_key)
|
||||
|
||||
# Deactivate old key if exists
|
||||
old_key_id = None
|
||||
if old_active_key:
|
||||
old_active_key.is_active = False
|
||||
old_active_key.rotated_at = datetime.utcnow()
|
||||
old_key_id = old_active_key.id
|
||||
|
||||
# Audit log for rotation
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="encryption.key_rotate",
|
||||
resource_type="encryption_key",
|
||||
action=AuditAction.UPDATE,
|
||||
user_id=current_user.id,
|
||||
resource_id=new_key.id,
|
||||
changes=[
|
||||
{"field": "rotation", "old_value": old_key_id, "new_value": new_key.id},
|
||||
{"field": "is_active", "old_value": False, "new_value": True},
|
||||
],
|
||||
request_metadata=getattr(request.state, "audit_metadata", None),
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
return EncryptionKeyRotateResponse(
|
||||
new_key_id=new_key.id,
|
||||
old_key_id=old_key_id,
|
||||
message="Key rotation completed successfully. New uploads will use the new key.",
|
||||
)
|
||||
|
||||
except MasterKeyNotConfiguredError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Master key is not configured"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to rotate encryption key: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{key_id}")
|
||||
async def deactivate_encryption_key(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_system_admin),
|
||||
):
|
||||
"""
|
||||
Deactivate an encryption key.
|
||||
|
||||
Note: This does NOT delete the key. Keys are never deleted to ensure
|
||||
encrypted files can always be decrypted.
|
||||
|
||||
If this is the only active key, you must rotate to a new key first.
|
||||
"""
|
||||
key = db.query(EncryptionKey).filter(EncryptionKey.id == key_id).first()
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="Encryption key not found")
|
||||
|
||||
if key.is_active:
|
||||
# Check if there are other active keys
|
||||
other_active = db.query(EncryptionKey).filter(
|
||||
EncryptionKey.is_active == True,
|
||||
EncryptionKey.id != key_id
|
||||
).first()
|
||||
|
||||
if not other_active:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot deactivate the only active key. Rotate to a new key first."
|
||||
)
|
||||
|
||||
key.is_active = False
|
||||
key.rotated_at = datetime.utcnow()
|
||||
|
||||
# Audit log
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="encryption.key_deactivate",
|
||||
resource_type="encryption_key",
|
||||
action=AuditAction.UPDATE,
|
||||
user_id=current_user.id,
|
||||
resource_id=key.id,
|
||||
changes=[{"field": "is_active", "old_value": True, "new_value": False}],
|
||||
request_metadata=getattr(request.state, "audit_metadata", None),
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {"detail": "Encryption key deactivated", "id": key_id}
|
||||
@@ -1,5 +1,7 @@
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Request
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -7,7 +9,7 @@ from typing import Optional
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.middleware.auth import get_current_user, check_task_access, check_task_edit_access
|
||||
from app.models import User, Task, Project, Attachment, AttachmentVersion, AuditAction
|
||||
from app.models import User, Task, Project, Attachment, AttachmentVersion, EncryptionKey, AuditAction
|
||||
from app.schemas.attachment import (
|
||||
AttachmentResponse, AttachmentListResponse, AttachmentDetailResponse,
|
||||
AttachmentVersionResponse, VersionHistoryResponse
|
||||
@@ -15,6 +17,13 @@ from app.schemas.attachment import (
|
||||
from app.services.file_storage_service import file_storage_service
|
||||
from app.services.audit_service import AuditService
|
||||
from app.services.watermark_service import watermark_service
|
||||
from app.services.encryption_service import (
|
||||
encryption_service,
|
||||
MasterKeyNotConfiguredError,
|
||||
DecryptionError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["attachments"])
|
||||
|
||||
@@ -103,6 +112,40 @@ def version_to_response(version: AttachmentVersion) -> AttachmentVersionResponse
|
||||
)
|
||||
|
||||
|
||||
def should_encrypt_file(project: Project, db: Session) -> tuple[bool, Optional[EncryptionKey]]:
|
||||
"""
|
||||
Determine if a file should be encrypted based on project security level.
|
||||
|
||||
Returns:
|
||||
Tuple of (should_encrypt, encryption_key)
|
||||
"""
|
||||
# Only encrypt for confidential projects
|
||||
if project.security_level != "confidential":
|
||||
return False, None
|
||||
|
||||
# Check if encryption is available
|
||||
if not encryption_service.is_encryption_available():
|
||||
logger.warning(
|
||||
f"Project {project.id} is confidential but encryption is not configured. "
|
||||
"Files will be stored unencrypted."
|
||||
)
|
||||
return False, None
|
||||
|
||||
# Get active encryption key
|
||||
active_key = db.query(EncryptionKey).filter(
|
||||
EncryptionKey.is_active == True
|
||||
).first()
|
||||
|
||||
if not active_key:
|
||||
logger.warning(
|
||||
f"Project {project.id} is confidential but no active encryption key exists. "
|
||||
"Files will be stored unencrypted. Create a key using /api/admin/encryption-keys/rotate"
|
||||
)
|
||||
return False, None
|
||||
|
||||
return True, active_key
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/attachments", response_model=AttachmentResponse)
|
||||
async def upload_attachment(
|
||||
task_id: str,
|
||||
@@ -111,10 +154,22 @@ async def upload_attachment(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Upload a file attachment to a task."""
|
||||
"""
|
||||
Upload a file attachment to a task.
|
||||
|
||||
For confidential projects, files are automatically encrypted using AES-256-GCM.
|
||||
"""
|
||||
task = get_task_with_access_check(db, task_id, current_user, require_edit=True)
|
||||
|
||||
# Check if attachment with same filename exists (for versioning in Phase 2)
|
||||
# Get project to check security level
|
||||
project = db.query(Project).filter(Project.id == task.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# Determine if encryption is needed
|
||||
should_encrypt, encryption_key = should_encrypt_file(project, db)
|
||||
|
||||
# Check if attachment with same filename exists (for versioning)
|
||||
existing = db.query(Attachment).filter(
|
||||
Attachment.task_id == task_id,
|
||||
Attachment.original_filename == file.filename,
|
||||
@@ -122,17 +177,73 @@ async def upload_attachment(
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Phase 2: Create new version
|
||||
# Create new version for existing attachment
|
||||
new_version = existing.current_version + 1
|
||||
|
||||
# Save file
|
||||
file_path, file_size, checksum = await file_storage_service.save_file(
|
||||
file=file,
|
||||
project_id=task.project_id,
|
||||
task_id=task_id,
|
||||
attachment_id=existing.id,
|
||||
version=new_version
|
||||
)
|
||||
if should_encrypt and encryption_key:
|
||||
# Read and encrypt file content
|
||||
file_content = await file.read()
|
||||
await file.seek(0)
|
||||
|
||||
try:
|
||||
# Decrypt the encryption key
|
||||
raw_key = encryption_service.decrypt_key(encryption_key.key_data)
|
||||
# Encrypt the file
|
||||
encrypted_content = encryption_service.encrypt_bytes(file_content, raw_key)
|
||||
|
||||
# Create a new UploadFile-like object with encrypted content
|
||||
encrypted_file = BytesIO(encrypted_content)
|
||||
encrypted_file.seek(0)
|
||||
|
||||
# Save encrypted file using a modified approach
|
||||
file_path, _, checksum = await file_storage_service.save_file(
|
||||
file=file,
|
||||
project_id=task.project_id,
|
||||
task_id=task_id,
|
||||
attachment_id=existing.id,
|
||||
version=new_version
|
||||
)
|
||||
|
||||
# Overwrite with encrypted content
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(encrypted_content)
|
||||
|
||||
file_size = len(encrypted_content)
|
||||
|
||||
# Update existing attachment with encryption info
|
||||
existing.is_encrypted = True
|
||||
existing.encryption_key_id = encryption_key.id
|
||||
|
||||
# Audit log for encryption
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="attachment.encrypt",
|
||||
resource_type="attachment",
|
||||
action=AuditAction.UPDATE,
|
||||
user_id=current_user.id,
|
||||
resource_id=existing.id,
|
||||
changes=[
|
||||
{"field": "is_encrypted", "old_value": False, "new_value": True},
|
||||
{"field": "encryption_key_id", "old_value": None, "new_value": encryption_key.id},
|
||||
],
|
||||
request_metadata=getattr(request.state, "audit_metadata", None),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encrypt file for attachment {existing.id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to encrypt file. Please try again."
|
||||
)
|
||||
else:
|
||||
# Save file without encryption
|
||||
file_path, file_size, checksum = await file_storage_service.save_file(
|
||||
file=file,
|
||||
project_id=task.project_id,
|
||||
task_id=task_id,
|
||||
attachment_id=existing.id,
|
||||
version=new_version
|
||||
)
|
||||
|
||||
# Create version record
|
||||
version = AttachmentVersion(
|
||||
@@ -170,15 +281,52 @@ async def upload_attachment(
|
||||
|
||||
# Create new attachment
|
||||
attachment_id = str(uuid.uuid4())
|
||||
is_encrypted = False
|
||||
encryption_key_id = None
|
||||
|
||||
# Save file
|
||||
file_path, file_size, checksum = await file_storage_service.save_file(
|
||||
file=file,
|
||||
project_id=task.project_id,
|
||||
task_id=task_id,
|
||||
attachment_id=attachment_id,
|
||||
version=1
|
||||
)
|
||||
if should_encrypt and encryption_key:
|
||||
# Read and encrypt file content
|
||||
file_content = await file.read()
|
||||
await file.seek(0)
|
||||
|
||||
try:
|
||||
# Decrypt the encryption key
|
||||
raw_key = encryption_service.decrypt_key(encryption_key.key_data)
|
||||
# Encrypt the file
|
||||
encrypted_content = encryption_service.encrypt_bytes(file_content, raw_key)
|
||||
|
||||
# Save file first to get path
|
||||
file_path, _, checksum = await file_storage_service.save_file(
|
||||
file=file,
|
||||
project_id=task.project_id,
|
||||
task_id=task_id,
|
||||
attachment_id=attachment_id,
|
||||
version=1
|
||||
)
|
||||
|
||||
# Overwrite with encrypted content
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(encrypted_content)
|
||||
|
||||
file_size = len(encrypted_content)
|
||||
is_encrypted = True
|
||||
encryption_key_id = encryption_key.id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encrypt new file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to encrypt file. Please try again."
|
||||
)
|
||||
else:
|
||||
# Save file without encryption
|
||||
file_path, file_size, checksum = await file_storage_service.save_file(
|
||||
file=file,
|
||||
project_id=task.project_id,
|
||||
task_id=task_id,
|
||||
attachment_id=attachment_id,
|
||||
version=1
|
||||
)
|
||||
|
||||
# Get mime type from file storage validation
|
||||
extension = file_storage_service.get_extension(file.filename or "")
|
||||
@@ -193,7 +341,8 @@ async def upload_attachment(
|
||||
mime_type=mime_type,
|
||||
file_size=file_size,
|
||||
current_version=1,
|
||||
is_encrypted=False,
|
||||
is_encrypted=is_encrypted,
|
||||
encryption_key_id=encryption_key_id,
|
||||
uploaded_by=current_user.id
|
||||
)
|
||||
db.add(attachment)
|
||||
@@ -211,6 +360,10 @@ async def upload_attachment(
|
||||
db.add(version)
|
||||
|
||||
# Audit log
|
||||
changes = [{"field": "filename", "old_value": None, "new_value": attachment.filename}]
|
||||
if is_encrypted:
|
||||
changes.append({"field": "is_encrypted", "old_value": None, "new_value": True})
|
||||
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="attachment.upload",
|
||||
@@ -218,7 +371,7 @@ async def upload_attachment(
|
||||
action=AuditAction.CREATE,
|
||||
user_id=current_user.id,
|
||||
resource_id=attachment.id,
|
||||
changes=[{"field": "filename", "old_value": None, "new_value": attachment.filename}],
|
||||
changes=changes,
|
||||
request_metadata=getattr(request.state, "audit_metadata", None)
|
||||
)
|
||||
|
||||
@@ -286,7 +439,11 @@ async def download_attachment(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Download an attachment file with dynamic watermark."""
|
||||
"""
|
||||
Download an attachment file with dynamic watermark.
|
||||
|
||||
For encrypted files, the file is automatically decrypted before returning.
|
||||
"""
|
||||
attachment = get_attachment_with_access_check(db, attachment_id, current_user, require_edit=False)
|
||||
|
||||
# Get version to download
|
||||
@@ -319,14 +476,69 @@ async def download_attachment(
|
||||
)
|
||||
db.commit()
|
||||
|
||||
# Read file content
|
||||
with open(file_path, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
# Decrypt if encrypted
|
||||
if attachment.is_encrypted:
|
||||
if not attachment.encryption_key_id:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Encrypted file is missing encryption key reference"
|
||||
)
|
||||
|
||||
encryption_key = db.query(EncryptionKey).filter(
|
||||
EncryptionKey.id == attachment.encryption_key_id
|
||||
).first()
|
||||
|
||||
if not encryption_key:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Encryption key not found for this file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Decrypt the encryption key
|
||||
raw_key = encryption_service.decrypt_key(encryption_key.key_data)
|
||||
# Decrypt the file
|
||||
file_bytes = encryption_service.decrypt_bytes(file_bytes, raw_key)
|
||||
|
||||
# Audit log for decryption
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="attachment.decrypt",
|
||||
resource_type="attachment",
|
||||
action=AuditAction.UPDATE,
|
||||
user_id=current_user.id,
|
||||
resource_id=attachment.id,
|
||||
changes=[{"field": "decrypted_for_download", "old_value": None, "new_value": True}],
|
||||
request_metadata=getattr(request.state, "audit_metadata", None) if request else None,
|
||||
)
|
||||
db.commit()
|
||||
|
||||
except DecryptionError as e:
|
||||
logger.error(f"Failed to decrypt attachment {attachment_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to decrypt file. The file may be corrupted."
|
||||
)
|
||||
except MasterKeyNotConfiguredError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Encryption is not configured. Cannot decrypt file."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error decrypting attachment {attachment_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to decrypt file."
|
||||
)
|
||||
|
||||
# Check if watermark should be applied
|
||||
mime_type = attachment.mime_type or ""
|
||||
if watermark_service.supports_watermark(mime_type):
|
||||
try:
|
||||
# Read the original file
|
||||
with open(file_path, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
# Apply watermark based on file type
|
||||
if watermark_service.is_supported_image(mime_type):
|
||||
watermarked_bytes, output_format = watermark_service.add_image_watermark(
|
||||
@@ -367,19 +579,19 @@ async def download_attachment(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# If watermarking fails, log the error but still return the original file
|
||||
# This ensures users can still download files even if watermarking has issues
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
# If watermarking fails, log the error but still return the file
|
||||
logger.warning(
|
||||
f"Watermarking failed for attachment {attachment_id}: {str(e)}. "
|
||||
"Returning original file."
|
||||
"Returning file without watermark."
|
||||
)
|
||||
|
||||
# Return original file without watermark for unsupported types or on error
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
filename=attachment.original_filename,
|
||||
media_type=attachment.mime_type
|
||||
# Return file (decrypted if needed, without watermark for unsupported types)
|
||||
return Response(
|
||||
content=file_bytes,
|
||||
media_type=attachment.mime_type,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{attachment.original_filename}"'
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
3
backend/app/api/custom_fields/__init__.py
Normal file
3
backend/app/api/custom_fields/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.api.custom_fields.router import router
|
||||
|
||||
__all__ = ["router"]
|
||||
368
backend/app/api/custom_fields/router.py
Normal file
368
backend/app/api/custom_fields/router.py
Normal file
@@ -0,0 +1,368 @@
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models import User, Project, CustomField, TaskCustomValue
|
||||
from app.schemas.custom_field import (
|
||||
CustomFieldCreate, CustomFieldUpdate, CustomFieldResponse, CustomFieldListResponse
|
||||
)
|
||||
from app.middleware.auth import get_current_user, check_project_access, check_project_edit_access
|
||||
from app.services.formula_service import FormulaService
|
||||
|
||||
router = APIRouter(tags=["custom-fields"])
|
||||
|
||||
# Maximum custom fields per project
|
||||
MAX_CUSTOM_FIELDS_PER_PROJECT = 20
|
||||
|
||||
|
||||
def custom_field_to_response(field: CustomField) -> CustomFieldResponse:
|
||||
"""Convert CustomField model to response schema."""
|
||||
return CustomFieldResponse(
|
||||
id=field.id,
|
||||
project_id=field.project_id,
|
||||
name=field.name,
|
||||
field_type=field.field_type,
|
||||
options=field.options,
|
||||
formula=field.formula,
|
||||
is_required=field.is_required,
|
||||
position=field.position,
|
||||
created_at=field.created_at,
|
||||
updated_at=field.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/projects/{project_id}/custom-fields", response_model=CustomFieldResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_custom_field(
|
||||
project_id: str,
|
||||
field_data: CustomFieldCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Create a new custom field for a project.
|
||||
|
||||
Only project owner or system admin can create custom fields.
|
||||
Maximum 20 custom fields per project.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Project not found",
|
||||
)
|
||||
|
||||
if not check_project_edit_access(current_user, project):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Permission denied - only project owner can manage custom fields",
|
||||
)
|
||||
|
||||
# Check custom field count limit
|
||||
field_count = db.query(CustomField).filter(
|
||||
CustomField.project_id == project_id
|
||||
).count()
|
||||
|
||||
if field_count >= MAX_CUSTOM_FIELDS_PER_PROJECT:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Maximum {MAX_CUSTOM_FIELDS_PER_PROJECT} custom fields per project exceeded",
|
||||
)
|
||||
|
||||
# Check for duplicate name
|
||||
existing = db.query(CustomField).filter(
|
||||
CustomField.project_id == project_id,
|
||||
CustomField.name == field_data.name,
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Custom field with name '{field_data.name}' already exists",
|
||||
)
|
||||
|
||||
# Validate formula if it's a formula field
|
||||
if field_data.field_type.value == "formula":
|
||||
if not field_data.formula:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Formula is required for formula fields",
|
||||
)
|
||||
|
||||
is_valid, error_msg = FormulaService.validate_formula(
|
||||
field_data.formula, project_id, db
|
||||
)
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=error_msg,
|
||||
)
|
||||
|
||||
# Get next position
|
||||
max_pos = db.query(CustomField).filter(
|
||||
CustomField.project_id == project_id
|
||||
).order_by(CustomField.position.desc()).first()
|
||||
next_position = (max_pos.position + 1) if max_pos else 0
|
||||
|
||||
# Create the custom field
|
||||
field = CustomField(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project_id,
|
||||
name=field_data.name,
|
||||
field_type=field_data.field_type.value,
|
||||
options=field_data.options,
|
||||
formula=field_data.formula,
|
||||
is_required=field_data.is_required,
|
||||
position=next_position,
|
||||
)
|
||||
|
||||
db.add(field)
|
||||
db.commit()
|
||||
db.refresh(field)
|
||||
|
||||
return custom_field_to_response(field)
|
||||
|
||||
|
||||
@router.get("/api/projects/{project_id}/custom-fields", response_model=CustomFieldListResponse)
|
||||
async def list_custom_fields(
|
||||
project_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
List all custom fields for a project.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Project not found",
|
||||
)
|
||||
|
||||
if not check_project_access(current_user, project):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == project_id
|
||||
).order_by(CustomField.position).all()
|
||||
|
||||
return CustomFieldListResponse(
|
||||
fields=[custom_field_to_response(f) for f in fields],
|
||||
total=len(fields),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/custom-fields/{field_id}", response_model=CustomFieldResponse)
|
||||
async def get_custom_field(
|
||||
field_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Get a specific custom field by ID.
|
||||
"""
|
||||
field = db.query(CustomField).filter(CustomField.id == field_id).first()
|
||||
|
||||
if not field:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Custom field not found",
|
||||
)
|
||||
|
||||
project = field.project
|
||||
if not check_project_access(current_user, project):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
return custom_field_to_response(field)
|
||||
|
||||
|
||||
@router.put("/api/custom-fields/{field_id}", response_model=CustomFieldResponse)
|
||||
async def update_custom_field(
|
||||
field_id: str,
|
||||
field_data: CustomFieldUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Update a custom field.
|
||||
|
||||
Only project owner or system admin can update custom fields.
|
||||
Note: field_type cannot be changed after creation.
|
||||
"""
|
||||
field = db.query(CustomField).filter(CustomField.id == field_id).first()
|
||||
|
||||
if not field:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Custom field not found",
|
||||
)
|
||||
|
||||
project = field.project
|
||||
if not check_project_edit_access(current_user, project):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Permission denied",
|
||||
)
|
||||
|
||||
# Check for duplicate name if name is being updated
|
||||
if field_data.name is not None and field_data.name != field.name:
|
||||
existing = db.query(CustomField).filter(
|
||||
CustomField.project_id == field.project_id,
|
||||
CustomField.name == field_data.name,
|
||||
CustomField.id != field_id,
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Custom field with name '{field_data.name}' already exists",
|
||||
)
|
||||
|
||||
# Validate formula if updating formula field
|
||||
if field.field_type == "formula" and field_data.formula is not None:
|
||||
is_valid, error_msg = FormulaService.validate_formula(
|
||||
field_data.formula, field.project_id, db, field_id
|
||||
)
|
||||
if not is_valid:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=error_msg,
|
||||
)
|
||||
|
||||
# Validate options if updating dropdown field
|
||||
if field.field_type == "dropdown" and field_data.options is not None:
|
||||
if len(field_data.options) == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Dropdown fields must have at least one option",
|
||||
)
|
||||
if len(field_data.options) > 50:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Dropdown fields can have at most 50 options",
|
||||
)
|
||||
|
||||
# Update fields
|
||||
if field_data.name is not None:
|
||||
field.name = field_data.name
|
||||
if field_data.options is not None and field.field_type == "dropdown":
|
||||
field.options = field_data.options
|
||||
if field_data.formula is not None and field.field_type == "formula":
|
||||
field.formula = field_data.formula
|
||||
if field_data.is_required is not None:
|
||||
field.is_required = field_data.is_required
|
||||
|
||||
db.commit()
|
||||
db.refresh(field)
|
||||
|
||||
return custom_field_to_response(field)
|
||||
|
||||
|
||||
@router.delete("/api/custom-fields/{field_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_custom_field(
|
||||
field_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Delete a custom field.
|
||||
|
||||
Only project owner or system admin can delete custom fields.
|
||||
This will also delete all stored values for this field.
|
||||
"""
|
||||
field = db.query(CustomField).filter(CustomField.id == field_id).first()
|
||||
|
||||
if not field:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Custom field not found",
|
||||
)
|
||||
|
||||
project = field.project
|
||||
if not check_project_edit_access(current_user, project):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Permission denied",
|
||||
)
|
||||
|
||||
# Check if any formula fields reference this field
|
||||
if field.field_type != "formula":
|
||||
formula_fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == field.project_id,
|
||||
CustomField.field_type == "formula",
|
||||
).all()
|
||||
|
||||
for formula_field in formula_fields:
|
||||
if formula_field.formula:
|
||||
references = FormulaService.extract_field_references(formula_field.formula)
|
||||
if field.name in references:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Cannot delete: field is referenced by formula field '{formula_field.name}'",
|
||||
)
|
||||
|
||||
# Delete the field (cascade will delete associated values)
|
||||
db.delete(field)
|
||||
db.commit()
|
||||
|
||||
|
||||
@router.patch("/api/custom-fields/{field_id}/position", response_model=CustomFieldResponse)
|
||||
async def update_custom_field_position(
|
||||
field_id: str,
|
||||
position: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Update a custom field's position (for reordering).
|
||||
"""
|
||||
field = db.query(CustomField).filter(CustomField.id == field_id).first()
|
||||
|
||||
if not field:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Custom field not found",
|
||||
)
|
||||
|
||||
project = field.project
|
||||
if not check_project_edit_access(current_user, project):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Permission denied",
|
||||
)
|
||||
|
||||
old_position = field.position
|
||||
|
||||
if position == old_position:
|
||||
return custom_field_to_response(field)
|
||||
|
||||
# Reorder other fields
|
||||
if position > old_position:
|
||||
# Moving down: shift fields between old and new position up
|
||||
db.query(CustomField).filter(
|
||||
CustomField.project_id == field.project_id,
|
||||
CustomField.position > old_position,
|
||||
CustomField.position <= position,
|
||||
).update({CustomField.position: CustomField.position - 1})
|
||||
else:
|
||||
# Moving up: shift fields between new and old position down
|
||||
db.query(CustomField).filter(
|
||||
CustomField.project_id == field.project_id,
|
||||
CustomField.position >= position,
|
||||
CustomField.position < old_position,
|
||||
).update({CustomField.position: CustomField.position + 1})
|
||||
|
||||
field.position = position
|
||||
db.commit()
|
||||
db.refresh(field)
|
||||
|
||||
return custom_field_to_response(field)
|
||||
3
backend/app/api/task_dependencies/__init__.py
Normal file
3
backend/app/api/task_dependencies/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.api.task_dependencies.router import router
|
||||
|
||||
__all__ = ["router"]
|
||||
431
backend/app/api/task_dependencies/router.py
Normal file
431
backend/app/api/task_dependencies/router.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
Task Dependencies API Router
|
||||
|
||||
Provides CRUD operations for task dependencies used in Gantt view.
|
||||
Includes circular dependency detection and date constraint validation.
|
||||
"""
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models import User, Task, TaskDependency, AuditAction
|
||||
from app.schemas.task_dependency import (
|
||||
TaskDependencyCreate,
|
||||
TaskDependencyUpdate,
|
||||
TaskDependencyResponse,
|
||||
TaskDependencyListResponse,
|
||||
TaskInfo
|
||||
)
|
||||
from app.middleware.auth import get_current_user, check_task_access, check_task_edit_access
|
||||
from app.middleware.audit import get_audit_metadata
|
||||
from app.services.audit_service import AuditService
|
||||
from app.services.dependency_service import DependencyService, DependencyValidationError
|
||||
|
||||
router = APIRouter(tags=["task-dependencies"])
|
||||
|
||||
|
||||
def dependency_to_response(
|
||||
dep: TaskDependency,
|
||||
include_tasks: bool = True
|
||||
) -> TaskDependencyResponse:
|
||||
"""Convert TaskDependency model to response schema."""
|
||||
predecessor_info = None
|
||||
successor_info = None
|
||||
|
||||
if include_tasks:
|
||||
if dep.predecessor:
|
||||
predecessor_info = TaskInfo(
|
||||
id=dep.predecessor.id,
|
||||
title=dep.predecessor.title,
|
||||
start_date=dep.predecessor.start_date,
|
||||
due_date=dep.predecessor.due_date
|
||||
)
|
||||
if dep.successor:
|
||||
successor_info = TaskInfo(
|
||||
id=dep.successor.id,
|
||||
title=dep.successor.title,
|
||||
start_date=dep.successor.start_date,
|
||||
due_date=dep.successor.due_date
|
||||
)
|
||||
|
||||
return TaskDependencyResponse(
|
||||
id=dep.id,
|
||||
predecessor_id=dep.predecessor_id,
|
||||
successor_id=dep.successor_id,
|
||||
dependency_type=dep.dependency_type,
|
||||
lag_days=dep.lag_days,
|
||||
created_at=dep.created_at,
|
||||
predecessor=predecessor_info,
|
||||
successor=successor_info
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/api/tasks/{task_id}/dependencies",
|
||||
response_model=TaskDependencyResponse,
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def create_dependency(
|
||||
task_id: str,
|
||||
dependency_data: TaskDependencyCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Add a dependency to a task (the task becomes the successor).
|
||||
|
||||
The predecessor_id in the request body specifies which task must complete first.
|
||||
The task_id in the URL becomes the successor (depends on the predecessor).
|
||||
|
||||
Validates:
|
||||
- Both tasks exist and are in the same project
|
||||
- No self-reference
|
||||
- No duplicate dependency
|
||||
- No circular dependency
|
||||
- Dependency limit not exceeded
|
||||
"""
|
||||
# Get the successor task (from URL)
|
||||
successor = db.query(Task).filter(Task.id == task_id).first()
|
||||
if not successor:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Task not found"
|
||||
)
|
||||
|
||||
# Check edit permission on successor
|
||||
if not check_task_edit_access(current_user, successor, successor.project):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Permission denied"
|
||||
)
|
||||
|
||||
# Validate the dependency
|
||||
try:
|
||||
DependencyService.validate_dependency(
|
||||
db,
|
||||
predecessor_id=dependency_data.predecessor_id,
|
||||
successor_id=task_id
|
||||
)
|
||||
except DependencyValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error_type": e.error_type,
|
||||
"message": e.message,
|
||||
"details": e.details
|
||||
}
|
||||
)
|
||||
|
||||
# Create the dependency
|
||||
dependency = TaskDependency(
|
||||
id=str(uuid.uuid4()),
|
||||
predecessor_id=dependency_data.predecessor_id,
|
||||
successor_id=task_id,
|
||||
dependency_type=dependency_data.dependency_type.value,
|
||||
lag_days=dependency_data.lag_days
|
||||
)
|
||||
|
||||
db.add(dependency)
|
||||
|
||||
# Audit log
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="task.dependency.create",
|
||||
resource_type="task_dependency",
|
||||
action=AuditAction.CREATE,
|
||||
user_id=current_user.id,
|
||||
resource_id=dependency.id,
|
||||
changes=[{
|
||||
"field": "dependency",
|
||||
"old_value": None,
|
||||
"new_value": {
|
||||
"predecessor_id": dependency.predecessor_id,
|
||||
"successor_id": dependency.successor_id,
|
||||
"dependency_type": dependency.dependency_type,
|
||||
"lag_days": dependency.lag_days
|
||||
}
|
||||
}],
|
||||
request_metadata=get_audit_metadata(request)
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(dependency)
|
||||
|
||||
return dependency_to_response(dependency)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/api/tasks/{task_id}/dependencies",
|
||||
response_model=TaskDependencyListResponse
|
||||
)
|
||||
async def list_task_dependencies(
|
||||
task_id: str,
|
||||
direction: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Get all dependencies for a task.
|
||||
|
||||
Args:
|
||||
task_id: The task to get dependencies for
|
||||
direction: Optional filter
|
||||
- 'predecessors': Only get tasks this task depends on
|
||||
- 'successors': Only get tasks that depend on this task
|
||||
- None: Get both
|
||||
|
||||
Returns all dependencies where the task is either the predecessor or successor.
|
||||
"""
|
||||
task = db.query(Task).filter(Task.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Task not found"
|
||||
)
|
||||
|
||||
if not check_task_access(current_user, task, task.project):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied"
|
||||
)
|
||||
|
||||
dependencies = []
|
||||
|
||||
if direction is None or direction == "predecessors":
|
||||
# Get dependencies where this task is the successor (predecessors)
|
||||
predecessor_deps = db.query(TaskDependency).filter(
|
||||
TaskDependency.successor_id == task_id
|
||||
).all()
|
||||
dependencies.extend(predecessor_deps)
|
||||
|
||||
if direction is None or direction == "successors":
|
||||
# Get dependencies where this task is the predecessor (successors)
|
||||
successor_deps = db.query(TaskDependency).filter(
|
||||
TaskDependency.predecessor_id == task_id
|
||||
).all()
|
||||
# Avoid duplicates if direction is None
|
||||
if direction is None:
|
||||
for dep in successor_deps:
|
||||
if dep not in dependencies:
|
||||
dependencies.append(dep)
|
||||
else:
|
||||
dependencies.extend(successor_deps)
|
||||
|
||||
return TaskDependencyListResponse(
|
||||
dependencies=[dependency_to_response(d) for d in dependencies],
|
||||
total=len(dependencies)
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/api/task-dependencies/{dependency_id}",
|
||||
response_model=TaskDependencyResponse
|
||||
)
|
||||
async def get_dependency(
|
||||
dependency_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Get a specific dependency by ID."""
|
||||
dependency = db.query(TaskDependency).filter(
|
||||
TaskDependency.id == dependency_id
|
||||
).first()
|
||||
|
||||
if not dependency:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Dependency not found"
|
||||
)
|
||||
|
||||
# Check access via the successor task
|
||||
task = dependency.successor
|
||||
if not check_task_access(current_user, task, task.project):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied"
|
||||
)
|
||||
|
||||
return dependency_to_response(dependency)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/api/task-dependencies/{dependency_id}",
|
||||
response_model=TaskDependencyResponse
|
||||
)
|
||||
async def update_dependency(
|
||||
dependency_id: str,
|
||||
update_data: TaskDependencyUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Update a dependency's type or lag days.
|
||||
|
||||
Cannot change predecessor_id or successor_id - delete and recreate instead.
|
||||
"""
|
||||
dependency = db.query(TaskDependency).filter(
|
||||
TaskDependency.id == dependency_id
|
||||
).first()
|
||||
|
||||
if not dependency:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Dependency not found"
|
||||
)
|
||||
|
||||
# Check edit permission via the successor task
|
||||
task = dependency.successor
|
||||
if not check_task_edit_access(current_user, task, task.project):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Permission denied"
|
||||
)
|
||||
|
||||
# Track changes for audit
|
||||
old_values = {
|
||||
"dependency_type": dependency.dependency_type,
|
||||
"lag_days": dependency.lag_days
|
||||
}
|
||||
|
||||
# Update fields
|
||||
if update_data.dependency_type is not None:
|
||||
dependency.dependency_type = update_data.dependency_type.value
|
||||
|
||||
if update_data.lag_days is not None:
|
||||
dependency.lag_days = update_data.lag_days
|
||||
|
||||
new_values = {
|
||||
"dependency_type": dependency.dependency_type,
|
||||
"lag_days": dependency.lag_days
|
||||
}
|
||||
|
||||
# Audit log
|
||||
changes = AuditService.detect_changes(old_values, new_values)
|
||||
if changes:
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="task.dependency.update",
|
||||
resource_type="task_dependency",
|
||||
action=AuditAction.UPDATE,
|
||||
user_id=current_user.id,
|
||||
resource_id=dependency.id,
|
||||
changes=changes,
|
||||
request_metadata=get_audit_metadata(request)
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(dependency)
|
||||
|
||||
return dependency_to_response(dependency)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/api/task-dependencies/{dependency_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT
|
||||
)
|
||||
async def delete_dependency(
|
||||
dependency_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Delete a dependency."""
|
||||
dependency = db.query(TaskDependency).filter(
|
||||
TaskDependency.id == dependency_id
|
||||
).first()
|
||||
|
||||
if not dependency:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Dependency not found"
|
||||
)
|
||||
|
||||
# Check edit permission via the successor task
|
||||
task = dependency.successor
|
||||
if not check_task_edit_access(current_user, task, task.project):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Permission denied"
|
||||
)
|
||||
|
||||
# Audit log
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type="task.dependency.delete",
|
||||
resource_type="task_dependency",
|
||||
action=AuditAction.DELETE,
|
||||
user_id=current_user.id,
|
||||
resource_id=dependency.id,
|
||||
changes=[{
|
||||
"field": "dependency",
|
||||
"old_value": {
|
||||
"predecessor_id": dependency.predecessor_id,
|
||||
"successor_id": dependency.successor_id,
|
||||
"dependency_type": dependency.dependency_type,
|
||||
"lag_days": dependency.lag_days
|
||||
},
|
||||
"new_value": None
|
||||
}],
|
||||
request_metadata=get_audit_metadata(request)
|
||||
)
|
||||
|
||||
db.delete(dependency)
|
||||
db.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/api/projects/{project_id}/dependencies",
|
||||
response_model=TaskDependencyListResponse
|
||||
)
|
||||
async def list_project_dependencies(
|
||||
project_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Get all dependencies for a project.
|
||||
|
||||
Useful for rendering the full Gantt chart with all dependency arrows.
|
||||
"""
|
||||
from app.models import Project
|
||||
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Project not found"
|
||||
)
|
||||
|
||||
from app.middleware.auth import check_project_access
|
||||
if not check_project_access(current_user, project):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied"
|
||||
)
|
||||
|
||||
# Get all dependencies for tasks in this project (exclude soft-deleted tasks)
|
||||
# Create aliases for joining both predecessor and successor
|
||||
from sqlalchemy.orm import aliased
|
||||
Successor = aliased(Task)
|
||||
Predecessor = aliased(Task)
|
||||
|
||||
dependencies = db.query(TaskDependency).join(
|
||||
Successor, TaskDependency.successor_id == Successor.id
|
||||
).join(
|
||||
Predecessor, TaskDependency.predecessor_id == Predecessor.id
|
||||
).filter(
|
||||
Successor.project_id == project_id,
|
||||
Successor.is_deleted == False,
|
||||
Predecessor.is_deleted == False
|
||||
).all()
|
||||
|
||||
return TaskDependencyListResponse(
|
||||
dependencies=[dependency_to_response(d) for d in dependencies],
|
||||
total=len(dependencies)
|
||||
)
|
||||
@@ -10,7 +10,7 @@ from app.core.redis_pubsub import publish_task_event
|
||||
from app.models import User, Project, Task, TaskStatus, AuditAction, Blocker
|
||||
from app.schemas.task import (
|
||||
TaskCreate, TaskUpdate, TaskResponse, TaskWithDetails, TaskListResponse,
|
||||
TaskStatusUpdate, TaskAssignUpdate
|
||||
TaskStatusUpdate, TaskAssignUpdate, CustomValueResponse
|
||||
)
|
||||
from app.middleware.auth import (
|
||||
get_current_user, check_project_access, check_task_access, check_task_edit_access
|
||||
@@ -19,6 +19,8 @@ from app.middleware.audit import get_audit_metadata
|
||||
from app.services.audit_service import AuditService
|
||||
from app.services.trigger_service import TriggerService
|
||||
from app.services.workload_cache import invalidate_user_workload_cache
|
||||
from app.services.custom_value_service import CustomValueService
|
||||
from app.services.dependency_service import DependencyService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,13 +42,18 @@ def get_task_depth(db: Session, task: Task) -> int:
|
||||
return depth
|
||||
|
||||
|
||||
def task_to_response(task: Task) -> TaskWithDetails:
|
||||
def task_to_response(task: Task, db: Session = None, include_custom_values: bool = False) -> TaskWithDetails:
|
||||
"""Convert a Task model to TaskWithDetails response."""
|
||||
# Count only non-deleted subtasks
|
||||
subtask_count = 0
|
||||
if task.subtasks:
|
||||
subtask_count = sum(1 for st in task.subtasks if not st.is_deleted)
|
||||
|
||||
# Get custom values if requested
|
||||
custom_values = None
|
||||
if include_custom_values and db:
|
||||
custom_values = CustomValueService.get_custom_values_for_task(db, task)
|
||||
|
||||
return TaskWithDetails(
|
||||
id=task.id,
|
||||
project_id=task.project_id,
|
||||
@@ -56,6 +63,7 @@ def task_to_response(task: Task) -> TaskWithDetails:
|
||||
priority=task.priority,
|
||||
original_estimate=task.original_estimate,
|
||||
time_spent=task.time_spent,
|
||||
start_date=task.start_date,
|
||||
due_date=task.due_date,
|
||||
assignee_id=task.assignee_id,
|
||||
status_id=task.status_id,
|
||||
@@ -69,6 +77,7 @@ def task_to_response(task: Task) -> TaskWithDetails:
|
||||
status_color=task.status.color if task.status else None,
|
||||
creator_name=task.creator.name if task.creator else None,
|
||||
subtask_count=subtask_count,
|
||||
custom_values=custom_values,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,12 +87,24 @@ async def list_tasks(
|
||||
parent_task_id: Optional[str] = Query(None, description="Filter by parent task"),
|
||||
status_id: Optional[str] = Query(None, description="Filter by status"),
|
||||
assignee_id: Optional[str] = Query(None, description="Filter by assignee"),
|
||||
due_after: Optional[datetime] = Query(None, description="Filter tasks with due_date >= this value (for calendar view)"),
|
||||
due_before: Optional[datetime] = Query(None, description="Filter tasks with due_date <= this value (for calendar view)"),
|
||||
include_deleted: bool = Query(False, description="Include deleted tasks (admin only)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
List all tasks in a project.
|
||||
|
||||
Supports filtering by:
|
||||
- parent_task_id: Filter by parent task (empty string for root tasks only)
|
||||
- status_id: Filter by task status
|
||||
- assignee_id: Filter by assigned user
|
||||
- due_after: Filter tasks with due_date >= this value (ISO 8601 datetime)
|
||||
- due_before: Filter tasks with due_date <= this value (ISO 8601 datetime)
|
||||
|
||||
The due_after and due_before parameters are useful for calendar view
|
||||
to fetch tasks within a specific date range.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
|
||||
@@ -124,10 +145,17 @@ async def list_tasks(
|
||||
if assignee_id:
|
||||
query = query.filter(Task.assignee_id == assignee_id)
|
||||
|
||||
# Date range filter for calendar view
|
||||
if due_after:
|
||||
query = query.filter(Task.due_date >= due_after)
|
||||
|
||||
if due_before:
|
||||
query = query.filter(Task.due_date <= due_before)
|
||||
|
||||
tasks = query.order_by(Task.position, Task.created_at).all()
|
||||
|
||||
return TaskListResponse(
|
||||
tasks=[task_to_response(t) for t in tasks],
|
||||
tasks=[task_to_response(t, db=db, include_custom_values=True) for t in tasks],
|
||||
total=len(tasks),
|
||||
)
|
||||
|
||||
@@ -204,6 +232,25 @@ async def create_task(
|
||||
).order_by(Task.position.desc()).first()
|
||||
next_position = (max_pos_result.position + 1) if max_pos_result else 0
|
||||
|
||||
# Validate required custom fields
|
||||
if task_data.custom_values:
|
||||
missing_fields = CustomValueService.validate_required_fields(
|
||||
db, project_id, task_data.custom_values
|
||||
)
|
||||
if missing_fields:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Missing required custom fields: {', '.join(missing_fields)}",
|
||||
)
|
||||
|
||||
# Validate start_date <= due_date
|
||||
if task_data.start_date and task_data.due_date:
|
||||
if task_data.start_date > task_data.due_date:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Start date cannot be after due date",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project_id,
|
||||
@@ -212,6 +259,7 @@ async def create_task(
|
||||
description=task_data.description,
|
||||
priority=task_data.priority.value if task_data.priority else "medium",
|
||||
original_estimate=task_data.original_estimate,
|
||||
start_date=task_data.start_date,
|
||||
due_date=task_data.due_date,
|
||||
assignee_id=task_data.assignee_id,
|
||||
status_id=task_data.status_id,
|
||||
@@ -220,6 +268,17 @@ async def create_task(
|
||||
)
|
||||
|
||||
db.add(task)
|
||||
db.flush() # Flush to get task.id for custom values
|
||||
|
||||
# Save custom values
|
||||
if task_data.custom_values:
|
||||
try:
|
||||
CustomValueService.save_custom_values(db, task, task_data.custom_values)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# Audit log
|
||||
AuditService.log_event(
|
||||
@@ -256,6 +315,7 @@ async def create_task(
|
||||
"assignee_id": str(task.assignee_id) if task.assignee_id else None,
|
||||
"assignee_name": task.assignee.name if task.assignee else None,
|
||||
"priority": task.priority,
|
||||
"start_date": str(task.start_date) if task.start_date else None,
|
||||
"due_date": str(task.due_date) if task.due_date else None,
|
||||
"time_estimate": task.original_estimate,
|
||||
"original_estimate": task.original_estimate,
|
||||
@@ -303,7 +363,7 @@ async def get_task(
|
||||
detail="Access denied",
|
||||
)
|
||||
|
||||
return task_to_response(task)
|
||||
return task_to_response(task, db, include_custom_values=True)
|
||||
|
||||
|
||||
@router.patch("/api/tasks/{task_id}", response_model=TaskResponse)
|
||||
@@ -336,13 +396,42 @@ async def update_task(
|
||||
"title": task.title,
|
||||
"description": task.description,
|
||||
"priority": task.priority,
|
||||
"start_date": task.start_date,
|
||||
"due_date": task.due_date,
|
||||
"original_estimate": task.original_estimate,
|
||||
"time_spent": task.time_spent,
|
||||
}
|
||||
|
||||
# Update fields
|
||||
# Update fields (exclude custom_values, handle separately)
|
||||
update_data = task_data.model_dump(exclude_unset=True)
|
||||
custom_values_data = update_data.pop("custom_values", None)
|
||||
|
||||
# Get the proposed start_date and due_date for validation
|
||||
new_start_date = update_data.get("start_date", task.start_date)
|
||||
new_due_date = update_data.get("due_date", task.due_date)
|
||||
|
||||
# Validate start_date <= due_date
|
||||
if new_start_date and new_due_date:
|
||||
if new_start_date > new_due_date:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Start date cannot be after due date",
|
||||
)
|
||||
|
||||
# Validate date constraints against dependencies
|
||||
if "start_date" in update_data or "due_date" in update_data:
|
||||
violations = DependencyService.validate_date_constraints(
|
||||
task, new_start_date, new_due_date, db
|
||||
)
|
||||
if violations:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": "Date change violates dependency constraints",
|
||||
"violations": violations
|
||||
}
|
||||
)
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field == "priority" and value:
|
||||
setattr(task, field, value.value)
|
||||
@@ -354,6 +443,7 @@ async def update_task(
|
||||
"title": task.title,
|
||||
"description": task.description,
|
||||
"priority": task.priority,
|
||||
"start_date": task.start_date,
|
||||
"due_date": task.due_date,
|
||||
"original_estimate": task.original_estimate,
|
||||
"time_spent": task.time_spent,
|
||||
@@ -377,6 +467,18 @@ async def update_task(
|
||||
if "priority" in update_data:
|
||||
TriggerService.evaluate_triggers(db, task, old_values, new_values, current_user)
|
||||
|
||||
# Handle custom values update
|
||||
if custom_values_data:
|
||||
try:
|
||||
from app.schemas.task import CustomValueInput
|
||||
custom_values = [CustomValueInput(**cv) for cv in custom_values_data]
|
||||
CustomValueService.save_custom_values(db, task, custom_values)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
|
||||
@@ -400,6 +502,7 @@ async def update_task(
|
||||
"assignee_id": str(task.assignee_id) if task.assignee_id else None,
|
||||
"assignee_name": task.assignee.name if task.assignee else None,
|
||||
"priority": task.priority,
|
||||
"start_date": str(task.start_date) if task.start_date else None,
|
||||
"due_date": str(task.due_date) if task.due_date else None,
|
||||
"time_estimate": task.original_estimate,
|
||||
"original_estimate": task.original_estimate,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import field_validator
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import os
|
||||
|
||||
|
||||
@@ -52,6 +52,35 @@ class Settings(BaseSettings):
|
||||
)
|
||||
return v
|
||||
|
||||
# Encryption - Master key for encrypting file encryption keys
|
||||
# Must be a 32-byte (256-bit) key encoded as base64 for AES-256
|
||||
# Generate with: python -c "import secrets, base64; print(base64.urlsafe_b64encode(secrets.token_bytes(32)).decode())"
|
||||
ENCRYPTION_MASTER_KEY: Optional[str] = None
|
||||
|
||||
@field_validator("ENCRYPTION_MASTER_KEY")
|
||||
@classmethod
|
||||
def validate_encryption_master_key(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that ENCRYPTION_MASTER_KEY is properly formatted if set."""
|
||||
if v is None or v.strip() == "":
|
||||
return None
|
||||
# Basic validation - should be base64 encoded 32 bytes
|
||||
import base64
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(v)
|
||||
if len(decoded) != 32:
|
||||
raise ValueError(
|
||||
"ENCRYPTION_MASTER_KEY must be a base64-encoded 32-byte key. "
|
||||
"Generate with: python -c \"import secrets, base64; print(base64.urlsafe_b64encode(secrets.token_bytes(32)).decode())\""
|
||||
)
|
||||
except Exception as e:
|
||||
if "must be a base64-encoded" in str(e):
|
||||
raise
|
||||
raise ValueError(
|
||||
"ENCRYPTION_MASTER_KEY must be a valid base64-encoded string. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
return v
|
||||
|
||||
# External Auth API
|
||||
AUTH_API_URL: str = "https://pj-auth-api.vercel.app"
|
||||
|
||||
|
||||
@@ -34,6 +34,9 @@ from app.api.attachments import router as attachments_router
|
||||
from app.api.triggers import router as triggers_router
|
||||
from app.api.reports import router as reports_router
|
||||
from app.api.health import router as health_router
|
||||
from app.api.custom_fields import router as custom_fields_router
|
||||
from app.api.task_dependencies import router as task_dependencies_router
|
||||
from app.api.admin import encryption_keys as admin_encryption_keys_router
|
||||
from app.core.config import settings
|
||||
|
||||
app = FastAPI(
|
||||
@@ -76,6 +79,9 @@ app.include_router(attachments_router)
|
||||
app.include_router(triggers_router)
|
||||
app.include_router(reports_router)
|
||||
app.include_router(health_router)
|
||||
app.include_router(custom_fields_router)
|
||||
app.include_router(task_dependencies_router)
|
||||
app.include_router(admin_encryption_keys_router.router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
||||
@@ -12,6 +12,7 @@ from app.models.notification import Notification
|
||||
from app.models.blocker import Blocker
|
||||
from app.models.audit_log import AuditLog, AuditAction, SensitivityLevel, EVENT_SENSITIVITY, ALERT_EVENTS
|
||||
from app.models.audit_alert import AuditAlert
|
||||
from app.models.encryption_key import EncryptionKey
|
||||
from app.models.attachment import Attachment
|
||||
from app.models.attachment_version import AttachmentVersion
|
||||
from app.models.trigger import Trigger, TriggerType
|
||||
@@ -19,13 +20,18 @@ from app.models.trigger_log import TriggerLog, TriggerLogStatus
|
||||
from app.models.scheduled_report import ScheduledReport, ReportType
|
||||
from app.models.report_history import ReportHistory, ReportHistoryStatus
|
||||
from app.models.project_health import ProjectHealth, RiskLevel, ScheduleStatus, ResourceStatus
|
||||
from app.models.custom_field import CustomField, FieldType
|
||||
from app.models.task_custom_value import TaskCustomValue
|
||||
from app.models.task_dependency import TaskDependency, DependencyType
|
||||
|
||||
__all__ = [
|
||||
"User", "Role", "Department", "Space", "Project", "TaskStatus", "Task", "WorkloadSnapshot",
|
||||
"Comment", "Mention", "Notification", "Blocker",
|
||||
"AuditLog", "AuditAlert", "AuditAction", "SensitivityLevel", "EVENT_SENSITIVITY", "ALERT_EVENTS",
|
||||
"Attachment", "AttachmentVersion",
|
||||
"EncryptionKey", "Attachment", "AttachmentVersion",
|
||||
"Trigger", "TriggerType", "TriggerLog", "TriggerLogStatus",
|
||||
"ScheduledReport", "ReportType", "ReportHistory", "ReportHistoryStatus",
|
||||
"ProjectHealth", "RiskLevel", "ScheduleStatus", "ResourceStatus"
|
||||
"ProjectHealth", "RiskLevel", "ScheduleStatus", "ResourceStatus",
|
||||
"CustomField", "FieldType", "TaskCustomValue",
|
||||
"TaskDependency", "DependencyType"
|
||||
]
|
||||
|
||||
@@ -16,6 +16,11 @@ class Attachment(Base):
|
||||
file_size = Column(BigInteger, nullable=False)
|
||||
current_version = Column(Integer, default=1, nullable=False)
|
||||
is_encrypted = Column(Boolean, default=False, nullable=False)
|
||||
encryption_key_id = Column(
|
||||
String(36),
|
||||
ForeignKey("pjctrl_encryption_keys.id", ondelete="SET NULL"),
|
||||
nullable=True
|
||||
)
|
||||
uploaded_by = Column(String(36), ForeignKey("pjctrl_users.id", ondelete="SET NULL"), nullable=True)
|
||||
is_deleted = Column(Boolean, default=False, nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now(), nullable=False)
|
||||
@@ -24,6 +29,7 @@ class Attachment(Base):
|
||||
# Relationships
|
||||
task = relationship("Task", back_populates="attachments")
|
||||
uploader = relationship("User", foreign_keys=[uploaded_by])
|
||||
encryption_key = relationship("EncryptionKey", foreign_keys=[encryption_key_id])
|
||||
versions = relationship("AttachmentVersion", back_populates="attachment", cascade="all, delete-orphan")
|
||||
|
||||
__table_args__ = (
|
||||
|
||||
37
backend/app/models/custom_field.py
Normal file
37
backend/app/models/custom_field.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import uuid
|
||||
import enum
|
||||
from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Enum, JSON, Integer
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class FieldType(str, enum.Enum):
|
||||
TEXT = "text"
|
||||
NUMBER = "number"
|
||||
DROPDOWN = "dropdown"
|
||||
DATE = "date"
|
||||
PERSON = "person"
|
||||
FORMULA = "formula"
|
||||
|
||||
|
||||
class CustomField(Base):
|
||||
__tablename__ = "pjctrl_custom_fields"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
project_id = Column(String(36), ForeignKey("pjctrl_projects.id", ondelete="CASCADE"), nullable=False)
|
||||
name = Column(String(100), nullable=False)
|
||||
field_type = Column(
|
||||
Enum("text", "number", "dropdown", "date", "person", "formula", name="field_type_enum"),
|
||||
nullable=False
|
||||
)
|
||||
options = Column(JSON, nullable=True) # For dropdown: list of options
|
||||
formula = Column(Text, nullable=True) # For formula: formula expression
|
||||
is_required = Column(Boolean, default=False, nullable=False)
|
||||
position = Column(Integer, default=0, nullable=False) # For ordering fields
|
||||
created_at = Column(DateTime, server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="custom_fields")
|
||||
values = relationship("TaskCustomValue", back_populates="field", cascade="all, delete-orphan")
|
||||
22
backend/app/models/encryption_key.py
Normal file
22
backend/app/models/encryption_key.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""EncryptionKey model for AES-256 file encryption key management."""
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Text, Boolean, DateTime
|
||||
from sqlalchemy.sql import func
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class EncryptionKey(Base):
|
||||
"""
|
||||
Encryption key storage for file encryption.
|
||||
|
||||
Keys are encrypted with the Master Key before storage.
|
||||
Only system admin can manage encryption keys.
|
||||
"""
|
||||
__tablename__ = "pjctrl_encryption_keys"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
key_data = Column(Text, nullable=False) # Encrypted key using Master Key
|
||||
algorithm = Column(String(20), default="AES-256-GCM", nullable=False)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now(), nullable=False)
|
||||
rotated_at = Column(DateTime, nullable=True) # When this key was superseded
|
||||
@@ -40,3 +40,4 @@ class Project(Base):
|
||||
tasks = relationship("Task", back_populates="project", cascade="all, delete-orphan")
|
||||
triggers = relationship("Trigger", back_populates="project", cascade="all, delete-orphan")
|
||||
health = relationship("ProjectHealth", back_populates="project", uselist=False, cascade="all, delete-orphan")
|
||||
custom_fields = relationship("CustomField", back_populates="project", cascade="all, delete-orphan")
|
||||
|
||||
@@ -30,6 +30,7 @@ class Task(Base):
|
||||
original_estimate = Column(Numeric(8, 2), nullable=True)
|
||||
time_spent = Column(Numeric(8, 2), default=0, nullable=False)
|
||||
blocker_flag = Column(Boolean, default=False, nullable=False)
|
||||
start_date = Column(DateTime, nullable=True)
|
||||
due_date = Column(DateTime, nullable=True)
|
||||
position = Column(Integer, default=0, nullable=False)
|
||||
created_by = Column(String(36), ForeignKey("pjctrl_users.id"), nullable=False)
|
||||
@@ -55,3 +56,18 @@ class Task(Base):
|
||||
blockers = relationship("Blocker", back_populates="task", cascade="all, delete-orphan")
|
||||
attachments = relationship("Attachment", back_populates="task", cascade="all, delete-orphan")
|
||||
trigger_logs = relationship("TriggerLog", back_populates="task")
|
||||
custom_values = relationship("TaskCustomValue", back_populates="task", cascade="all, delete-orphan")
|
||||
|
||||
# Dependency relationships (for Gantt view)
|
||||
predecessors = relationship(
|
||||
"TaskDependency",
|
||||
foreign_keys="TaskDependency.successor_id",
|
||||
back_populates="successor",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
successors = relationship(
|
||||
"TaskDependency",
|
||||
foreign_keys="TaskDependency.predecessor_id",
|
||||
back_populates="predecessor",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
24
backend/app/models/task_custom_value.py
Normal file
24
backend/app/models/task_custom_value.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Text, DateTime, ForeignKey, UniqueConstraint
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class TaskCustomValue(Base):
|
||||
__tablename__ = "pjctrl_task_custom_values"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
task_id = Column(String(36), ForeignKey("pjctrl_tasks.id", ondelete="CASCADE"), nullable=False)
|
||||
field_id = Column(String(36), ForeignKey("pjctrl_custom_fields.id", ondelete="CASCADE"), nullable=False)
|
||||
value = Column(Text, nullable=True) # Stored as text, parsed based on field_type
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
|
||||
# Unique constraint: one value per task-field combination
|
||||
__table_args__ = (
|
||||
UniqueConstraint('task_id', 'field_id', name='uq_task_field'),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
task = relationship("Task", back_populates="custom_values")
|
||||
field = relationship("CustomField", back_populates="values")
|
||||
68
backend/app/models/task_dependency.py
Normal file
68
backend/app/models/task_dependency.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from sqlalchemy import Column, String, Integer, Enum, DateTime, ForeignKey, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
from app.core.database import Base
|
||||
import enum
|
||||
|
||||
|
||||
class DependencyType(str, enum.Enum):
|
||||
"""
|
||||
Task dependency types for Gantt chart.
|
||||
|
||||
FS (Finish-to-Start): Predecessor must finish before successor starts (most common)
|
||||
SS (Start-to-Start): Predecessor must start before successor starts
|
||||
FF (Finish-to-Finish): Predecessor must finish before successor finishes
|
||||
SF (Start-to-Finish): Predecessor must start before successor finishes (rare)
|
||||
"""
|
||||
FS = "FS" # Finish-to-Start
|
||||
SS = "SS" # Start-to-Start
|
||||
FF = "FF" # Finish-to-Finish
|
||||
SF = "SF" # Start-to-Finish
|
||||
|
||||
|
||||
class TaskDependency(Base):
|
||||
"""
|
||||
Represents a dependency relationship between two tasks.
|
||||
|
||||
The predecessor task affects when the successor task can be scheduled,
|
||||
based on the dependency_type. This is used for Gantt chart visualization
|
||||
and date validation.
|
||||
"""
|
||||
__tablename__ = "pjctrl_task_dependencies"
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('predecessor_id', 'successor_id', name='uq_predecessor_successor'),
|
||||
)
|
||||
|
||||
id = Column(String(36), primary_key=True)
|
||||
predecessor_id = Column(
|
||||
String(36),
|
||||
ForeignKey("pjctrl_tasks.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
successor_id = Column(
|
||||
String(36),
|
||||
ForeignKey("pjctrl_tasks.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
dependency_type = Column(
|
||||
Enum("FS", "SS", "FF", "SF", name="dependency_type_enum"),
|
||||
default="FS",
|
||||
nullable=False
|
||||
)
|
||||
lag_days = Column(Integer, default=0, nullable=False)
|
||||
created_at = Column(DateTime, server_default=func.now(), nullable=False)
|
||||
|
||||
# Relationships
|
||||
predecessor = relationship(
|
||||
"Task",
|
||||
foreign_keys=[predecessor_id],
|
||||
back_populates="successors"
|
||||
)
|
||||
successor = relationship(
|
||||
"Task",
|
||||
foreign_keys=[successor_id],
|
||||
back_populates="predecessors"
|
||||
)
|
||||
88
backend/app/schemas/custom_field.py
Normal file
88
backend/app/schemas/custom_field.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import Optional, List, Any, Dict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class FieldType(str, Enum):
|
||||
TEXT = "text"
|
||||
NUMBER = "number"
|
||||
DROPDOWN = "dropdown"
|
||||
DATE = "date"
|
||||
PERSON = "person"
|
||||
FORMULA = "formula"
|
||||
|
||||
|
||||
class CustomFieldBase(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
field_type: FieldType
|
||||
options: Optional[List[str]] = None # For dropdown type
|
||||
formula: Optional[str] = None # For formula type
|
||||
is_required: bool = False
|
||||
|
||||
@field_validator('options')
|
||||
@classmethod
|
||||
def validate_options(cls, v, info):
|
||||
field_type = info.data.get('field_type')
|
||||
if field_type == FieldType.DROPDOWN:
|
||||
if not v or len(v) == 0:
|
||||
raise ValueError('Dropdown fields must have at least one option')
|
||||
if len(v) > 50:
|
||||
raise ValueError('Dropdown fields can have at most 50 options')
|
||||
return v
|
||||
|
||||
@field_validator('formula')
|
||||
@classmethod
|
||||
def validate_formula(cls, v, info):
|
||||
field_type = info.data.get('field_type')
|
||||
if field_type == FieldType.FORMULA and not v:
|
||||
raise ValueError('Formula fields must have a formula expression')
|
||||
return v
|
||||
|
||||
|
||||
class CustomFieldCreate(CustomFieldBase):
|
||||
pass
|
||||
|
||||
|
||||
class CustomFieldUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
options: Optional[List[str]] = None
|
||||
formula: Optional[str] = None
|
||||
is_required: Optional[bool] = None
|
||||
|
||||
|
||||
class CustomFieldResponse(CustomFieldBase):
|
||||
id: str
|
||||
project_id: str
|
||||
position: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CustomFieldListResponse(BaseModel):
|
||||
fields: List[CustomFieldResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# Task custom value schemas
|
||||
class CustomValueInput(BaseModel):
|
||||
field_id: str
|
||||
value: Optional[Any] = None # Can be string, number, date string, or user id
|
||||
|
||||
|
||||
class CustomValueResponse(BaseModel):
|
||||
field_id: str
|
||||
field_name: str
|
||||
field_type: FieldType
|
||||
value: Optional[Any] = None
|
||||
display_value: Optional[str] = None # Formatted for display
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TaskCustomValuesUpdate(BaseModel):
|
||||
custom_values: List[CustomValueInput]
|
||||
46
backend/app/schemas/encryption_key.py
Normal file
46
backend/app/schemas/encryption_key.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Schemas for encryption key API."""
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EncryptionKeyResponse(BaseModel):
|
||||
"""Response schema for encryption key (without actual key data)."""
|
||||
id: str
|
||||
algorithm: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
rotated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class EncryptionKeyListResponse(BaseModel):
|
||||
"""Response schema for list of encryption keys."""
|
||||
keys: List[EncryptionKeyResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class EncryptionKeyCreateResponse(BaseModel):
|
||||
"""Response schema after creating a new encryption key."""
|
||||
id: str
|
||||
algorithm: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
message: str
|
||||
|
||||
|
||||
class EncryptionKeyRotateResponse(BaseModel):
|
||||
"""Response schema after key rotation."""
|
||||
new_key_id: str
|
||||
old_key_id: Optional[str]
|
||||
message: str
|
||||
|
||||
|
||||
class EncryptionStatusResponse(BaseModel):
|
||||
"""Response schema for encryption status."""
|
||||
encryption_available: bool
|
||||
active_key_id: Optional[str]
|
||||
total_keys: int
|
||||
message: str
|
||||
@@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Any, Dict
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
@@ -12,11 +12,27 @@ class Priority(str, Enum):
|
||||
URGENT = "urgent"
|
||||
|
||||
|
||||
class CustomValueInput(BaseModel):
|
||||
"""Input for setting a custom field value."""
|
||||
field_id: str
|
||||
value: Optional[Any] = None # Can be string, number, date string, or user id
|
||||
|
||||
|
||||
class CustomValueResponse(BaseModel):
|
||||
"""Response for a custom field value."""
|
||||
field_id: str
|
||||
field_name: str
|
||||
field_type: str
|
||||
value: Optional[Any] = None
|
||||
display_value: Optional[str] = None # Formatted for display
|
||||
|
||||
|
||||
class TaskBase(BaseModel):
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
priority: Priority = Priority.MEDIUM
|
||||
original_estimate: Optional[Decimal] = None
|
||||
start_date: Optional[datetime] = None
|
||||
due_date: Optional[datetime] = None
|
||||
|
||||
|
||||
@@ -24,6 +40,7 @@ class TaskCreate(TaskBase):
|
||||
parent_task_id: Optional[str] = None
|
||||
assignee_id: Optional[str] = None
|
||||
status_id: Optional[str] = None
|
||||
custom_values: Optional[List[CustomValueInput]] = None
|
||||
|
||||
|
||||
class TaskUpdate(BaseModel):
|
||||
@@ -32,8 +49,10 @@ class TaskUpdate(BaseModel):
|
||||
priority: Optional[Priority] = None
|
||||
original_estimate: Optional[Decimal] = None
|
||||
time_spent: Optional[Decimal] = None
|
||||
start_date: Optional[datetime] = None
|
||||
due_date: Optional[datetime] = None
|
||||
position: Optional[int] = None
|
||||
custom_values: Optional[List[CustomValueInput]] = None
|
||||
|
||||
|
||||
class TaskStatusUpdate(BaseModel):
|
||||
@@ -67,6 +86,7 @@ class TaskWithDetails(TaskResponse):
|
||||
status_color: Optional[str] = None
|
||||
creator_name: Optional[str] = None
|
||||
subtask_count: int = 0
|
||||
custom_values: Optional[List[CustomValueResponse]] = None
|
||||
|
||||
|
||||
class TaskListResponse(BaseModel):
|
||||
|
||||
78
backend/app/schemas/task_dependency.py
Normal file
78
backend/app/schemas/task_dependency.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from pydantic import BaseModel, field_validator
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DependencyType(str, Enum):
|
||||
"""Task dependency types for Gantt chart."""
|
||||
FS = "FS" # Finish-to-Start (most common)
|
||||
SS = "SS" # Start-to-Start
|
||||
FF = "FF" # Finish-to-Finish
|
||||
SF = "SF" # Start-to-Finish (rare)
|
||||
|
||||
|
||||
class TaskDependencyCreate(BaseModel):
|
||||
"""Schema for creating a task dependency."""
|
||||
predecessor_id: str
|
||||
dependency_type: DependencyType = DependencyType.FS
|
||||
lag_days: int = 0
|
||||
|
||||
@field_validator('lag_days')
|
||||
@classmethod
|
||||
def validate_lag_days(cls, v):
|
||||
if v < -365 or v > 365:
|
||||
raise ValueError('lag_days must be between -365 and 365')
|
||||
return v
|
||||
|
||||
|
||||
class TaskDependencyUpdate(BaseModel):
|
||||
"""Schema for updating a task dependency."""
|
||||
dependency_type: Optional[DependencyType] = None
|
||||
lag_days: Optional[int] = None
|
||||
|
||||
@field_validator('lag_days')
|
||||
@classmethod
|
||||
def validate_lag_days(cls, v):
|
||||
if v is not None and (v < -365 or v > 365):
|
||||
raise ValueError('lag_days must be between -365 and 365')
|
||||
return v
|
||||
|
||||
|
||||
class TaskInfo(BaseModel):
|
||||
"""Brief task information for dependency response."""
|
||||
id: str
|
||||
title: str
|
||||
start_date: Optional[datetime] = None
|
||||
due_date: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TaskDependencyResponse(BaseModel):
|
||||
"""Schema for task dependency response."""
|
||||
id: str
|
||||
predecessor_id: str
|
||||
successor_id: str
|
||||
dependency_type: DependencyType
|
||||
lag_days: int
|
||||
created_at: datetime
|
||||
predecessor: Optional[TaskInfo] = None
|
||||
successor: Optional[TaskInfo] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TaskDependencyListResponse(BaseModel):
|
||||
"""Schema for list of task dependencies."""
|
||||
dependencies: List[TaskDependencyResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class DependencyValidationError(BaseModel):
|
||||
"""Schema for dependency validation error details."""
|
||||
error_type: str # 'circular', 'self_reference', 'duplicate', 'cross_project'
|
||||
message: str
|
||||
details: Optional[dict] = None
|
||||
278
backend/app/services/custom_value_service.py
Normal file
278
backend/app/services/custom_value_service.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Service for managing task custom values.
|
||||
"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Task, CustomField, TaskCustomValue, User
|
||||
from app.schemas.task import CustomValueInput, CustomValueResponse
|
||||
from app.services.formula_service import FormulaService
|
||||
|
||||
|
||||
class CustomValueService:
|
||||
"""Service for managing custom field values on tasks."""
|
||||
|
||||
@staticmethod
|
||||
def get_custom_values_for_task(
|
||||
db: Session,
|
||||
task: Task,
|
||||
include_formula_calculations: bool = True,
|
||||
) -> List[CustomValueResponse]:
|
||||
"""
|
||||
Get all custom field values for a task.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
task: The task to get values for
|
||||
include_formula_calculations: Whether to calculate formula field values
|
||||
|
||||
Returns:
|
||||
List of CustomValueResponse objects
|
||||
"""
|
||||
# Get all custom fields for the project
|
||||
fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == task.project_id
|
||||
).order_by(CustomField.position).all()
|
||||
|
||||
if not fields:
|
||||
return []
|
||||
|
||||
# Get stored values
|
||||
stored_values = db.query(TaskCustomValue).filter(
|
||||
TaskCustomValue.task_id == task.id
|
||||
).all()
|
||||
|
||||
value_map = {v.field_id: v.value for v in stored_values}
|
||||
|
||||
# Calculate formula values if requested
|
||||
formula_values = {}
|
||||
if include_formula_calculations:
|
||||
formula_values = FormulaService.calculate_all_formulas_for_task(db, task)
|
||||
|
||||
result = []
|
||||
for field in fields:
|
||||
if field.field_type == "formula":
|
||||
# Use calculated formula value
|
||||
calculated = formula_values.get(field.id)
|
||||
value = str(calculated) if calculated is not None else None
|
||||
display_value = CustomValueService._format_display_value(
|
||||
field, value, db
|
||||
)
|
||||
else:
|
||||
# Use stored value
|
||||
value = value_map.get(field.id)
|
||||
display_value = CustomValueService._format_display_value(
|
||||
field, value, db
|
||||
)
|
||||
|
||||
result.append(CustomValueResponse(
|
||||
field_id=field.id,
|
||||
field_name=field.name,
|
||||
field_type=field.field_type,
|
||||
value=value,
|
||||
display_value=display_value,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _format_display_value(
|
||||
field: CustomField,
|
||||
value: Optional[str],
|
||||
db: Session,
|
||||
) -> Optional[str]:
|
||||
"""Format a value for display based on field type."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
field_type = field.field_type
|
||||
|
||||
if field_type == "person":
|
||||
# Look up user name
|
||||
from app.models import User
|
||||
user = db.query(User).filter(User.id == value).first()
|
||||
return user.name if user else value
|
||||
|
||||
elif field_type == "number" or field_type == "formula":
|
||||
# Format number
|
||||
try:
|
||||
num = Decimal(value)
|
||||
# Remove trailing zeros after decimal point
|
||||
formatted = f"{num:,.4f}".rstrip('0').rstrip('.')
|
||||
return formatted
|
||||
except (InvalidOperation, ValueError):
|
||||
return value
|
||||
|
||||
elif field_type == "date":
|
||||
# Format date
|
||||
try:
|
||||
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
|
||||
return dt.strftime('%Y-%m-%d')
|
||||
except (ValueError, AttributeError):
|
||||
return value
|
||||
|
||||
else:
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def save_custom_values(
|
||||
db: Session,
|
||||
task: Task,
|
||||
custom_values: List[CustomValueInput],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Save custom field values for a task.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
task: The task to save values for
|
||||
custom_values: List of values to save
|
||||
|
||||
Returns:
|
||||
List of field IDs that were updated (for formula recalculation)
|
||||
"""
|
||||
if not custom_values:
|
||||
return []
|
||||
|
||||
updated_field_ids = []
|
||||
|
||||
for cv in custom_values:
|
||||
field = db.query(CustomField).filter(
|
||||
CustomField.id == cv.field_id,
|
||||
CustomField.project_id == task.project_id,
|
||||
).first()
|
||||
|
||||
if not field:
|
||||
continue
|
||||
|
||||
# Skip formula fields - they are calculated, not stored directly
|
||||
if field.field_type == "formula":
|
||||
continue
|
||||
|
||||
# Validate value based on field type
|
||||
validated_value = CustomValueService._validate_value(
|
||||
field, cv.value, db
|
||||
)
|
||||
|
||||
# Find existing value or create new
|
||||
existing = db.query(TaskCustomValue).filter(
|
||||
TaskCustomValue.task_id == task.id,
|
||||
TaskCustomValue.field_id == cv.field_id,
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
if existing.value != validated_value:
|
||||
existing.value = validated_value
|
||||
updated_field_ids.append(cv.field_id)
|
||||
else:
|
||||
new_value = TaskCustomValue(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=task.id,
|
||||
field_id=cv.field_id,
|
||||
value=validated_value,
|
||||
)
|
||||
db.add(new_value)
|
||||
updated_field_ids.append(cv.field_id)
|
||||
|
||||
# Recalculate formula fields if any values were updated
|
||||
if updated_field_ids:
|
||||
for field_id in updated_field_ids:
|
||||
FormulaService.recalculate_dependent_formulas(db, task, field_id)
|
||||
|
||||
return updated_field_ids
|
||||
|
||||
@staticmethod
|
||||
def _validate_value(
|
||||
field: CustomField,
|
||||
value: Any,
|
||||
db: Session,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Validate and normalize a value based on field type.
|
||||
|
||||
Returns the validated value as a string, or None.
|
||||
"""
|
||||
if value is None or value == "":
|
||||
if field.is_required:
|
||||
raise ValueError(f"Field '{field.name}' is required")
|
||||
return None
|
||||
|
||||
field_type = field.field_type
|
||||
str_value = str(value)
|
||||
|
||||
if field_type == "text":
|
||||
return str_value
|
||||
|
||||
elif field_type == "number":
|
||||
try:
|
||||
Decimal(str_value)
|
||||
return str_value
|
||||
except (InvalidOperation, ValueError):
|
||||
raise ValueError(f"Invalid number for field '{field.name}'")
|
||||
|
||||
elif field_type == "dropdown":
|
||||
if field.options and str_value not in field.options:
|
||||
raise ValueError(
|
||||
f"Invalid option for field '{field.name}'. "
|
||||
f"Must be one of: {', '.join(field.options)}"
|
||||
)
|
||||
return str_value
|
||||
|
||||
elif field_type == "date":
|
||||
# Validate date format
|
||||
try:
|
||||
datetime.fromisoformat(str_value.replace('Z', '+00:00'))
|
||||
return str_value
|
||||
except ValueError:
|
||||
# Try parsing as date only
|
||||
try:
|
||||
datetime.strptime(str_value, '%Y-%m-%d')
|
||||
return str_value
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid date for field '{field.name}'")
|
||||
|
||||
elif field_type == "person":
|
||||
# Validate user exists
|
||||
from app.models import User
|
||||
user = db.query(User).filter(User.id == str_value).first()
|
||||
if not user:
|
||||
raise ValueError(f"Invalid user ID for field '{field.name}'")
|
||||
return str_value
|
||||
|
||||
return str_value
|
||||
|
||||
@staticmethod
|
||||
def validate_required_fields(
|
||||
db: Session,
|
||||
project_id: str,
|
||||
custom_values: Optional[List[CustomValueInput]],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Validate that all required custom fields have values.
|
||||
|
||||
Returns list of missing required field names.
|
||||
"""
|
||||
required_fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == project_id,
|
||||
CustomField.is_required == True,
|
||||
CustomField.field_type != "formula", # Formula fields are calculated
|
||||
).all()
|
||||
|
||||
if not required_fields:
|
||||
return []
|
||||
|
||||
provided_field_ids = set()
|
||||
if custom_values:
|
||||
for cv in custom_values:
|
||||
if cv.value is not None and cv.value != "":
|
||||
provided_field_ids.add(cv.field_id)
|
||||
|
||||
missing = []
|
||||
for field in required_fields:
|
||||
if field.id not in provided_field_ids:
|
||||
missing.append(field.name)
|
||||
|
||||
return missing
|
||||
424
backend/app/services/dependency_service.py
Normal file
424
backend/app/services/dependency_service.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
Dependency Service
|
||||
|
||||
Handles task dependency validation including:
|
||||
- Circular dependency detection using DFS
|
||||
- Date constraint validation based on dependency types
|
||||
- Self-reference prevention
|
||||
- Cross-project dependency prevention
|
||||
"""
|
||||
from typing import List, Optional, Set, Tuple, Dict, Any
|
||||
from collections import defaultdict
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.models import Task, TaskDependency
|
||||
|
||||
|
||||
class DependencyValidationError(Exception):
|
||||
"""Custom exception for dependency validation errors."""
|
||||
|
||||
def __init__(self, error_type: str, message: str, details: Optional[dict] = None):
|
||||
self.error_type = error_type
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class DependencyService:
|
||||
"""Service for managing task dependencies with validation."""
|
||||
|
||||
# Maximum number of direct dependencies per task (as per spec)
|
||||
MAX_DIRECT_DEPENDENCIES = 10
|
||||
|
||||
@staticmethod
|
||||
def detect_circular_dependency(
|
||||
db: Session,
|
||||
predecessor_id: str,
|
||||
successor_id: str,
|
||||
project_id: str
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Detect if adding a dependency would create a circular reference.
|
||||
|
||||
Uses DFS to traverse from the successor to check if we can reach
|
||||
the predecessor through existing dependencies.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
predecessor_id: The task that must complete first
|
||||
successor_id: The task that depends on the predecessor
|
||||
project_id: Project ID to scope the query
|
||||
|
||||
Returns:
|
||||
List of task IDs forming the cycle if circular, None otherwise
|
||||
"""
|
||||
# If adding predecessor -> successor, check if successor can reach predecessor
|
||||
# This would mean predecessor depends (transitively) on successor, creating a cycle
|
||||
|
||||
# Build adjacency list for the project's dependencies
|
||||
dependencies = db.query(TaskDependency).join(
|
||||
Task, TaskDependency.successor_id == Task.id
|
||||
).filter(Task.project_id == project_id).all()
|
||||
|
||||
# Graph: successor -> [predecessors]
|
||||
# We need to check if predecessor is reachable from successor
|
||||
# by following the chain of "what does this task depend on"
|
||||
graph: Dict[str, List[str]] = defaultdict(list)
|
||||
for dep in dependencies:
|
||||
graph[dep.successor_id].append(dep.predecessor_id)
|
||||
|
||||
# Simulate adding the new edge
|
||||
graph[successor_id].append(predecessor_id)
|
||||
|
||||
# DFS to find if there's a path from predecessor back to successor
|
||||
# (which would complete a cycle)
|
||||
visited: Set[str] = set()
|
||||
path: List[str] = []
|
||||
in_path: Set[str] = set()
|
||||
|
||||
def dfs(node: str) -> Optional[List[str]]:
|
||||
"""DFS traversal to detect cycles."""
|
||||
if node in in_path:
|
||||
# Found a cycle - return the cycle path
|
||||
cycle_start = path.index(node)
|
||||
return path[cycle_start:] + [node]
|
||||
|
||||
if node in visited:
|
||||
return None
|
||||
|
||||
visited.add(node)
|
||||
in_path.add(node)
|
||||
path.append(node)
|
||||
|
||||
for neighbor in graph.get(node, []):
|
||||
result = dfs(neighbor)
|
||||
if result:
|
||||
return result
|
||||
|
||||
path.pop()
|
||||
in_path.remove(node)
|
||||
return None
|
||||
|
||||
# Start DFS from the successor to check if we can reach back to it
|
||||
return dfs(successor_id)
|
||||
|
||||
@staticmethod
|
||||
def validate_dependency(
|
||||
db: Session,
|
||||
predecessor_id: str,
|
||||
successor_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a dependency can be created.
|
||||
|
||||
Raises DependencyValidationError if validation fails.
|
||||
|
||||
Checks:
|
||||
1. Self-reference
|
||||
2. Both tasks exist
|
||||
3. Both tasks are in the same project
|
||||
4. No duplicate dependency
|
||||
5. No circular dependency
|
||||
6. Dependency limit not exceeded
|
||||
"""
|
||||
# Check self-reference
|
||||
if predecessor_id == successor_id:
|
||||
raise DependencyValidationError(
|
||||
error_type="self_reference",
|
||||
message="A task cannot depend on itself"
|
||||
)
|
||||
|
||||
# Get both tasks
|
||||
predecessor = db.query(Task).filter(Task.id == predecessor_id).first()
|
||||
successor = db.query(Task).filter(Task.id == successor_id).first()
|
||||
|
||||
if not predecessor:
|
||||
raise DependencyValidationError(
|
||||
error_type="not_found",
|
||||
message="Predecessor task not found",
|
||||
details={"task_id": predecessor_id}
|
||||
)
|
||||
|
||||
if not successor:
|
||||
raise DependencyValidationError(
|
||||
error_type="not_found",
|
||||
message="Successor task not found",
|
||||
details={"task_id": successor_id}
|
||||
)
|
||||
|
||||
# Check same project
|
||||
if predecessor.project_id != successor.project_id:
|
||||
raise DependencyValidationError(
|
||||
error_type="cross_project",
|
||||
message="Dependencies can only be created between tasks in the same project",
|
||||
details={
|
||||
"predecessor_project_id": predecessor.project_id,
|
||||
"successor_project_id": successor.project_id
|
||||
}
|
||||
)
|
||||
|
||||
# Check duplicate
|
||||
existing = db.query(TaskDependency).filter(
|
||||
TaskDependency.predecessor_id == predecessor_id,
|
||||
TaskDependency.successor_id == successor_id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise DependencyValidationError(
|
||||
error_type="duplicate",
|
||||
message="This dependency already exists"
|
||||
)
|
||||
|
||||
# Check dependency limit
|
||||
current_count = db.query(TaskDependency).filter(
|
||||
TaskDependency.successor_id == successor_id
|
||||
).count()
|
||||
|
||||
if current_count >= DependencyService.MAX_DIRECT_DEPENDENCIES:
|
||||
raise DependencyValidationError(
|
||||
error_type="limit_exceeded",
|
||||
message=f"A task cannot have more than {DependencyService.MAX_DIRECT_DEPENDENCIES} direct dependencies",
|
||||
details={"current_count": current_count}
|
||||
)
|
||||
|
||||
# Check circular dependency
|
||||
cycle = DependencyService.detect_circular_dependency(
|
||||
db, predecessor_id, successor_id, predecessor.project_id
|
||||
)
|
||||
|
||||
if cycle:
|
||||
raise DependencyValidationError(
|
||||
error_type="circular",
|
||||
message="Adding this dependency would create a circular reference",
|
||||
details={"cycle": cycle}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_date_constraints(
|
||||
task: Task,
|
||||
start_date: Optional[datetime],
|
||||
due_date: Optional[datetime],
|
||||
db: Session
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Validate date changes against dependency constraints.
|
||||
|
||||
Returns a list of constraint violations (empty if valid).
|
||||
|
||||
Dependency type meanings:
|
||||
- FS: predecessor.due_date + lag <= successor.start_date
|
||||
- SS: predecessor.start_date + lag <= successor.start_date
|
||||
- FF: predecessor.due_date + lag <= successor.due_date
|
||||
- SF: predecessor.start_date + lag <= successor.due_date
|
||||
"""
|
||||
violations = []
|
||||
|
||||
# Use provided dates or fall back to current task dates
|
||||
new_start = start_date if start_date is not None else task.start_date
|
||||
new_due = due_date if due_date is not None else task.due_date
|
||||
|
||||
# Basic date validation
|
||||
if new_start and new_due and new_start > new_due:
|
||||
violations.append({
|
||||
"type": "date_order",
|
||||
"message": "Start date cannot be after due date",
|
||||
"start_date": str(new_start),
|
||||
"due_date": str(new_due)
|
||||
})
|
||||
|
||||
# Get dependencies where this task is the successor (predecessors)
|
||||
predecessors = db.query(TaskDependency).filter(
|
||||
TaskDependency.successor_id == task.id
|
||||
).all()
|
||||
|
||||
for dep in predecessors:
|
||||
pred_task = dep.predecessor
|
||||
if not pred_task:
|
||||
continue
|
||||
|
||||
lag = timedelta(days=dep.lag_days)
|
||||
violation = None
|
||||
|
||||
if dep.dependency_type == "FS":
|
||||
# Predecessor must finish before successor starts
|
||||
if pred_task.due_date and new_start:
|
||||
required_start = pred_task.due_date + lag
|
||||
if new_start < required_start:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "FS",
|
||||
"predecessor_id": pred_task.id,
|
||||
"predecessor_title": pred_task.title,
|
||||
"message": f"Start date must be on or after {required_start.date()} (predecessor due date + {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
elif dep.dependency_type == "SS":
|
||||
# Predecessor must start before successor starts
|
||||
if pred_task.start_date and new_start:
|
||||
required_start = pred_task.start_date + lag
|
||||
if new_start < required_start:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "SS",
|
||||
"predecessor_id": pred_task.id,
|
||||
"predecessor_title": pred_task.title,
|
||||
"message": f"Start date must be on or after {required_start.date()} (predecessor start date + {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
elif dep.dependency_type == "FF":
|
||||
# Predecessor must finish before successor finishes
|
||||
if pred_task.due_date and new_due:
|
||||
required_due = pred_task.due_date + lag
|
||||
if new_due < required_due:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "FF",
|
||||
"predecessor_id": pred_task.id,
|
||||
"predecessor_title": pred_task.title,
|
||||
"message": f"Due date must be on or after {required_due.date()} (predecessor due date + {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
elif dep.dependency_type == "SF":
|
||||
# Predecessor must start before successor finishes
|
||||
if pred_task.start_date and new_due:
|
||||
required_due = pred_task.start_date + lag
|
||||
if new_due < required_due:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "SF",
|
||||
"predecessor_id": pred_task.id,
|
||||
"predecessor_title": pred_task.title,
|
||||
"message": f"Due date must be on or after {required_due.date()} (predecessor start date + {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
if violation:
|
||||
violations.append(violation)
|
||||
|
||||
# Get dependencies where this task is the predecessor (successors)
|
||||
successors = db.query(TaskDependency).filter(
|
||||
TaskDependency.predecessor_id == task.id
|
||||
).all()
|
||||
|
||||
for dep in successors:
|
||||
succ_task = dep.successor
|
||||
if not succ_task:
|
||||
continue
|
||||
|
||||
lag = timedelta(days=dep.lag_days)
|
||||
violation = None
|
||||
|
||||
if dep.dependency_type == "FS":
|
||||
# This task must finish before successor starts
|
||||
if new_due and succ_task.start_date:
|
||||
required_due = succ_task.start_date - lag
|
||||
if new_due > required_due:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "FS",
|
||||
"successor_id": succ_task.id,
|
||||
"successor_title": succ_task.title,
|
||||
"message": f"Due date must be on or before {required_due.date()} (successor start date - {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
elif dep.dependency_type == "SS":
|
||||
# This task must start before successor starts
|
||||
if new_start and succ_task.start_date:
|
||||
required_start = succ_task.start_date - lag
|
||||
if new_start > required_start:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "SS",
|
||||
"successor_id": succ_task.id,
|
||||
"successor_title": succ_task.title,
|
||||
"message": f"Start date must be on or before {required_start.date()} (successor start date - {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
elif dep.dependency_type == "FF":
|
||||
# This task must finish before successor finishes
|
||||
if new_due and succ_task.due_date:
|
||||
required_due = succ_task.due_date - lag
|
||||
if new_due > required_due:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "FF",
|
||||
"successor_id": succ_task.id,
|
||||
"successor_title": succ_task.title,
|
||||
"message": f"Due date must be on or before {required_due.date()} (successor due date - {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
elif dep.dependency_type == "SF":
|
||||
# This task must start before successor finishes
|
||||
if new_start and succ_task.due_date:
|
||||
required_start = succ_task.due_date - lag
|
||||
if new_start > required_start:
|
||||
violation = {
|
||||
"type": "dependency_constraint",
|
||||
"dependency_type": "SF",
|
||||
"successor_id": succ_task.id,
|
||||
"successor_title": succ_task.title,
|
||||
"message": f"Start date must be on or before {required_start.date()} (successor due date - {dep.lag_days} days lag)"
|
||||
}
|
||||
|
||||
if violation:
|
||||
violations.append(violation)
|
||||
|
||||
return violations
|
||||
|
||||
@staticmethod
|
||||
def get_all_predecessors(db: Session, task_id: str) -> List[str]:
|
||||
"""
|
||||
Get all transitive predecessors of a task.
|
||||
|
||||
Uses BFS to find all tasks that this task depends on (directly or indirectly).
|
||||
"""
|
||||
visited: Set[str] = set()
|
||||
queue = [task_id]
|
||||
predecessors = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
if current in visited:
|
||||
continue
|
||||
|
||||
visited.add(current)
|
||||
|
||||
deps = db.query(TaskDependency).filter(
|
||||
TaskDependency.successor_id == current
|
||||
).all()
|
||||
|
||||
for dep in deps:
|
||||
if dep.predecessor_id not in visited:
|
||||
predecessors.append(dep.predecessor_id)
|
||||
queue.append(dep.predecessor_id)
|
||||
|
||||
return predecessors
|
||||
|
||||
@staticmethod
|
||||
def get_all_successors(db: Session, task_id: str) -> List[str]:
|
||||
"""
|
||||
Get all transitive successors of a task.
|
||||
|
||||
Uses BFS to find all tasks that depend on this task (directly or indirectly).
|
||||
"""
|
||||
visited: Set[str] = set()
|
||||
queue = [task_id]
|
||||
successors = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
if current in visited:
|
||||
continue
|
||||
|
||||
visited.add(current)
|
||||
|
||||
deps = db.query(TaskDependency).filter(
|
||||
TaskDependency.predecessor_id == current
|
||||
).all()
|
||||
|
||||
for dep in deps:
|
||||
if dep.successor_id not in visited:
|
||||
successors.append(dep.successor_id)
|
||||
queue.append(dep.successor_id)
|
||||
|
||||
return successors
|
||||
300
backend/app/services/encryption_service.py
Normal file
300
backend/app/services/encryption_service.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
Encryption service for AES-256-GCM file encryption.
|
||||
|
||||
This service handles:
|
||||
- File encryption key generation and management
|
||||
- Encrypting/decrypting file encryption keys with Master Key
|
||||
- Streaming file encryption/decryption with AES-256-GCM
|
||||
"""
|
||||
import os
|
||||
import base64
|
||||
import secrets
|
||||
import logging
|
||||
from typing import BinaryIO, Tuple, Optional, Generator
|
||||
from io import BytesIO
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
KEY_SIZE = 32 # 256 bits for AES-256
|
||||
NONCE_SIZE = 12 # 96 bits for GCM recommended nonce size
|
||||
TAG_SIZE = 16 # 128 bits for GCM authentication tag
|
||||
CHUNK_SIZE = 64 * 1024 # 64KB chunks for streaming
|
||||
|
||||
|
||||
class EncryptionError(Exception):
|
||||
"""Base exception for encryption errors."""
|
||||
pass
|
||||
|
||||
|
||||
class MasterKeyNotConfiguredError(EncryptionError):
|
||||
"""Raised when master key is not configured."""
|
||||
pass
|
||||
|
||||
|
||||
class DecryptionError(EncryptionError):
|
||||
"""Raised when decryption fails."""
|
||||
pass
|
||||
|
||||
|
||||
class EncryptionService:
|
||||
"""
|
||||
Service for file encryption using AES-256-GCM.
|
||||
|
||||
Key hierarchy:
|
||||
1. Master Key (from environment) -> encrypts file encryption keys
|
||||
2. File Encryption Keys (stored in DB) -> encrypt actual files
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._master_key: Optional[bytes] = None
|
||||
|
||||
@property
|
||||
def master_key(self) -> bytes:
|
||||
"""Get the master key, loading from config if needed."""
|
||||
if self._master_key is None:
|
||||
if not settings.ENCRYPTION_MASTER_KEY:
|
||||
raise MasterKeyNotConfiguredError(
|
||||
"ENCRYPTION_MASTER_KEY is not configured. "
|
||||
"File encryption is disabled."
|
||||
)
|
||||
self._master_key = base64.urlsafe_b64decode(settings.ENCRYPTION_MASTER_KEY)
|
||||
return self._master_key
|
||||
|
||||
def is_encryption_available(self) -> bool:
|
||||
"""Check if encryption is available (master key configured)."""
|
||||
return settings.ENCRYPTION_MASTER_KEY is not None
|
||||
|
||||
def generate_key(self) -> bytes:
|
||||
"""
|
||||
Generate a new AES-256 encryption key.
|
||||
|
||||
Returns:
|
||||
32-byte random key
|
||||
"""
|
||||
return secrets.token_bytes(KEY_SIZE)
|
||||
|
||||
def encrypt_key(self, key: bytes) -> str:
|
||||
"""
|
||||
Encrypt a file encryption key using the Master Key.
|
||||
|
||||
Args:
|
||||
key: The raw 32-byte file encryption key
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted key (nonce + ciphertext + tag)
|
||||
"""
|
||||
aesgcm = AESGCM(self.master_key)
|
||||
nonce = secrets.token_bytes(NONCE_SIZE)
|
||||
|
||||
# Encrypt the key
|
||||
ciphertext = aesgcm.encrypt(nonce, key, None)
|
||||
|
||||
# Combine nonce + ciphertext (includes tag)
|
||||
encrypted_data = nonce + ciphertext
|
||||
|
||||
return base64.urlsafe_b64encode(encrypted_data).decode('utf-8')
|
||||
|
||||
def decrypt_key(self, encrypted_key: str) -> bytes:
|
||||
"""
|
||||
Decrypt a file encryption key using the Master Key.
|
||||
|
||||
Args:
|
||||
encrypted_key: Base64-encoded encrypted key
|
||||
|
||||
Returns:
|
||||
The raw 32-byte file encryption key
|
||||
"""
|
||||
try:
|
||||
encrypted_data = base64.urlsafe_b64decode(encrypted_key)
|
||||
|
||||
# Extract nonce and ciphertext
|
||||
nonce = encrypted_data[:NONCE_SIZE]
|
||||
ciphertext = encrypted_data[NONCE_SIZE:]
|
||||
|
||||
# Decrypt
|
||||
aesgcm = AESGCM(self.master_key)
|
||||
return aesgcm.decrypt(nonce, ciphertext, None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt encryption key: {e}")
|
||||
raise DecryptionError("Failed to decrypt file encryption key")
|
||||
|
||||
def encrypt_file(self, file_content: BinaryIO, key: bytes) -> bytes:
|
||||
"""
|
||||
Encrypt file content using AES-256-GCM.
|
||||
|
||||
For smaller files, encrypts the entire content at once.
|
||||
The format is: nonce (12 bytes) + ciphertext + tag (16 bytes)
|
||||
|
||||
Args:
|
||||
file_content: File-like object to encrypt
|
||||
key: 32-byte AES-256 key
|
||||
|
||||
Returns:
|
||||
Encrypted bytes (nonce + ciphertext + tag)
|
||||
"""
|
||||
# Read all content
|
||||
plaintext = file_content.read()
|
||||
|
||||
# Generate nonce
|
||||
nonce = secrets.token_bytes(NONCE_SIZE)
|
||||
|
||||
# Encrypt
|
||||
aesgcm = AESGCM(key)
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
|
||||
|
||||
# Return nonce + ciphertext (tag is appended by encrypt)
|
||||
return nonce + ciphertext
|
||||
|
||||
def decrypt_file(self, encrypted_content: BinaryIO, key: bytes) -> bytes:
|
||||
"""
|
||||
Decrypt file content using AES-256-GCM.
|
||||
|
||||
Args:
|
||||
encrypted_content: File-like object containing encrypted data
|
||||
key: 32-byte AES-256 key
|
||||
|
||||
Returns:
|
||||
Decrypted bytes
|
||||
"""
|
||||
try:
|
||||
# Read all encrypted content
|
||||
encrypted_data = encrypted_content.read()
|
||||
|
||||
# Extract nonce and ciphertext
|
||||
nonce = encrypted_data[:NONCE_SIZE]
|
||||
ciphertext = encrypted_data[NONCE_SIZE:]
|
||||
|
||||
# Decrypt
|
||||
aesgcm = AESGCM(key)
|
||||
plaintext = aesgcm.decrypt(nonce, ciphertext, None)
|
||||
|
||||
return plaintext
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt file: {e}")
|
||||
raise DecryptionError("Failed to decrypt file. The file may be corrupted or the key is incorrect.")
|
||||
|
||||
def encrypt_file_streaming(self, file_content: BinaryIO, key: bytes) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Encrypt file content using AES-256-GCM with streaming.
|
||||
|
||||
For large files, encrypts in chunks. Each chunk has its own nonce.
|
||||
Format per chunk: chunk_size (4 bytes) + nonce (12 bytes) + ciphertext + tag
|
||||
|
||||
Args:
|
||||
file_content: File-like object to encrypt
|
||||
key: 32-byte AES-256 key
|
||||
|
||||
Yields:
|
||||
Encrypted chunks
|
||||
"""
|
||||
aesgcm = AESGCM(key)
|
||||
|
||||
# Write header with version byte
|
||||
yield b'\x01' # Version 1 for streaming format
|
||||
|
||||
while True:
|
||||
chunk = file_content.read(CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
# Generate nonce for this chunk
|
||||
nonce = secrets.token_bytes(NONCE_SIZE)
|
||||
|
||||
# Encrypt chunk
|
||||
ciphertext = aesgcm.encrypt(nonce, chunk, None)
|
||||
|
||||
# Write chunk size (4 bytes, little endian)
|
||||
chunk_size = len(ciphertext) + NONCE_SIZE
|
||||
yield chunk_size.to_bytes(4, 'little')
|
||||
|
||||
# Write nonce + ciphertext
|
||||
yield nonce + ciphertext
|
||||
|
||||
# Write end marker (zero size)
|
||||
yield b'\x00\x00\x00\x00'
|
||||
|
||||
def decrypt_file_streaming(self, encrypted_content: BinaryIO, key: bytes) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Decrypt file content using AES-256-GCM with streaming.
|
||||
|
||||
Args:
|
||||
encrypted_content: File-like object containing encrypted data
|
||||
key: 32-byte AES-256 key
|
||||
|
||||
Yields:
|
||||
Decrypted chunks
|
||||
"""
|
||||
aesgcm = AESGCM(key)
|
||||
|
||||
# Read version byte
|
||||
version = encrypted_content.read(1)
|
||||
if version != b'\x01':
|
||||
raise DecryptionError(f"Unknown encryption format version")
|
||||
|
||||
while True:
|
||||
# Read chunk size
|
||||
size_bytes = encrypted_content.read(4)
|
||||
if len(size_bytes) < 4:
|
||||
raise DecryptionError("Unexpected end of file")
|
||||
|
||||
chunk_size = int.from_bytes(size_bytes, 'little')
|
||||
|
||||
# Check for end marker
|
||||
if chunk_size == 0:
|
||||
break
|
||||
|
||||
# Read chunk (nonce + ciphertext)
|
||||
chunk = encrypted_content.read(chunk_size)
|
||||
if len(chunk) < chunk_size:
|
||||
raise DecryptionError("Unexpected end of file")
|
||||
|
||||
# Extract nonce and ciphertext
|
||||
nonce = chunk[:NONCE_SIZE]
|
||||
ciphertext = chunk[NONCE_SIZE:]
|
||||
|
||||
try:
|
||||
# Decrypt
|
||||
plaintext = aesgcm.decrypt(nonce, ciphertext, None)
|
||||
yield plaintext
|
||||
except Exception as e:
|
||||
raise DecryptionError(f"Failed to decrypt chunk: {e}")
|
||||
|
||||
def encrypt_bytes(self, data: bytes, key: bytes) -> bytes:
|
||||
"""
|
||||
Encrypt bytes directly (convenience method).
|
||||
|
||||
Args:
|
||||
data: Bytes to encrypt
|
||||
key: 32-byte AES-256 key
|
||||
|
||||
Returns:
|
||||
Encrypted bytes
|
||||
"""
|
||||
return self.encrypt_file(BytesIO(data), key)
|
||||
|
||||
def decrypt_bytes(self, encrypted_data: bytes, key: bytes) -> bytes:
|
||||
"""
|
||||
Decrypt bytes directly (convenience method).
|
||||
|
||||
Args:
|
||||
encrypted_data: Encrypted bytes
|
||||
key: 32-byte AES-256 key
|
||||
|
||||
Returns:
|
||||
Decrypted bytes
|
||||
"""
|
||||
return self.decrypt_file(BytesIO(encrypted_data), key)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
encryption_service = EncryptionService()
|
||||
420
backend/app/services/formula_service.py
Normal file
420
backend/app/services/formula_service.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Formula Service for Custom Fields
|
||||
|
||||
Supports:
|
||||
- Basic math operations: +, -, *, /
|
||||
- Field references: {field_name}
|
||||
- Built-in task fields: {original_estimate}, {time_spent}
|
||||
- Parentheses for grouping
|
||||
|
||||
Example formulas:
|
||||
- "{time_spent} / {original_estimate} * 100"
|
||||
- "{cost_per_hour} * {hours_worked}"
|
||||
- "({field_a} + {field_b}) / 2"
|
||||
"""
|
||||
import re
|
||||
import ast
|
||||
import operator
|
||||
from typing import Dict, Any, Optional, List, Set, Tuple
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Task, CustomField, TaskCustomValue
|
||||
|
||||
|
||||
class FormulaError(Exception):
|
||||
"""Exception raised for formula parsing or calculation errors."""
|
||||
pass
|
||||
|
||||
|
||||
class CircularReferenceError(FormulaError):
|
||||
"""Exception raised when circular references are detected in formulas."""
|
||||
pass
|
||||
|
||||
|
||||
class FormulaService:
|
||||
"""Service for parsing and calculating formula fields."""
|
||||
|
||||
# Built-in task fields that can be referenced in formulas
|
||||
BUILTIN_FIELDS = {
|
||||
"original_estimate",
|
||||
"time_spent",
|
||||
}
|
||||
|
||||
# Supported operators
|
||||
OPERATORS = {
|
||||
ast.Add: operator.add,
|
||||
ast.Sub: operator.sub,
|
||||
ast.Mult: operator.mul,
|
||||
ast.Div: operator.truediv,
|
||||
ast.USub: operator.neg,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_field_references(formula: str) -> Set[str]:
|
||||
"""
|
||||
Extract all field references from a formula.
|
||||
|
||||
Field references are in the format {field_name}.
|
||||
Returns a set of field names referenced in the formula.
|
||||
"""
|
||||
pattern = r'\{([^}]+)\}'
|
||||
matches = re.findall(pattern, formula)
|
||||
return set(matches)
|
||||
|
||||
@staticmethod
|
||||
def validate_formula(
|
||||
formula: str,
|
||||
project_id: str,
|
||||
db: Session,
|
||||
current_field_id: Optional[str] = None,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate a formula expression.
|
||||
|
||||
Checks:
|
||||
1. Syntax is valid
|
||||
2. All referenced fields exist
|
||||
3. Referenced fields are number or formula type
|
||||
4. No circular references
|
||||
|
||||
Returns (is_valid, error_message)
|
||||
"""
|
||||
if not formula or not formula.strip():
|
||||
return False, "Formula cannot be empty"
|
||||
|
||||
# Extract field references
|
||||
references = FormulaService.extract_field_references(formula)
|
||||
|
||||
if not references:
|
||||
return False, "Formula must reference at least one field"
|
||||
|
||||
# Validate syntax by trying to parse
|
||||
try:
|
||||
# Replace field references with dummy numbers for syntax check
|
||||
test_formula = formula
|
||||
for ref in references:
|
||||
test_formula = test_formula.replace(f"{{{ref}}}", "1")
|
||||
|
||||
# Try to parse and evaluate with safe operations
|
||||
FormulaService._safe_eval(test_formula)
|
||||
except Exception as e:
|
||||
return False, f"Invalid formula syntax: {str(e)}"
|
||||
|
||||
# Separate builtin and custom field references
|
||||
custom_references = references - FormulaService.BUILTIN_FIELDS
|
||||
|
||||
# Validate custom field references exist and are numeric types
|
||||
if custom_references:
|
||||
fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == project_id,
|
||||
CustomField.name.in_(custom_references),
|
||||
).all()
|
||||
|
||||
found_names = {f.name for f in fields}
|
||||
missing = custom_references - found_names
|
||||
|
||||
if missing:
|
||||
return False, f"Unknown field references: {', '.join(missing)}"
|
||||
|
||||
# Check field types (must be number or formula)
|
||||
for field in fields:
|
||||
if field.field_type not in ("number", "formula"):
|
||||
return False, f"Field '{field.name}' is not a numeric type"
|
||||
|
||||
# Check for circular references
|
||||
if current_field_id:
|
||||
try:
|
||||
FormulaService._check_circular_references(
|
||||
db, project_id, current_field_id, references
|
||||
)
|
||||
except CircularReferenceError as e:
|
||||
return False, str(e)
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def _check_circular_references(
|
||||
db: Session,
|
||||
project_id: str,
|
||||
field_id: str,
|
||||
references: Set[str],
|
||||
visited: Optional[Set[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Check for circular references in formula fields.
|
||||
|
||||
Raises CircularReferenceError if a cycle is detected.
|
||||
"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
# Get the current field's name
|
||||
current_field = db.query(CustomField).filter(
|
||||
CustomField.id == field_id
|
||||
).first()
|
||||
|
||||
if current_field:
|
||||
if current_field.name in references:
|
||||
raise CircularReferenceError(
|
||||
f"Circular reference detected: field cannot reference itself"
|
||||
)
|
||||
|
||||
# Get all referenced formula fields
|
||||
custom_references = references - FormulaService.BUILTIN_FIELDS
|
||||
if not custom_references:
|
||||
return
|
||||
|
||||
formula_fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == project_id,
|
||||
CustomField.name.in_(custom_references),
|
||||
CustomField.field_type == "formula",
|
||||
).all()
|
||||
|
||||
for field in formula_fields:
|
||||
if field.id in visited:
|
||||
raise CircularReferenceError(
|
||||
f"Circular reference detected involving field '{field.name}'"
|
||||
)
|
||||
|
||||
visited.add(field.id)
|
||||
|
||||
if field.formula:
|
||||
nested_refs = FormulaService.extract_field_references(field.formula)
|
||||
if current_field and current_field.name in nested_refs:
|
||||
raise CircularReferenceError(
|
||||
f"Circular reference detected: '{field.name}' references the current field"
|
||||
)
|
||||
FormulaService._check_circular_references(
|
||||
db, project_id, field_id, nested_refs, visited
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _safe_eval(expression: str) -> Decimal:
|
||||
"""
|
||||
Safely evaluate a mathematical expression.
|
||||
|
||||
Only allows basic arithmetic operations (+, -, *, /).
|
||||
"""
|
||||
try:
|
||||
node = ast.parse(expression, mode='eval')
|
||||
return FormulaService._eval_node(node.body)
|
||||
except Exception as e:
|
||||
raise FormulaError(f"Failed to evaluate expression: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _eval_node(node: ast.AST) -> Decimal:
|
||||
"""Recursively evaluate an AST node."""
|
||||
if isinstance(node, ast.Constant):
|
||||
if isinstance(node.value, (int, float)):
|
||||
return Decimal(str(node.value))
|
||||
raise FormulaError(f"Invalid constant: {node.value}")
|
||||
|
||||
elif isinstance(node, ast.BinOp):
|
||||
left = FormulaService._eval_node(node.left)
|
||||
right = FormulaService._eval_node(node.right)
|
||||
op = FormulaService.OPERATORS.get(type(node.op))
|
||||
if op is None:
|
||||
raise FormulaError(f"Unsupported operator: {type(node.op).__name__}")
|
||||
|
||||
# Handle division by zero
|
||||
if isinstance(node.op, ast.Div) and right == 0:
|
||||
return Decimal('0') # Return 0 instead of raising error
|
||||
|
||||
return Decimal(str(op(float(left), float(right))))
|
||||
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
operand = FormulaService._eval_node(node.operand)
|
||||
op = FormulaService.OPERATORS.get(type(node.op))
|
||||
if op is None:
|
||||
raise FormulaError(f"Unsupported operator: {type(node.op).__name__}")
|
||||
return Decimal(str(op(float(operand))))
|
||||
|
||||
else:
|
||||
raise FormulaError(f"Unsupported expression type: {type(node).__name__}")
|
||||
|
||||
@staticmethod
|
||||
def calculate_formula(
|
||||
formula: str,
|
||||
task: Task,
|
||||
db: Session,
|
||||
calculated_cache: Optional[Dict[str, Decimal]] = None,
|
||||
) -> Optional[Decimal]:
|
||||
"""
|
||||
Calculate the value of a formula for a given task.
|
||||
|
||||
Args:
|
||||
formula: The formula expression
|
||||
task: The task to calculate for
|
||||
db: Database session
|
||||
calculated_cache: Cache for already calculated formula values (for recursion)
|
||||
|
||||
Returns:
|
||||
The calculated value, or None if calculation fails
|
||||
"""
|
||||
if calculated_cache is None:
|
||||
calculated_cache = {}
|
||||
|
||||
references = FormulaService.extract_field_references(formula)
|
||||
values: Dict[str, Decimal] = {}
|
||||
|
||||
# Get builtin field values
|
||||
for ref in references:
|
||||
if ref in FormulaService.BUILTIN_FIELDS:
|
||||
task_value = getattr(task, ref, None)
|
||||
if task_value is not None:
|
||||
values[ref] = Decimal(str(task_value))
|
||||
else:
|
||||
values[ref] = Decimal('0')
|
||||
|
||||
# Get custom field values
|
||||
custom_references = references - FormulaService.BUILTIN_FIELDS
|
||||
if custom_references:
|
||||
# Get field definitions
|
||||
fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == task.project_id,
|
||||
CustomField.name.in_(custom_references),
|
||||
).all()
|
||||
|
||||
field_map = {f.name: f for f in fields}
|
||||
|
||||
# Get custom values for this task
|
||||
custom_values = db.query(TaskCustomValue).filter(
|
||||
TaskCustomValue.task_id == task.id,
|
||||
TaskCustomValue.field_id.in_([f.id for f in fields]),
|
||||
).all()
|
||||
|
||||
value_map = {cv.field_id: cv.value for cv in custom_values}
|
||||
|
||||
for ref in custom_references:
|
||||
field = field_map.get(ref)
|
||||
if not field:
|
||||
values[ref] = Decimal('0')
|
||||
continue
|
||||
|
||||
if field.field_type == "formula":
|
||||
# Recursively calculate formula fields
|
||||
if field.id in calculated_cache:
|
||||
values[ref] = calculated_cache[field.id]
|
||||
else:
|
||||
nested_value = FormulaService.calculate_formula(
|
||||
field.formula, task, db, calculated_cache
|
||||
)
|
||||
values[ref] = nested_value if nested_value is not None else Decimal('0')
|
||||
calculated_cache[field.id] = values[ref]
|
||||
else:
|
||||
# Get stored value
|
||||
stored_value = value_map.get(field.id)
|
||||
if stored_value:
|
||||
try:
|
||||
values[ref] = Decimal(str(stored_value))
|
||||
except (InvalidOperation, ValueError):
|
||||
values[ref] = Decimal('0')
|
||||
else:
|
||||
values[ref] = Decimal('0')
|
||||
|
||||
# Substitute values into formula
|
||||
expression = formula
|
||||
for ref, value in values.items():
|
||||
expression = expression.replace(f"{{{ref}}}", str(value))
|
||||
|
||||
# Evaluate the expression
|
||||
try:
|
||||
result = FormulaService._safe_eval(expression)
|
||||
# Round to 4 decimal places
|
||||
return result.quantize(Decimal('0.0001'))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def recalculate_dependent_formulas(
|
||||
db: Session,
|
||||
task: Task,
|
||||
changed_field_id: str,
|
||||
) -> Dict[str, Decimal]:
|
||||
"""
|
||||
Recalculate all formula fields that depend on a changed field.
|
||||
|
||||
Returns a dict of field_id -> calculated_value for updated formulas.
|
||||
"""
|
||||
# Get the changed field
|
||||
changed_field = db.query(CustomField).filter(
|
||||
CustomField.id == changed_field_id
|
||||
).first()
|
||||
|
||||
if not changed_field:
|
||||
return {}
|
||||
|
||||
# Find all formula fields in the project
|
||||
formula_fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == task.project_id,
|
||||
CustomField.field_type == "formula",
|
||||
).all()
|
||||
|
||||
results = {}
|
||||
calculated_cache: Dict[str, Decimal] = {}
|
||||
|
||||
for field in formula_fields:
|
||||
if not field.formula:
|
||||
continue
|
||||
|
||||
# Check if this formula depends on the changed field
|
||||
references = FormulaService.extract_field_references(field.formula)
|
||||
if changed_field.name in references or changed_field.name in FormulaService.BUILTIN_FIELDS:
|
||||
value = FormulaService.calculate_formula(
|
||||
field.formula, task, db, calculated_cache
|
||||
)
|
||||
if value is not None:
|
||||
results[field.id] = value
|
||||
calculated_cache[field.id] = value
|
||||
|
||||
# Update or create the custom value
|
||||
existing = db.query(TaskCustomValue).filter(
|
||||
TaskCustomValue.task_id == task.id,
|
||||
TaskCustomValue.field_id == field.id,
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
existing.value = str(value)
|
||||
else:
|
||||
import uuid
|
||||
new_value = TaskCustomValue(
|
||||
id=str(uuid.uuid4()),
|
||||
task_id=task.id,
|
||||
field_id=field.id,
|
||||
value=str(value),
|
||||
)
|
||||
db.add(new_value)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def calculate_all_formulas_for_task(
|
||||
db: Session,
|
||||
task: Task,
|
||||
) -> Dict[str, Decimal]:
|
||||
"""
|
||||
Calculate all formula fields for a task.
|
||||
|
||||
Used when loading a task to get current formula values.
|
||||
"""
|
||||
formula_fields = db.query(CustomField).filter(
|
||||
CustomField.project_id == task.project_id,
|
||||
CustomField.field_type == "formula",
|
||||
).all()
|
||||
|
||||
results = {}
|
||||
calculated_cache: Dict[str, Decimal] = {}
|
||||
|
||||
for field in formula_fields:
|
||||
if not field.formula:
|
||||
continue
|
||||
|
||||
value = FormulaService.calculate_formula(
|
||||
field.formula, task, db, calculated_cache
|
||||
)
|
||||
if value is not None:
|
||||
results[field.id] = value
|
||||
calculated_cache[field.id] = value
|
||||
|
||||
return results
|
||||
70
backend/migrations/versions/011_custom_fields_tables.py
Normal file
70
backend/migrations/versions/011_custom_fields_tables.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Add custom fields and task custom values tables
|
||||
|
||||
Revision ID: 011
|
||||
Revises: 010
|
||||
Create Date: 2026-01-05
|
||||
|
||||
FEAT-001: Add custom fields feature for flexible task data extension.
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '011'
|
||||
down_revision: Union[str, None] = '010'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create field_type_enum
|
||||
field_type_enum = sa.Enum(
|
||||
'text', 'number', 'dropdown', 'date', 'person', 'formula',
|
||||
name='field_type_enum'
|
||||
)
|
||||
field_type_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Create pjctrl_custom_fields table
|
||||
op.create_table(
|
||||
'pjctrl_custom_fields',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('project_id', sa.String(36), sa.ForeignKey('pjctrl_projects.id', ondelete='CASCADE'), nullable=False),
|
||||
sa.Column('name', sa.String(100), nullable=False),
|
||||
sa.Column('field_type', field_type_enum, nullable=False),
|
||||
sa.Column('options', sa.JSON, nullable=True),
|
||||
sa.Column('formula', sa.Text, nullable=True),
|
||||
sa.Column('is_required', sa.Boolean, default=False, nullable=False),
|
||||
sa.Column('position', sa.Integer, default=0, nullable=False),
|
||||
sa.Column('created_at', sa.DateTime, server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime, server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
|
||||
)
|
||||
|
||||
# Create indexes for custom_fields
|
||||
op.create_index('ix_pjctrl_custom_fields_project_id', 'pjctrl_custom_fields', ['project_id'])
|
||||
|
||||
# Create pjctrl_task_custom_values table
|
||||
op.create_table(
|
||||
'pjctrl_task_custom_values',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('task_id', sa.String(36), sa.ForeignKey('pjctrl_tasks.id', ondelete='CASCADE'), nullable=False),
|
||||
sa.Column('field_id', sa.String(36), sa.ForeignKey('pjctrl_custom_fields.id', ondelete='CASCADE'), nullable=False),
|
||||
sa.Column('value', sa.Text, nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime, server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False),
|
||||
sa.UniqueConstraint('task_id', 'field_id', name='uq_task_field'),
|
||||
)
|
||||
|
||||
# Create indexes for task_custom_values
|
||||
op.create_index('ix_pjctrl_task_custom_values_task_id', 'pjctrl_task_custom_values', ['task_id'])
|
||||
op.create_index('ix_pjctrl_task_custom_values_field_id', 'pjctrl_task_custom_values', ['field_id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_pjctrl_task_custom_values_field_id', table_name='pjctrl_task_custom_values')
|
||||
op.drop_index('ix_pjctrl_task_custom_values_task_id', table_name='pjctrl_task_custom_values')
|
||||
op.drop_table('pjctrl_task_custom_values')
|
||||
|
||||
op.drop_index('ix_pjctrl_custom_fields_project_id', table_name='pjctrl_custom_fields')
|
||||
op.drop_table('pjctrl_custom_fields')
|
||||
|
||||
# Drop the enum type
|
||||
sa.Enum(name='field_type_enum').drop(op.get_bind(), checkfirst=True)
|
||||
48
backend/migrations/versions/012_encryption_keys_table.py
Normal file
48
backend/migrations/versions/012_encryption_keys_table.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Encryption keys table and attachment encryption_key_id
|
||||
|
||||
Revision ID: 012
|
||||
Revises: 011
|
||||
Create Date: 2026-01-05
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers
|
||||
revision = '012'
|
||||
down_revision = '011'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Create encryption_keys table
|
||||
op.create_table(
|
||||
'pjctrl_encryption_keys',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('key_data', sa.Text, nullable=False), # Encrypted key using Master Key
|
||||
sa.Column('algorithm', sa.String(20), default='AES-256-GCM', nullable=False),
|
||||
sa.Column('is_active', sa.Boolean, default=True, nullable=False),
|
||||
sa.Column('created_at', sa.DateTime, server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('rotated_at', sa.DateTime, nullable=True),
|
||||
)
|
||||
op.create_index('idx_encryption_key_active', 'pjctrl_encryption_keys', ['is_active'])
|
||||
|
||||
# Add encryption_key_id column to attachments table
|
||||
op.add_column(
|
||||
'pjctrl_attachments',
|
||||
sa.Column(
|
||||
'encryption_key_id',
|
||||
sa.String(36),
|
||||
sa.ForeignKey('pjctrl_encryption_keys.id', ondelete='SET NULL'),
|
||||
nullable=True
|
||||
)
|
||||
)
|
||||
op.create_index('idx_attachment_encryption_key', 'pjctrl_attachments', ['encryption_key_id'])
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index('idx_attachment_encryption_key', 'pjctrl_attachments')
|
||||
op.drop_column('pjctrl_attachments', 'encryption_key_id')
|
||||
op.drop_index('idx_encryption_key_active', 'pjctrl_encryption_keys')
|
||||
op.drop_table('pjctrl_encryption_keys')
|
||||
92
backend/migrations/versions/013_task_dependencies_table.py
Normal file
92
backend/migrations/versions/013_task_dependencies_table.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Add start_date to tasks and create task dependencies table for Gantt view
|
||||
|
||||
Revision ID: 013
|
||||
Revises: 012
|
||||
Create Date: 2026-01-05
|
||||
|
||||
FEAT-003: Add Gantt view support with task dependencies and start dates.
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = '013'
|
||||
down_revision: Union[str, None] = '012'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add start_date column to pjctrl_tasks
|
||||
op.add_column(
|
||||
'pjctrl_tasks',
|
||||
sa.Column('start_date', sa.DateTime, nullable=True)
|
||||
)
|
||||
|
||||
# Create dependency_type_enum
|
||||
dependency_type_enum = sa.Enum(
|
||||
'FS', 'SS', 'FF', 'SF',
|
||||
name='dependency_type_enum'
|
||||
)
|
||||
dependency_type_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Create pjctrl_task_dependencies table
|
||||
op.create_table(
|
||||
'pjctrl_task_dependencies',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column(
|
||||
'predecessor_id',
|
||||
sa.String(36),
|
||||
sa.ForeignKey('pjctrl_tasks.id', ondelete='CASCADE'),
|
||||
nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
'successor_id',
|
||||
sa.String(36),
|
||||
sa.ForeignKey('pjctrl_tasks.id', ondelete='CASCADE'),
|
||||
nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
'dependency_type',
|
||||
dependency_type_enum,
|
||||
default='FS',
|
||||
nullable=False
|
||||
),
|
||||
sa.Column('lag_days', sa.Integer, default=0, nullable=False),
|
||||
sa.Column('created_at', sa.DateTime, server_default=sa.func.now(), nullable=False),
|
||||
# Unique constraint to prevent duplicate dependencies
|
||||
sa.UniqueConstraint('predecessor_id', 'successor_id', name='uq_predecessor_successor'),
|
||||
)
|
||||
|
||||
# Create indexes for efficient dependency lookups
|
||||
op.create_index(
|
||||
'ix_pjctrl_task_dependencies_predecessor_id',
|
||||
'pjctrl_task_dependencies',
|
||||
['predecessor_id']
|
||||
)
|
||||
op.create_index(
|
||||
'ix_pjctrl_task_dependencies_successor_id',
|
||||
'pjctrl_task_dependencies',
|
||||
['successor_id']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes
|
||||
op.drop_index(
|
||||
'ix_pjctrl_task_dependencies_successor_id',
|
||||
table_name='pjctrl_task_dependencies'
|
||||
)
|
||||
op.drop_index(
|
||||
'ix_pjctrl_task_dependencies_predecessor_id',
|
||||
table_name='pjctrl_task_dependencies'
|
||||
)
|
||||
|
||||
# Drop the table
|
||||
op.drop_table('pjctrl_task_dependencies')
|
||||
|
||||
# Drop the enum type
|
||||
sa.Enum(name='dependency_type_enum').drop(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Remove start_date column from tasks
|
||||
op.drop_column('pjctrl_tasks', 'start_date')
|
||||
440
backend/tests/test_custom_fields.py
Normal file
440
backend/tests/test_custom_fields.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
Tests for Custom Fields feature.
|
||||
"""
|
||||
import pytest
|
||||
from app.models import User, Space, Project, Task, TaskStatus, CustomField, TaskCustomValue
|
||||
from app.services.formula_service import FormulaService, FormulaError, CircularReferenceError
|
||||
|
||||
|
||||
class TestCustomFieldsCRUD:
|
||||
"""Test custom fields CRUD operations."""
|
||||
|
||||
def setup_project(self, db, owner_id: str):
|
||||
"""Create a space and project for testing."""
|
||||
space = Space(
|
||||
id="test-space-001",
|
||||
name="Test Space",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
db.add(space)
|
||||
|
||||
project = Project(
|
||||
id="test-project-001",
|
||||
space_id=space.id,
|
||||
title="Test Project",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
db.add(project)
|
||||
|
||||
# Add default task status
|
||||
status = TaskStatus(
|
||||
id="test-status-001",
|
||||
project_id=project.id,
|
||||
name="To Do",
|
||||
color="#3B82F6",
|
||||
position=0,
|
||||
)
|
||||
db.add(status)
|
||||
|
||||
db.commit()
|
||||
return project
|
||||
|
||||
def test_create_text_field(self, client, db, admin_token):
|
||||
"""Test creating a text custom field."""
|
||||
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
|
||||
|
||||
response = client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={
|
||||
"name": "Sprint Number",
|
||||
"field_type": "text",
|
||||
"is_required": False,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "Sprint Number"
|
||||
assert data["field_type"] == "text"
|
||||
assert data["is_required"] is False
|
||||
|
||||
def test_create_number_field(self, client, db, admin_token):
|
||||
"""Test creating a number custom field."""
|
||||
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
|
||||
|
||||
response = client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={
|
||||
"name": "Story Points",
|
||||
"field_type": "number",
|
||||
"is_required": True,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "Story Points"
|
||||
assert data["field_type"] == "number"
|
||||
assert data["is_required"] is True
|
||||
|
||||
def test_create_dropdown_field(self, client, db, admin_token):
|
||||
"""Test creating a dropdown custom field."""
|
||||
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
|
||||
|
||||
response = client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={
|
||||
"name": "Component",
|
||||
"field_type": "dropdown",
|
||||
"options": ["Frontend", "Backend", "Database", "API"],
|
||||
"is_required": False,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "Component"
|
||||
assert data["field_type"] == "dropdown"
|
||||
assert data["options"] == ["Frontend", "Backend", "Database", "API"]
|
||||
|
||||
def test_create_dropdown_field_without_options_fails(self, client, db, admin_token):
|
||||
"""Test that creating a dropdown field without options fails."""
|
||||
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
|
||||
|
||||
response = client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={
|
||||
"name": "Component",
|
||||
"field_type": "dropdown",
|
||||
"options": [],
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
def test_create_formula_field(self, client, db, admin_token):
|
||||
"""Test creating a formula custom field."""
|
||||
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
|
||||
|
||||
# First create a number field to reference
|
||||
client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={
|
||||
"name": "hours_worked",
|
||||
"field_type": "number",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
# Create formula field
|
||||
response = client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={
|
||||
"name": "Progress",
|
||||
"field_type": "formula",
|
||||
"formula": "{time_spent} / {original_estimate} * 100",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "Progress"
|
||||
assert data["field_type"] == "formula"
|
||||
assert "{time_spent}" in data["formula"]
|
||||
|
||||
def test_list_custom_fields(self, client, db, admin_token):
|
||||
"""Test listing custom fields for a project."""
|
||||
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
|
||||
|
||||
# Create some fields
|
||||
client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={"name": "Field 1", "field_type": "text"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={"name": "Field 2", "field_type": "number"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 2
|
||||
assert len(data["fields"]) == 2
|
||||
|
||||
def test_update_custom_field(self, client, db, admin_token):
|
||||
"""Test updating a custom field."""
|
||||
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
|
||||
|
||||
# Create a field
|
||||
create_response = client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={"name": "Original Name", "field_type": "text"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
field_id = create_response.json()["id"]
|
||||
|
||||
# Update it
|
||||
response = client.put(
|
||||
f"/api/custom-fields/{field_id}",
|
||||
json={"name": "Updated Name", "is_required": True},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated Name"
|
||||
assert data["is_required"] is True
|
||||
|
||||
def test_delete_custom_field(self, client, db, admin_token):
|
||||
"""Test deleting a custom field."""
|
||||
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
|
||||
|
||||
# Create a field
|
||||
create_response = client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={"name": "To Delete", "field_type": "text"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
field_id = create_response.json()["id"]
|
||||
|
||||
# Delete it
|
||||
response = client.delete(
|
||||
f"/api/custom-fields/{field_id}",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
# Verify it's gone
|
||||
get_response = client.get(
|
||||
f"/api/custom-fields/{field_id}",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert get_response.status_code == 404
|
||||
|
||||
def test_max_fields_limit(self, client, db, admin_token):
|
||||
"""Test that maximum 20 custom fields per project is enforced."""
|
||||
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
|
||||
|
||||
# Create 20 fields
|
||||
for i in range(20):
|
||||
response = client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={"name": f"Field {i}", "field_type": "text"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
|
||||
# Try to create the 21st field
|
||||
response = client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={"name": "Field 21", "field_type": "text"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Maximum" in response.json()["detail"]
|
||||
|
||||
def test_duplicate_name_rejected(self, client, db, admin_token):
|
||||
"""Test that duplicate field names are rejected."""
|
||||
project = self.setup_project(db, "00000000-0000-0000-0000-000000000001")
|
||||
|
||||
# Create a field
|
||||
client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={"name": "Unique Name", "field_type": "text"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
# Try to create another with same name
|
||||
response = client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={"name": "Unique Name", "field_type": "number"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestFormulaService:
|
||||
"""Test formula parsing and calculation."""
|
||||
|
||||
def test_extract_field_references(self):
|
||||
"""Test extracting field references from formulas."""
|
||||
formula = "{time_spent} / {original_estimate} * 100"
|
||||
refs = FormulaService.extract_field_references(formula)
|
||||
assert refs == {"time_spent", "original_estimate"}
|
||||
|
||||
def test_extract_multiple_references(self):
|
||||
"""Test extracting multiple field references."""
|
||||
formula = "{field_a} + {field_b} - {field_c}"
|
||||
refs = FormulaService.extract_field_references(formula)
|
||||
assert refs == {"field_a", "field_b", "field_c"}
|
||||
|
||||
def test_safe_eval_addition(self):
|
||||
"""Test safe evaluation of addition."""
|
||||
result = FormulaService._safe_eval("10 + 5")
|
||||
assert float(result) == 15.0
|
||||
|
||||
def test_safe_eval_division(self):
|
||||
"""Test safe evaluation of division."""
|
||||
result = FormulaService._safe_eval("20 / 4")
|
||||
assert float(result) == 5.0
|
||||
|
||||
def test_safe_eval_complex_expression(self):
|
||||
"""Test safe evaluation of complex expression."""
|
||||
result = FormulaService._safe_eval("(10 + 5) * 2 / 3")
|
||||
assert float(result) == 10.0
|
||||
|
||||
def test_safe_eval_division_by_zero(self):
|
||||
"""Test that division by zero returns 0."""
|
||||
result = FormulaService._safe_eval("10 / 0")
|
||||
assert float(result) == 0.0
|
||||
|
||||
def test_safe_eval_negative_numbers(self):
|
||||
"""Test safe evaluation with negative numbers."""
|
||||
result = FormulaService._safe_eval("-5 + 10")
|
||||
assert float(result) == 5.0
|
||||
|
||||
|
||||
class TestCustomValuesWithTasks:
|
||||
"""Test custom values integration with tasks."""
|
||||
|
||||
def setup_project_with_fields(self, db, client, admin_token, owner_id: str):
|
||||
"""Create a project with custom fields for testing."""
|
||||
space = Space(
|
||||
id="test-space-002",
|
||||
name="Test Space",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
db.add(space)
|
||||
|
||||
project = Project(
|
||||
id="test-project-002",
|
||||
space_id=space.id,
|
||||
title="Test Project",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
db.add(project)
|
||||
|
||||
status = TaskStatus(
|
||||
id="test-status-002",
|
||||
project_id=project.id,
|
||||
name="To Do",
|
||||
color="#3B82F6",
|
||||
position=0,
|
||||
)
|
||||
db.add(status)
|
||||
db.commit()
|
||||
|
||||
# Create custom fields via API
|
||||
text_response = client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={"name": "sprint_number", "field_type": "text"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
text_field_id = text_response.json()["id"]
|
||||
|
||||
number_response = client.post(
|
||||
f"/api/projects/{project.id}/custom-fields",
|
||||
json={"name": "story_points", "field_type": "number"},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
number_field_id = number_response.json()["id"]
|
||||
|
||||
return project, text_field_id, number_field_id
|
||||
|
||||
def test_create_task_with_custom_values(self, client, db, admin_token):
|
||||
"""Test creating a task with custom values."""
|
||||
project, text_field_id, number_field_id = self.setup_project_with_fields(
|
||||
db, client, admin_token, "00000000-0000-0000-0000-000000000001"
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
f"/api/projects/{project.id}/tasks",
|
||||
json={
|
||||
"title": "Test Task",
|
||||
"custom_values": [
|
||||
{"field_id": text_field_id, "value": "Sprint 5"},
|
||||
{"field_id": number_field_id, "value": "8"},
|
||||
],
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
|
||||
def test_get_task_includes_custom_values(self, client, db, admin_token):
|
||||
"""Test that getting a task includes custom values."""
|
||||
project, text_field_id, number_field_id = self.setup_project_with_fields(
|
||||
db, client, admin_token, "00000000-0000-0000-0000-000000000001"
|
||||
)
|
||||
|
||||
# Create task with custom values
|
||||
create_response = client.post(
|
||||
f"/api/projects/{project.id}/tasks",
|
||||
json={
|
||||
"title": "Test Task",
|
||||
"custom_values": [
|
||||
{"field_id": text_field_id, "value": "Sprint 5"},
|
||||
{"field_id": number_field_id, "value": "8"},
|
||||
],
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
task_id = create_response.json()["id"]
|
||||
|
||||
# Get task and check custom values
|
||||
get_response = client.get(
|
||||
f"/api/tasks/{task_id}",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert get_response.status_code == 200
|
||||
data = get_response.json()
|
||||
assert data["custom_values"] is not None
|
||||
assert len(data["custom_values"]) >= 2
|
||||
|
||||
def test_update_task_custom_values(self, client, db, admin_token):
|
||||
"""Test updating custom values on a task."""
|
||||
project, text_field_id, number_field_id = self.setup_project_with_fields(
|
||||
db, client, admin_token, "00000000-0000-0000-0000-000000000001"
|
||||
)
|
||||
|
||||
# Create task
|
||||
create_response = client.post(
|
||||
f"/api/projects/{project.id}/tasks",
|
||||
json={
|
||||
"title": "Test Task",
|
||||
"custom_values": [
|
||||
{"field_id": text_field_id, "value": "Sprint 5"},
|
||||
],
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
task_id = create_response.json()["id"]
|
||||
|
||||
# Update custom values
|
||||
update_response = client.patch(
|
||||
f"/api/tasks/{task_id}",
|
||||
json={
|
||||
"custom_values": [
|
||||
{"field_id": text_field_id, "value": "Sprint 6"},
|
||||
{"field_id": number_field_id, "value": "13"},
|
||||
],
|
||||
},
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert update_response.status_code == 200
|
||||
275
backend/tests/test_encryption.py
Normal file
275
backend/tests/test_encryption.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Tests for the file encryption functionality.
|
||||
|
||||
Tests cover:
|
||||
- Encryption service (key generation, encrypt/decrypt)
|
||||
- Encryption key management API
|
||||
- Attachment upload with encryption
|
||||
- Attachment download with decryption
|
||||
"""
|
||||
import pytest
|
||||
import base64
|
||||
import secrets
|
||||
from io import BytesIO
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.services.encryption_service import (
|
||||
EncryptionService,
|
||||
encryption_service,
|
||||
MasterKeyNotConfiguredError,
|
||||
DecryptionError,
|
||||
KEY_SIZE,
|
||||
NONCE_SIZE,
|
||||
)
|
||||
|
||||
|
||||
class TestEncryptionService:
|
||||
"""Tests for the encryption service."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_master_key(self):
|
||||
"""Generate a valid test master key."""
|
||||
return base64.urlsafe_b64encode(secrets.token_bytes(32)).decode()
|
||||
|
||||
@pytest.fixture
|
||||
def service_with_key(self, mock_master_key):
|
||||
"""Create an encryption service with a mock master key."""
|
||||
with patch('app.services.encryption_service.settings') as mock_settings:
|
||||
mock_settings.ENCRYPTION_MASTER_KEY = mock_master_key
|
||||
service = EncryptionService()
|
||||
yield service
|
||||
|
||||
def test_generate_key(self, service_with_key):
|
||||
"""Test that generate_key produces a 32-byte key."""
|
||||
key = service_with_key.generate_key()
|
||||
assert len(key) == KEY_SIZE
|
||||
assert isinstance(key, bytes)
|
||||
|
||||
def test_generate_key_uniqueness(self, service_with_key):
|
||||
"""Test that each generated key is unique."""
|
||||
keys = [service_with_key.generate_key() for _ in range(10)]
|
||||
unique_keys = set(keys)
|
||||
assert len(unique_keys) == 10
|
||||
|
||||
def test_encrypt_decrypt_key(self, service_with_key):
|
||||
"""Test encryption and decryption of a file encryption key."""
|
||||
# Generate a key to encrypt
|
||||
original_key = service_with_key.generate_key()
|
||||
|
||||
# Encrypt the key
|
||||
encrypted_key = service_with_key.encrypt_key(original_key)
|
||||
assert isinstance(encrypted_key, str)
|
||||
assert encrypted_key != base64.urlsafe_b64encode(original_key).decode()
|
||||
|
||||
# Decrypt the key
|
||||
decrypted_key = service_with_key.decrypt_key(encrypted_key)
|
||||
assert decrypted_key == original_key
|
||||
|
||||
def test_encrypt_decrypt_file(self, service_with_key):
|
||||
"""Test file encryption and decryption."""
|
||||
# Create test file content
|
||||
original_content = b"This is a test file content for encryption."
|
||||
file_obj = BytesIO(original_content)
|
||||
|
||||
# Generate encryption key
|
||||
key = service_with_key.generate_key()
|
||||
|
||||
# Encrypt
|
||||
encrypted_content = service_with_key.encrypt_file(file_obj, key)
|
||||
assert encrypted_content != original_content
|
||||
assert len(encrypted_content) > len(original_content) # Due to nonce and tag
|
||||
|
||||
# Decrypt
|
||||
encrypted_file = BytesIO(encrypted_content)
|
||||
decrypted_content = service_with_key.decrypt_file(encrypted_file, key)
|
||||
assert decrypted_content == original_content
|
||||
|
||||
def test_encrypt_decrypt_bytes(self, service_with_key):
|
||||
"""Test bytes encryption and decryption convenience methods."""
|
||||
original_data = b"Test data for encryption"
|
||||
key = service_with_key.generate_key()
|
||||
|
||||
# Encrypt
|
||||
encrypted_data = service_with_key.encrypt_bytes(original_data, key)
|
||||
assert encrypted_data != original_data
|
||||
|
||||
# Decrypt
|
||||
decrypted_data = service_with_key.decrypt_bytes(encrypted_data, key)
|
||||
assert decrypted_data == original_data
|
||||
|
||||
def test_encrypt_large_file(self, service_with_key):
|
||||
"""Test encryption of a larger file (1MB)."""
|
||||
# Create 1MB of random data
|
||||
original_content = secrets.token_bytes(1024 * 1024)
|
||||
file_obj = BytesIO(original_content)
|
||||
key = service_with_key.generate_key()
|
||||
|
||||
# Encrypt
|
||||
encrypted_content = service_with_key.encrypt_file(file_obj, key)
|
||||
|
||||
# Decrypt
|
||||
encrypted_file = BytesIO(encrypted_content)
|
||||
decrypted_content = service_with_key.decrypt_file(encrypted_file, key)
|
||||
|
||||
assert decrypted_content == original_content
|
||||
|
||||
def test_decrypt_with_wrong_key(self, service_with_key):
|
||||
"""Test that decryption fails with wrong key."""
|
||||
original_content = b"Secret content"
|
||||
file_obj = BytesIO(original_content)
|
||||
|
||||
key1 = service_with_key.generate_key()
|
||||
key2 = service_with_key.generate_key()
|
||||
|
||||
# Encrypt with key1
|
||||
encrypted_content = service_with_key.encrypt_file(file_obj, key1)
|
||||
|
||||
# Try to decrypt with key2
|
||||
encrypted_file = BytesIO(encrypted_content)
|
||||
with pytest.raises(DecryptionError):
|
||||
service_with_key.decrypt_file(encrypted_file, key2)
|
||||
|
||||
def test_decrypt_corrupted_data(self, service_with_key):
|
||||
"""Test that decryption fails with corrupted data."""
|
||||
original_content = b"Secret content"
|
||||
file_obj = BytesIO(original_content)
|
||||
key = service_with_key.generate_key()
|
||||
|
||||
# Encrypt
|
||||
encrypted_content = service_with_key.encrypt_file(file_obj, key)
|
||||
|
||||
# Corrupt the encrypted data
|
||||
corrupted = bytearray(encrypted_content)
|
||||
corrupted[20] ^= 0xFF # Flip some bits
|
||||
corrupted_content = bytes(corrupted)
|
||||
|
||||
# Try to decrypt
|
||||
encrypted_file = BytesIO(corrupted_content)
|
||||
with pytest.raises(DecryptionError):
|
||||
service_with_key.decrypt_file(encrypted_file, key)
|
||||
|
||||
def test_is_encryption_available_with_key(self, mock_master_key):
|
||||
"""Test is_encryption_available returns True when key is configured."""
|
||||
with patch('app.services.encryption_service.settings') as mock_settings:
|
||||
mock_settings.ENCRYPTION_MASTER_KEY = mock_master_key
|
||||
service = EncryptionService()
|
||||
assert service.is_encryption_available() is True
|
||||
|
||||
def test_is_encryption_available_without_key(self):
|
||||
"""Test is_encryption_available returns False when key is not configured."""
|
||||
with patch('app.services.encryption_service.settings') as mock_settings:
|
||||
mock_settings.ENCRYPTION_MASTER_KEY = None
|
||||
service = EncryptionService()
|
||||
assert service.is_encryption_available() is False
|
||||
|
||||
def test_master_key_not_configured_error(self):
|
||||
"""Test that operations fail when master key is not configured."""
|
||||
with patch('app.services.encryption_service.settings') as mock_settings:
|
||||
mock_settings.ENCRYPTION_MASTER_KEY = None
|
||||
service = EncryptionService()
|
||||
|
||||
key = secrets.token_bytes(32)
|
||||
with pytest.raises(MasterKeyNotConfiguredError):
|
||||
service.encrypt_key(key)
|
||||
|
||||
def test_encrypted_key_format(self, service_with_key):
|
||||
"""Test that encrypted key is valid base64."""
|
||||
key = service_with_key.generate_key()
|
||||
encrypted_key = service_with_key.encrypt_key(key)
|
||||
|
||||
# Should be valid base64
|
||||
decoded = base64.urlsafe_b64decode(encrypted_key)
|
||||
# Should contain nonce + ciphertext + tag
|
||||
assert len(decoded) >= NONCE_SIZE + KEY_SIZE + 16 # 16 = GCM tag size
|
||||
|
||||
|
||||
class TestEncryptionServiceStreaming:
|
||||
"""Tests for streaming encryption (for large files)."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_master_key(self):
|
||||
"""Generate a valid test master key."""
|
||||
return base64.urlsafe_b64encode(secrets.token_bytes(32)).decode()
|
||||
|
||||
@pytest.fixture
|
||||
def service_with_key(self, mock_master_key):
|
||||
"""Create an encryption service with a mock master key."""
|
||||
with patch('app.services.encryption_service.settings') as mock_settings:
|
||||
mock_settings.ENCRYPTION_MASTER_KEY = mock_master_key
|
||||
service = EncryptionService()
|
||||
yield service
|
||||
|
||||
def test_streaming_encrypt_decrypt(self, service_with_key):
|
||||
"""Test streaming encryption and decryption."""
|
||||
# Create test content
|
||||
original_content = b"Test content for streaming encryption. " * 1000
|
||||
file_obj = BytesIO(original_content)
|
||||
key = service_with_key.generate_key()
|
||||
|
||||
# Encrypt using streaming
|
||||
encrypted_chunks = list(service_with_key.encrypt_file_streaming(file_obj, key))
|
||||
encrypted_content = b''.join(encrypted_chunks)
|
||||
|
||||
# Decrypt using streaming
|
||||
encrypted_file = BytesIO(encrypted_content)
|
||||
decrypted_chunks = list(service_with_key.decrypt_file_streaming(encrypted_file, key))
|
||||
decrypted_content = b''.join(decrypted_chunks)
|
||||
|
||||
assert decrypted_content == original_content
|
||||
|
||||
|
||||
class TestEncryptionKeyValidation:
|
||||
"""Tests for encryption key validation in config."""
|
||||
|
||||
def test_valid_master_key(self):
|
||||
"""Test that a valid master key passes validation."""
|
||||
from app.core.config import Settings
|
||||
|
||||
valid_key = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode()
|
||||
|
||||
# This should not raise
|
||||
with patch.dict('os.environ', {
|
||||
'JWT_SECRET_KEY': 'test-secret-key-that-is-valid',
|
||||
'ENCRYPTION_MASTER_KEY': valid_key
|
||||
}):
|
||||
settings = Settings()
|
||||
assert settings.ENCRYPTION_MASTER_KEY == valid_key
|
||||
|
||||
def test_invalid_master_key_length(self):
|
||||
"""Test that an invalid length master key fails validation."""
|
||||
from app.core.config import Settings
|
||||
|
||||
# 16 bytes instead of 32
|
||||
invalid_key = base64.urlsafe_b64encode(secrets.token_bytes(16)).decode()
|
||||
|
||||
with patch.dict('os.environ', {
|
||||
'JWT_SECRET_KEY': 'test-secret-key-that-is-valid',
|
||||
'ENCRYPTION_MASTER_KEY': invalid_key
|
||||
}):
|
||||
with pytest.raises(ValueError, match="must be a base64-encoded 32-byte key"):
|
||||
Settings()
|
||||
|
||||
def test_invalid_master_key_format(self):
|
||||
"""Test that an invalid format master key fails validation."""
|
||||
from app.core.config import Settings
|
||||
from pydantic import ValidationError
|
||||
|
||||
invalid_key = "not-valid-base64!@#$"
|
||||
|
||||
with patch.dict('os.environ', {
|
||||
'JWT_SECRET_KEY': 'test-secret-key-that-is-valid',
|
||||
'ENCRYPTION_MASTER_KEY': invalid_key
|
||||
}):
|
||||
with pytest.raises(ValidationError, match="ENCRYPTION_MASTER_KEY"):
|
||||
Settings()
|
||||
|
||||
def test_empty_master_key_allowed(self):
|
||||
"""Test that empty master key is allowed (encryption disabled)."""
|
||||
from app.core.config import Settings
|
||||
|
||||
with patch.dict('os.environ', {
|
||||
'JWT_SECRET_KEY': 'test-secret-key-that-is-valid',
|
||||
'ENCRYPTION_MASTER_KEY': ''
|
||||
}):
|
||||
settings = Settings()
|
||||
assert settings.ENCRYPTION_MASTER_KEY is None
|
||||
1433
backend/tests/test_task_dependencies.py
Normal file
1433
backend/tests/test_task_dependencies.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -114,3 +114,383 @@ class TestSubtaskDepth:
|
||||
"""Test that MAX_SUBTASK_DEPTH is defined."""
|
||||
from app.api.tasks.router import MAX_SUBTASK_DEPTH
|
||||
assert MAX_SUBTASK_DEPTH == 2
|
||||
|
||||
|
||||
class TestDateRangeFilter:
|
||||
"""Test date range filter for calendar view."""
|
||||
|
||||
def test_due_after_filter(self, client, db, admin_token):
|
||||
"""Test filtering tasks with due_date >= due_after."""
|
||||
from datetime import datetime, timedelta
|
||||
from app.models import Space, Project, Task, TaskStatus
|
||||
|
||||
# Create test data
|
||||
space = Space(
|
||||
id="test-space-id",
|
||||
name="Test Space",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
)
|
||||
db.add(space)
|
||||
|
||||
project = Project(
|
||||
id="test-project-id",
|
||||
space_id="test-space-id",
|
||||
title="Test Project",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
security_level="public",
|
||||
)
|
||||
db.add(project)
|
||||
|
||||
status = TaskStatus(
|
||||
id="test-status-id",
|
||||
project_id="test-project-id",
|
||||
name="To Do",
|
||||
color="#808080",
|
||||
position=0,
|
||||
)
|
||||
db.add(status)
|
||||
|
||||
# Create tasks with different due dates
|
||||
now = datetime.now()
|
||||
task1 = Task(
|
||||
id="task-1",
|
||||
project_id="test-project-id",
|
||||
title="Task Due Yesterday",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id",
|
||||
due_date=now - timedelta(days=1),
|
||||
)
|
||||
task2 = Task(
|
||||
id="task-2",
|
||||
project_id="test-project-id",
|
||||
title="Task Due Today",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id",
|
||||
due_date=now,
|
||||
)
|
||||
task3 = Task(
|
||||
id="task-3",
|
||||
project_id="test-project-id",
|
||||
title="Task Due Tomorrow",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id",
|
||||
due_date=now + timedelta(days=1),
|
||||
)
|
||||
task4 = Task(
|
||||
id="task-4",
|
||||
project_id="test-project-id",
|
||||
title="Task No Due Date",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id",
|
||||
due_date=None,
|
||||
)
|
||||
db.add_all([task1, task2, task3, task4])
|
||||
db.commit()
|
||||
|
||||
# Filter tasks due today or later
|
||||
due_after = now.isoformat()
|
||||
response = client.get(
|
||||
f"/api/projects/test-project-id/tasks?due_after={due_after}",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should return task2 and task3 (due today and tomorrow)
|
||||
assert data["total"] == 2
|
||||
task_ids = [t["id"] for t in data["tasks"]]
|
||||
assert "task-2" in task_ids
|
||||
assert "task-3" in task_ids
|
||||
assert "task-1" not in task_ids
|
||||
assert "task-4" not in task_ids
|
||||
|
||||
def test_due_before_filter(self, client, db, admin_token):
|
||||
"""Test filtering tasks with due_date <= due_before."""
|
||||
from datetime import datetime, timedelta
|
||||
from app.models import Space, Project, Task, TaskStatus
|
||||
|
||||
# Create test data
|
||||
space = Space(
|
||||
id="test-space-id-2",
|
||||
name="Test Space 2",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
)
|
||||
db.add(space)
|
||||
|
||||
project = Project(
|
||||
id="test-project-id-2",
|
||||
space_id="test-space-id-2",
|
||||
title="Test Project 2",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
security_level="public",
|
||||
)
|
||||
db.add(project)
|
||||
|
||||
status = TaskStatus(
|
||||
id="test-status-id-2",
|
||||
project_id="test-project-id-2",
|
||||
name="To Do",
|
||||
color="#808080",
|
||||
position=0,
|
||||
)
|
||||
db.add(status)
|
||||
|
||||
# Create tasks with different due dates
|
||||
now = datetime.now()
|
||||
task1 = Task(
|
||||
id="task-b-1",
|
||||
project_id="test-project-id-2",
|
||||
title="Task Due Yesterday",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id-2",
|
||||
due_date=now - timedelta(days=1),
|
||||
)
|
||||
task2 = Task(
|
||||
id="task-b-2",
|
||||
project_id="test-project-id-2",
|
||||
title="Task Due Today",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id-2",
|
||||
due_date=now,
|
||||
)
|
||||
task3 = Task(
|
||||
id="task-b-3",
|
||||
project_id="test-project-id-2",
|
||||
title="Task Due Tomorrow",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id-2",
|
||||
due_date=now + timedelta(days=1),
|
||||
)
|
||||
db.add_all([task1, task2, task3])
|
||||
db.commit()
|
||||
|
||||
# Filter tasks due today or earlier
|
||||
due_before = now.isoformat()
|
||||
response = client.get(
|
||||
f"/api/projects/test-project-id-2/tasks?due_before={due_before}",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should return task1 and task2 (due yesterday and today)
|
||||
assert data["total"] == 2
|
||||
task_ids = [t["id"] for t in data["tasks"]]
|
||||
assert "task-b-1" in task_ids
|
||||
assert "task-b-2" in task_ids
|
||||
assert "task-b-3" not in task_ids
|
||||
|
||||
def test_date_range_filter_combined(self, client, db, admin_token):
|
||||
"""Test filtering tasks within a date range (due_after AND due_before)."""
|
||||
from datetime import datetime, timedelta
|
||||
from app.models import Space, Project, Task, TaskStatus
|
||||
|
||||
# Create test data
|
||||
space = Space(
|
||||
id="test-space-id-3",
|
||||
name="Test Space 3",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
)
|
||||
db.add(space)
|
||||
|
||||
project = Project(
|
||||
id="test-project-id-3",
|
||||
space_id="test-space-id-3",
|
||||
title="Test Project 3",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
security_level="public",
|
||||
)
|
||||
db.add(project)
|
||||
|
||||
status = TaskStatus(
|
||||
id="test-status-id-3",
|
||||
project_id="test-project-id-3",
|
||||
name="To Do",
|
||||
color="#808080",
|
||||
position=0,
|
||||
)
|
||||
db.add(status)
|
||||
|
||||
# Create tasks spanning a week
|
||||
now = datetime.now()
|
||||
start_of_week = now - timedelta(days=now.weekday()) # Monday
|
||||
end_of_week = start_of_week + timedelta(days=6) # Sunday
|
||||
|
||||
task_before = Task(
|
||||
id="task-c-before",
|
||||
project_id="test-project-id-3",
|
||||
title="Task Before Week",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id-3",
|
||||
due_date=start_of_week - timedelta(days=1),
|
||||
)
|
||||
task_in_week = Task(
|
||||
id="task-c-in-week",
|
||||
project_id="test-project-id-3",
|
||||
title="Task In Week",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id-3",
|
||||
due_date=start_of_week + timedelta(days=3),
|
||||
)
|
||||
task_after = Task(
|
||||
id="task-c-after",
|
||||
project_id="test-project-id-3",
|
||||
title="Task After Week",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id-3",
|
||||
due_date=end_of_week + timedelta(days=1),
|
||||
)
|
||||
db.add_all([task_before, task_in_week, task_after])
|
||||
db.commit()
|
||||
|
||||
# Filter tasks within this week
|
||||
due_after = start_of_week.isoformat()
|
||||
due_before = end_of_week.isoformat()
|
||||
response = client.get(
|
||||
f"/api/projects/test-project-id-3/tasks?due_after={due_after}&due_before={due_before}",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should only return the task in the week
|
||||
assert data["total"] == 1
|
||||
assert data["tasks"][0]["id"] == "task-c-in-week"
|
||||
|
||||
def test_date_filter_with_no_due_date(self, client, db, admin_token):
|
||||
"""Test that tasks without due_date are excluded from date range filters."""
|
||||
from datetime import datetime, timedelta
|
||||
from app.models import Space, Project, Task, TaskStatus
|
||||
|
||||
# Create test data
|
||||
space = Space(
|
||||
id="test-space-id-4",
|
||||
name="Test Space 4",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
)
|
||||
db.add(space)
|
||||
|
||||
project = Project(
|
||||
id="test-project-id-4",
|
||||
space_id="test-space-id-4",
|
||||
title="Test Project 4",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
security_level="public",
|
||||
)
|
||||
db.add(project)
|
||||
|
||||
status = TaskStatus(
|
||||
id="test-status-id-4",
|
||||
project_id="test-project-id-4",
|
||||
name="To Do",
|
||||
color="#808080",
|
||||
position=0,
|
||||
)
|
||||
db.add(status)
|
||||
|
||||
# Create tasks - some with due_date, some without
|
||||
now = datetime.now()
|
||||
task_with_date = Task(
|
||||
id="task-d-with-date",
|
||||
project_id="test-project-id-4",
|
||||
title="Task With Due Date",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id-4",
|
||||
due_date=now,
|
||||
)
|
||||
task_without_date = Task(
|
||||
id="task-d-without-date",
|
||||
project_id="test-project-id-4",
|
||||
title="Task Without Due Date",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id-4",
|
||||
due_date=None,
|
||||
)
|
||||
db.add_all([task_with_date, task_without_date])
|
||||
db.commit()
|
||||
|
||||
# When using date filter, tasks without due_date should be excluded
|
||||
due_after = (now - timedelta(days=1)).isoformat()
|
||||
due_before = (now + timedelta(days=1)).isoformat()
|
||||
response = client.get(
|
||||
f"/api/projects/test-project-id-4/tasks?due_after={due_after}&due_before={due_before}",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should only return the task with due_date
|
||||
assert data["total"] == 1
|
||||
assert data["tasks"][0]["id"] == "task-d-with-date"
|
||||
|
||||
def test_date_filter_backward_compatibility(self, client, db, admin_token):
|
||||
"""Test that not providing date filters returns all tasks (backward compatibility)."""
|
||||
from datetime import datetime, timedelta
|
||||
from app.models import Space, Project, Task, TaskStatus
|
||||
|
||||
# Create test data
|
||||
space = Space(
|
||||
id="test-space-id-5",
|
||||
name="Test Space 5",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
)
|
||||
db.add(space)
|
||||
|
||||
project = Project(
|
||||
id="test-project-id-5",
|
||||
space_id="test-space-id-5",
|
||||
title="Test Project 5",
|
||||
owner_id="00000000-0000-0000-0000-000000000001",
|
||||
security_level="public",
|
||||
)
|
||||
db.add(project)
|
||||
|
||||
status = TaskStatus(
|
||||
id="test-status-id-5",
|
||||
project_id="test-project-id-5",
|
||||
name="To Do",
|
||||
color="#808080",
|
||||
position=0,
|
||||
)
|
||||
db.add(status)
|
||||
|
||||
# Create tasks with and without due dates
|
||||
now = datetime.now()
|
||||
task1 = Task(
|
||||
id="task-e-1",
|
||||
project_id="test-project-id-5",
|
||||
title="Task 1",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id-5",
|
||||
due_date=now,
|
||||
)
|
||||
task2 = Task(
|
||||
id="task-e-2",
|
||||
project_id="test-project-id-5",
|
||||
title="Task 2",
|
||||
priority="medium",
|
||||
created_by="00000000-0000-0000-0000-000000000001",
|
||||
status_id="test-status-id-5",
|
||||
due_date=None,
|
||||
)
|
||||
db.add_all([task1, task2])
|
||||
db.commit()
|
||||
|
||||
# Request without date filters - should return all tasks
|
||||
response = client.get(
|
||||
"/api/projects/test-project-id-5/tasks",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should return both tasks
|
||||
assert data["total"] == 2
|
||||
|
||||
Reference in New Issue
Block a user