feat: implement custom fields, gantt view, calendar view, and file encryption

- Custom Fields (FEAT-001):
  - CustomField and TaskCustomValue models with formula support
  - CRUD API for custom field management
  - Formula engine for calculated fields
  - Frontend: CustomFieldEditor, CustomFieldInput, ProjectSettings page
  - Task list API now includes custom_values
  - KanbanBoard displays custom field values

- Gantt View (FEAT-003):
  - TaskDependency model with FS/SS/FF/SF dependency types
  - Dependency CRUD API with cycle detection
  - start_date field added to tasks
  - GanttChart component with Frappe Gantt integration
  - Dependency type selector in UI

- Calendar View (FEAT-004):
  - CalendarView component with FullCalendar integration
  - Date range filtering API for tasks
  - Drag-and-drop date updates
  - View mode switching in Tasks page

- File Encryption (FEAT-010):
  - AES-256-GCM encryption service
  - EncryptionKey model with key rotation support
  - Admin API for key management
  - Encrypted upload/download for confidential projects

- Migrations: 011 (custom fields), 012 (encryption keys), 013 (task dependencies)
- Updated issues.md with completion status

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-01-05 23:39:12 +08:00
parent 69b81d9241
commit 2d80a8384e
65 changed files with 11045 additions and 82 deletions

View File

@@ -0,0 +1,278 @@
"""
Service for managing task custom values.
"""
import uuid
from typing import List, Dict, Any, Optional
from decimal import Decimal, InvalidOperation
from datetime import datetime
from sqlalchemy.orm import Session
from app.models import Task, CustomField, TaskCustomValue, User
from app.schemas.task import CustomValueInput, CustomValueResponse
from app.services.formula_service import FormulaService
class CustomValueService:
"""Service for managing custom field values on tasks."""
@staticmethod
def get_custom_values_for_task(
db: Session,
task: Task,
include_formula_calculations: bool = True,
) -> List[CustomValueResponse]:
"""
Get all custom field values for a task.
Args:
db: Database session
task: The task to get values for
include_formula_calculations: Whether to calculate formula field values
Returns:
List of CustomValueResponse objects
"""
# Get all custom fields for the project
fields = db.query(CustomField).filter(
CustomField.project_id == task.project_id
).order_by(CustomField.position).all()
if not fields:
return []
# Get stored values
stored_values = db.query(TaskCustomValue).filter(
TaskCustomValue.task_id == task.id
).all()
value_map = {v.field_id: v.value for v in stored_values}
# Calculate formula values if requested
formula_values = {}
if include_formula_calculations:
formula_values = FormulaService.calculate_all_formulas_for_task(db, task)
result = []
for field in fields:
if field.field_type == "formula":
# Use calculated formula value
calculated = formula_values.get(field.id)
value = str(calculated) if calculated is not None else None
display_value = CustomValueService._format_display_value(
field, value, db
)
else:
# Use stored value
value = value_map.get(field.id)
display_value = CustomValueService._format_display_value(
field, value, db
)
result.append(CustomValueResponse(
field_id=field.id,
field_name=field.name,
field_type=field.field_type,
value=value,
display_value=display_value,
))
return result
@staticmethod
def _format_display_value(
field: CustomField,
value: Optional[str],
db: Session,
) -> Optional[str]:
"""Format a value for display based on field type."""
if value is None:
return None
field_type = field.field_type
if field_type == "person":
# Look up user name
from app.models import User
user = db.query(User).filter(User.id == value).first()
return user.name if user else value
elif field_type == "number" or field_type == "formula":
# Format number
try:
num = Decimal(value)
# Remove trailing zeros after decimal point
formatted = f"{num:,.4f}".rstrip('0').rstrip('.')
return formatted
except (InvalidOperation, ValueError):
return value
elif field_type == "date":
# Format date
try:
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
return dt.strftime('%Y-%m-%d')
except (ValueError, AttributeError):
return value
else:
return value
@staticmethod
def save_custom_values(
db: Session,
task: Task,
custom_values: List[CustomValueInput],
) -> List[str]:
"""
Save custom field values for a task.
Args:
db: Database session
task: The task to save values for
custom_values: List of values to save
Returns:
List of field IDs that were updated (for formula recalculation)
"""
if not custom_values:
return []
updated_field_ids = []
for cv in custom_values:
field = db.query(CustomField).filter(
CustomField.id == cv.field_id,
CustomField.project_id == task.project_id,
).first()
if not field:
continue
# Skip formula fields - they are calculated, not stored directly
if field.field_type == "formula":
continue
# Validate value based on field type
validated_value = CustomValueService._validate_value(
field, cv.value, db
)
# Find existing value or create new
existing = db.query(TaskCustomValue).filter(
TaskCustomValue.task_id == task.id,
TaskCustomValue.field_id == cv.field_id,
).first()
if existing:
if existing.value != validated_value:
existing.value = validated_value
updated_field_ids.append(cv.field_id)
else:
new_value = TaskCustomValue(
id=str(uuid.uuid4()),
task_id=task.id,
field_id=cv.field_id,
value=validated_value,
)
db.add(new_value)
updated_field_ids.append(cv.field_id)
# Recalculate formula fields if any values were updated
if updated_field_ids:
for field_id in updated_field_ids:
FormulaService.recalculate_dependent_formulas(db, task, field_id)
return updated_field_ids
@staticmethod
def _validate_value(
field: CustomField,
value: Any,
db: Session,
) -> Optional[str]:
"""
Validate and normalize a value based on field type.
Returns the validated value as a string, or None.
"""
if value is None or value == "":
if field.is_required:
raise ValueError(f"Field '{field.name}' is required")
return None
field_type = field.field_type
str_value = str(value)
if field_type == "text":
return str_value
elif field_type == "number":
try:
Decimal(str_value)
return str_value
except (InvalidOperation, ValueError):
raise ValueError(f"Invalid number for field '{field.name}'")
elif field_type == "dropdown":
if field.options and str_value not in field.options:
raise ValueError(
f"Invalid option for field '{field.name}'. "
f"Must be one of: {', '.join(field.options)}"
)
return str_value
elif field_type == "date":
# Validate date format
try:
datetime.fromisoformat(str_value.replace('Z', '+00:00'))
return str_value
except ValueError:
# Try parsing as date only
try:
datetime.strptime(str_value, '%Y-%m-%d')
return str_value
except ValueError:
raise ValueError(f"Invalid date for field '{field.name}'")
elif field_type == "person":
# Validate user exists
from app.models import User
user = db.query(User).filter(User.id == str_value).first()
if not user:
raise ValueError(f"Invalid user ID for field '{field.name}'")
return str_value
return str_value
@staticmethod
def validate_required_fields(
db: Session,
project_id: str,
custom_values: Optional[List[CustomValueInput]],
) -> List[str]:
"""
Validate that all required custom fields have values.
Returns list of missing required field names.
"""
required_fields = db.query(CustomField).filter(
CustomField.project_id == project_id,
CustomField.is_required == True,
CustomField.field_type != "formula", # Formula fields are calculated
).all()
if not required_fields:
return []
provided_field_ids = set()
if custom_values:
for cv in custom_values:
if cv.value is not None and cv.value != "":
provided_field_ids.add(cv.field_id)
missing = []
for field in required_fields:
if field.id not in provided_field_ids:
missing.append(field.name)
return missing

View File

@@ -0,0 +1,424 @@
"""
Dependency Service
Handles task dependency validation including:
- Circular dependency detection using DFS
- Date constraint validation based on dependency types
- Self-reference prevention
- Cross-project dependency prevention
"""
from typing import List, Optional, Set, Tuple, Dict, Any
from collections import defaultdict
from sqlalchemy.orm import Session
from datetime import datetime, timedelta
from app.models import Task, TaskDependency
class DependencyValidationError(Exception):
"""Custom exception for dependency validation errors."""
def __init__(self, error_type: str, message: str, details: Optional[dict] = None):
self.error_type = error_type
self.message = message
self.details = details or {}
super().__init__(message)
class DependencyService:
"""Service for managing task dependencies with validation."""
# Maximum number of direct dependencies per task (as per spec)
MAX_DIRECT_DEPENDENCIES = 10
@staticmethod
def detect_circular_dependency(
db: Session,
predecessor_id: str,
successor_id: str,
project_id: str
) -> Optional[List[str]]:
"""
Detect if adding a dependency would create a circular reference.
Uses DFS to traverse from the successor to check if we can reach
the predecessor through existing dependencies.
Args:
db: Database session
predecessor_id: The task that must complete first
successor_id: The task that depends on the predecessor
project_id: Project ID to scope the query
Returns:
List of task IDs forming the cycle if circular, None otherwise
"""
# If adding predecessor -> successor, check if successor can reach predecessor
# This would mean predecessor depends (transitively) on successor, creating a cycle
# Build adjacency list for the project's dependencies
dependencies = db.query(TaskDependency).join(
Task, TaskDependency.successor_id == Task.id
).filter(Task.project_id == project_id).all()
# Graph: successor -> [predecessors]
# We need to check if predecessor is reachable from successor
# by following the chain of "what does this task depend on"
graph: Dict[str, List[str]] = defaultdict(list)
for dep in dependencies:
graph[dep.successor_id].append(dep.predecessor_id)
# Simulate adding the new edge
graph[successor_id].append(predecessor_id)
# DFS to find if there's a path from predecessor back to successor
# (which would complete a cycle)
visited: Set[str] = set()
path: List[str] = []
in_path: Set[str] = set()
def dfs(node: str) -> Optional[List[str]]:
"""DFS traversal to detect cycles."""
if node in in_path:
# Found a cycle - return the cycle path
cycle_start = path.index(node)
return path[cycle_start:] + [node]
if node in visited:
return None
visited.add(node)
in_path.add(node)
path.append(node)
for neighbor in graph.get(node, []):
result = dfs(neighbor)
if result:
return result
path.pop()
in_path.remove(node)
return None
# Start DFS from the successor to check if we can reach back to it
return dfs(successor_id)
@staticmethod
def validate_dependency(
db: Session,
predecessor_id: str,
successor_id: str
) -> None:
"""
Validate that a dependency can be created.
Raises DependencyValidationError if validation fails.
Checks:
1. Self-reference
2. Both tasks exist
3. Both tasks are in the same project
4. No duplicate dependency
5. No circular dependency
6. Dependency limit not exceeded
"""
# Check self-reference
if predecessor_id == successor_id:
raise DependencyValidationError(
error_type="self_reference",
message="A task cannot depend on itself"
)
# Get both tasks
predecessor = db.query(Task).filter(Task.id == predecessor_id).first()
successor = db.query(Task).filter(Task.id == successor_id).first()
if not predecessor:
raise DependencyValidationError(
error_type="not_found",
message="Predecessor task not found",
details={"task_id": predecessor_id}
)
if not successor:
raise DependencyValidationError(
error_type="not_found",
message="Successor task not found",
details={"task_id": successor_id}
)
# Check same project
if predecessor.project_id != successor.project_id:
raise DependencyValidationError(
error_type="cross_project",
message="Dependencies can only be created between tasks in the same project",
details={
"predecessor_project_id": predecessor.project_id,
"successor_project_id": successor.project_id
}
)
# Check duplicate
existing = db.query(TaskDependency).filter(
TaskDependency.predecessor_id == predecessor_id,
TaskDependency.successor_id == successor_id
).first()
if existing:
raise DependencyValidationError(
error_type="duplicate",
message="This dependency already exists"
)
# Check dependency limit
current_count = db.query(TaskDependency).filter(
TaskDependency.successor_id == successor_id
).count()
if current_count >= DependencyService.MAX_DIRECT_DEPENDENCIES:
raise DependencyValidationError(
error_type="limit_exceeded",
message=f"A task cannot have more than {DependencyService.MAX_DIRECT_DEPENDENCIES} direct dependencies",
details={"current_count": current_count}
)
# Check circular dependency
cycle = DependencyService.detect_circular_dependency(
db, predecessor_id, successor_id, predecessor.project_id
)
if cycle:
raise DependencyValidationError(
error_type="circular",
message="Adding this dependency would create a circular reference",
details={"cycle": cycle}
)
@staticmethod
def validate_date_constraints(
task: Task,
start_date: Optional[datetime],
due_date: Optional[datetime],
db: Session
) -> List[Dict[str, Any]]:
"""
Validate date changes against dependency constraints.
Returns a list of constraint violations (empty if valid).
Dependency type meanings:
- FS: predecessor.due_date + lag <= successor.start_date
- SS: predecessor.start_date + lag <= successor.start_date
- FF: predecessor.due_date + lag <= successor.due_date
- SF: predecessor.start_date + lag <= successor.due_date
"""
violations = []
# Use provided dates or fall back to current task dates
new_start = start_date if start_date is not None else task.start_date
new_due = due_date if due_date is not None else task.due_date
# Basic date validation
if new_start and new_due and new_start > new_due:
violations.append({
"type": "date_order",
"message": "Start date cannot be after due date",
"start_date": str(new_start),
"due_date": str(new_due)
})
# Get dependencies where this task is the successor (predecessors)
predecessors = db.query(TaskDependency).filter(
TaskDependency.successor_id == task.id
).all()
for dep in predecessors:
pred_task = dep.predecessor
if not pred_task:
continue
lag = timedelta(days=dep.lag_days)
violation = None
if dep.dependency_type == "FS":
# Predecessor must finish before successor starts
if pred_task.due_date and new_start:
required_start = pred_task.due_date + lag
if new_start < required_start:
violation = {
"type": "dependency_constraint",
"dependency_type": "FS",
"predecessor_id": pred_task.id,
"predecessor_title": pred_task.title,
"message": f"Start date must be on or after {required_start.date()} (predecessor due date + {dep.lag_days} days lag)"
}
elif dep.dependency_type == "SS":
# Predecessor must start before successor starts
if pred_task.start_date and new_start:
required_start = pred_task.start_date + lag
if new_start < required_start:
violation = {
"type": "dependency_constraint",
"dependency_type": "SS",
"predecessor_id": pred_task.id,
"predecessor_title": pred_task.title,
"message": f"Start date must be on or after {required_start.date()} (predecessor start date + {dep.lag_days} days lag)"
}
elif dep.dependency_type == "FF":
# Predecessor must finish before successor finishes
if pred_task.due_date and new_due:
required_due = pred_task.due_date + lag
if new_due < required_due:
violation = {
"type": "dependency_constraint",
"dependency_type": "FF",
"predecessor_id": pred_task.id,
"predecessor_title": pred_task.title,
"message": f"Due date must be on or after {required_due.date()} (predecessor due date + {dep.lag_days} days lag)"
}
elif dep.dependency_type == "SF":
# Predecessor must start before successor finishes
if pred_task.start_date and new_due:
required_due = pred_task.start_date + lag
if new_due < required_due:
violation = {
"type": "dependency_constraint",
"dependency_type": "SF",
"predecessor_id": pred_task.id,
"predecessor_title": pred_task.title,
"message": f"Due date must be on or after {required_due.date()} (predecessor start date + {dep.lag_days} days lag)"
}
if violation:
violations.append(violation)
# Get dependencies where this task is the predecessor (successors)
successors = db.query(TaskDependency).filter(
TaskDependency.predecessor_id == task.id
).all()
for dep in successors:
succ_task = dep.successor
if not succ_task:
continue
lag = timedelta(days=dep.lag_days)
violation = None
if dep.dependency_type == "FS":
# This task must finish before successor starts
if new_due and succ_task.start_date:
required_due = succ_task.start_date - lag
if new_due > required_due:
violation = {
"type": "dependency_constraint",
"dependency_type": "FS",
"successor_id": succ_task.id,
"successor_title": succ_task.title,
"message": f"Due date must be on or before {required_due.date()} (successor start date - {dep.lag_days} days lag)"
}
elif dep.dependency_type == "SS":
# This task must start before successor starts
if new_start and succ_task.start_date:
required_start = succ_task.start_date - lag
if new_start > required_start:
violation = {
"type": "dependency_constraint",
"dependency_type": "SS",
"successor_id": succ_task.id,
"successor_title": succ_task.title,
"message": f"Start date must be on or before {required_start.date()} (successor start date - {dep.lag_days} days lag)"
}
elif dep.dependency_type == "FF":
# This task must finish before successor finishes
if new_due and succ_task.due_date:
required_due = succ_task.due_date - lag
if new_due > required_due:
violation = {
"type": "dependency_constraint",
"dependency_type": "FF",
"successor_id": succ_task.id,
"successor_title": succ_task.title,
"message": f"Due date must be on or before {required_due.date()} (successor due date - {dep.lag_days} days lag)"
}
elif dep.dependency_type == "SF":
# This task must start before successor finishes
if new_start and succ_task.due_date:
required_start = succ_task.due_date - lag
if new_start > required_start:
violation = {
"type": "dependency_constraint",
"dependency_type": "SF",
"successor_id": succ_task.id,
"successor_title": succ_task.title,
"message": f"Start date must be on or before {required_start.date()} (successor due date - {dep.lag_days} days lag)"
}
if violation:
violations.append(violation)
return violations
@staticmethod
def get_all_predecessors(db: Session, task_id: str) -> List[str]:
"""
Get all transitive predecessors of a task.
Uses BFS to find all tasks that this task depends on (directly or indirectly).
"""
visited: Set[str] = set()
queue = [task_id]
predecessors = []
while queue:
current = queue.pop(0)
if current in visited:
continue
visited.add(current)
deps = db.query(TaskDependency).filter(
TaskDependency.successor_id == current
).all()
for dep in deps:
if dep.predecessor_id not in visited:
predecessors.append(dep.predecessor_id)
queue.append(dep.predecessor_id)
return predecessors
@staticmethod
def get_all_successors(db: Session, task_id: str) -> List[str]:
"""
Get all transitive successors of a task.
Uses BFS to find all tasks that depend on this task (directly or indirectly).
"""
visited: Set[str] = set()
queue = [task_id]
successors = []
while queue:
current = queue.pop(0)
if current in visited:
continue
visited.add(current)
deps = db.query(TaskDependency).filter(
TaskDependency.predecessor_id == current
).all()
for dep in deps:
if dep.successor_id not in visited:
successors.append(dep.successor_id)
queue.append(dep.successor_id)
return successors

View File

@@ -0,0 +1,300 @@
"""
Encryption service for AES-256-GCM file encryption.
This service handles:
- File encryption key generation and management
- Encrypting/decrypting file encryption keys with Master Key
- Streaming file encryption/decryption with AES-256-GCM
"""
import os
import base64
import secrets
import logging
from typing import BinaryIO, Tuple, Optional, Generator
from io import BytesIO
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.backends import default_backend
from app.core.config import settings
logger = logging.getLogger(__name__)
# Constants
KEY_SIZE = 32 # 256 bits for AES-256
NONCE_SIZE = 12 # 96 bits for GCM recommended nonce size
TAG_SIZE = 16 # 128 bits for GCM authentication tag
CHUNK_SIZE = 64 * 1024 # 64KB chunks for streaming
class EncryptionError(Exception):
"""Base exception for encryption errors."""
pass
class MasterKeyNotConfiguredError(EncryptionError):
"""Raised when master key is not configured."""
pass
class DecryptionError(EncryptionError):
"""Raised when decryption fails."""
pass
class EncryptionService:
"""
Service for file encryption using AES-256-GCM.
Key hierarchy:
1. Master Key (from environment) -> encrypts file encryption keys
2. File Encryption Keys (stored in DB) -> encrypt actual files
"""
def __init__(self):
self._master_key: Optional[bytes] = None
@property
def master_key(self) -> bytes:
"""Get the master key, loading from config if needed."""
if self._master_key is None:
if not settings.ENCRYPTION_MASTER_KEY:
raise MasterKeyNotConfiguredError(
"ENCRYPTION_MASTER_KEY is not configured. "
"File encryption is disabled."
)
self._master_key = base64.urlsafe_b64decode(settings.ENCRYPTION_MASTER_KEY)
return self._master_key
def is_encryption_available(self) -> bool:
"""Check if encryption is available (master key configured)."""
return settings.ENCRYPTION_MASTER_KEY is not None
def generate_key(self) -> bytes:
"""
Generate a new AES-256 encryption key.
Returns:
32-byte random key
"""
return secrets.token_bytes(KEY_SIZE)
def encrypt_key(self, key: bytes) -> str:
"""
Encrypt a file encryption key using the Master Key.
Args:
key: The raw 32-byte file encryption key
Returns:
Base64-encoded encrypted key (nonce + ciphertext + tag)
"""
aesgcm = AESGCM(self.master_key)
nonce = secrets.token_bytes(NONCE_SIZE)
# Encrypt the key
ciphertext = aesgcm.encrypt(nonce, key, None)
# Combine nonce + ciphertext (includes tag)
encrypted_data = nonce + ciphertext
return base64.urlsafe_b64encode(encrypted_data).decode('utf-8')
def decrypt_key(self, encrypted_key: str) -> bytes:
"""
Decrypt a file encryption key using the Master Key.
Args:
encrypted_key: Base64-encoded encrypted key
Returns:
The raw 32-byte file encryption key
"""
try:
encrypted_data = base64.urlsafe_b64decode(encrypted_key)
# Extract nonce and ciphertext
nonce = encrypted_data[:NONCE_SIZE]
ciphertext = encrypted_data[NONCE_SIZE:]
# Decrypt
aesgcm = AESGCM(self.master_key)
return aesgcm.decrypt(nonce, ciphertext, None)
except Exception as e:
logger.error(f"Failed to decrypt encryption key: {e}")
raise DecryptionError("Failed to decrypt file encryption key")
def encrypt_file(self, file_content: BinaryIO, key: bytes) -> bytes:
"""
Encrypt file content using AES-256-GCM.
For smaller files, encrypts the entire content at once.
The format is: nonce (12 bytes) + ciphertext + tag (16 bytes)
Args:
file_content: File-like object to encrypt
key: 32-byte AES-256 key
Returns:
Encrypted bytes (nonce + ciphertext + tag)
"""
# Read all content
plaintext = file_content.read()
# Generate nonce
nonce = secrets.token_bytes(NONCE_SIZE)
# Encrypt
aesgcm = AESGCM(key)
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
# Return nonce + ciphertext (tag is appended by encrypt)
return nonce + ciphertext
def decrypt_file(self, encrypted_content: BinaryIO, key: bytes) -> bytes:
"""
Decrypt file content using AES-256-GCM.
Args:
encrypted_content: File-like object containing encrypted data
key: 32-byte AES-256 key
Returns:
Decrypted bytes
"""
try:
# Read all encrypted content
encrypted_data = encrypted_content.read()
# Extract nonce and ciphertext
nonce = encrypted_data[:NONCE_SIZE]
ciphertext = encrypted_data[NONCE_SIZE:]
# Decrypt
aesgcm = AESGCM(key)
plaintext = aesgcm.decrypt(nonce, ciphertext, None)
return plaintext
except Exception as e:
logger.error(f"Failed to decrypt file: {e}")
raise DecryptionError("Failed to decrypt file. The file may be corrupted or the key is incorrect.")
def encrypt_file_streaming(self, file_content: BinaryIO, key: bytes) -> Generator[bytes, None, None]:
"""
Encrypt file content using AES-256-GCM with streaming.
For large files, encrypts in chunks. Each chunk has its own nonce.
Format per chunk: chunk_size (4 bytes) + nonce (12 bytes) + ciphertext + tag
Args:
file_content: File-like object to encrypt
key: 32-byte AES-256 key
Yields:
Encrypted chunks
"""
aesgcm = AESGCM(key)
# Write header with version byte
yield b'\x01' # Version 1 for streaming format
while True:
chunk = file_content.read(CHUNK_SIZE)
if not chunk:
break
# Generate nonce for this chunk
nonce = secrets.token_bytes(NONCE_SIZE)
# Encrypt chunk
ciphertext = aesgcm.encrypt(nonce, chunk, None)
# Write chunk size (4 bytes, little endian)
chunk_size = len(ciphertext) + NONCE_SIZE
yield chunk_size.to_bytes(4, 'little')
# Write nonce + ciphertext
yield nonce + ciphertext
# Write end marker (zero size)
yield b'\x00\x00\x00\x00'
def decrypt_file_streaming(self, encrypted_content: BinaryIO, key: bytes) -> Generator[bytes, None, None]:
"""
Decrypt file content using AES-256-GCM with streaming.
Args:
encrypted_content: File-like object containing encrypted data
key: 32-byte AES-256 key
Yields:
Decrypted chunks
"""
aesgcm = AESGCM(key)
# Read version byte
version = encrypted_content.read(1)
if version != b'\x01':
raise DecryptionError(f"Unknown encryption format version")
while True:
# Read chunk size
size_bytes = encrypted_content.read(4)
if len(size_bytes) < 4:
raise DecryptionError("Unexpected end of file")
chunk_size = int.from_bytes(size_bytes, 'little')
# Check for end marker
if chunk_size == 0:
break
# Read chunk (nonce + ciphertext)
chunk = encrypted_content.read(chunk_size)
if len(chunk) < chunk_size:
raise DecryptionError("Unexpected end of file")
# Extract nonce and ciphertext
nonce = chunk[:NONCE_SIZE]
ciphertext = chunk[NONCE_SIZE:]
try:
# Decrypt
plaintext = aesgcm.decrypt(nonce, ciphertext, None)
yield plaintext
except Exception as e:
raise DecryptionError(f"Failed to decrypt chunk: {e}")
def encrypt_bytes(self, data: bytes, key: bytes) -> bytes:
"""
Encrypt bytes directly (convenience method).
Args:
data: Bytes to encrypt
key: 32-byte AES-256 key
Returns:
Encrypted bytes
"""
return self.encrypt_file(BytesIO(data), key)
def decrypt_bytes(self, encrypted_data: bytes, key: bytes) -> bytes:
"""
Decrypt bytes directly (convenience method).
Args:
encrypted_data: Encrypted bytes
key: 32-byte AES-256 key
Returns:
Decrypted bytes
"""
return self.decrypt_file(BytesIO(encrypted_data), key)
# Singleton instance
encryption_service = EncryptionService()

View File

@@ -0,0 +1,420 @@
"""
Formula Service for Custom Fields
Supports:
- Basic math operations: +, -, *, /
- Field references: {field_name}
- Built-in task fields: {original_estimate}, {time_spent}
- Parentheses for grouping
Example formulas:
- "{time_spent} / {original_estimate} * 100"
- "{cost_per_hour} * {hours_worked}"
- "({field_a} + {field_b}) / 2"
"""
import re
import ast
import operator
from typing import Dict, Any, Optional, List, Set, Tuple
from decimal import Decimal, InvalidOperation
from sqlalchemy.orm import Session
from app.models import Task, CustomField, TaskCustomValue
class FormulaError(Exception):
"""Exception raised for formula parsing or calculation errors."""
pass
class CircularReferenceError(FormulaError):
"""Exception raised when circular references are detected in formulas."""
pass
class FormulaService:
"""Service for parsing and calculating formula fields."""
# Built-in task fields that can be referenced in formulas
BUILTIN_FIELDS = {
"original_estimate",
"time_spent",
}
# Supported operators
OPERATORS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.USub: operator.neg,
}
@staticmethod
def extract_field_references(formula: str) -> Set[str]:
"""
Extract all field references from a formula.
Field references are in the format {field_name}.
Returns a set of field names referenced in the formula.
"""
pattern = r'\{([^}]+)\}'
matches = re.findall(pattern, formula)
return set(matches)
@staticmethod
def validate_formula(
formula: str,
project_id: str,
db: Session,
current_field_id: Optional[str] = None,
) -> Tuple[bool, Optional[str]]:
"""
Validate a formula expression.
Checks:
1. Syntax is valid
2. All referenced fields exist
3. Referenced fields are number or formula type
4. No circular references
Returns (is_valid, error_message)
"""
if not formula or not formula.strip():
return False, "Formula cannot be empty"
# Extract field references
references = FormulaService.extract_field_references(formula)
if not references:
return False, "Formula must reference at least one field"
# Validate syntax by trying to parse
try:
# Replace field references with dummy numbers for syntax check
test_formula = formula
for ref in references:
test_formula = test_formula.replace(f"{{{ref}}}", "1")
# Try to parse and evaluate with safe operations
FormulaService._safe_eval(test_formula)
except Exception as e:
return False, f"Invalid formula syntax: {str(e)}"
# Separate builtin and custom field references
custom_references = references - FormulaService.BUILTIN_FIELDS
# Validate custom field references exist and are numeric types
if custom_references:
fields = db.query(CustomField).filter(
CustomField.project_id == project_id,
CustomField.name.in_(custom_references),
).all()
found_names = {f.name for f in fields}
missing = custom_references - found_names
if missing:
return False, f"Unknown field references: {', '.join(missing)}"
# Check field types (must be number or formula)
for field in fields:
if field.field_type not in ("number", "formula"):
return False, f"Field '{field.name}' is not a numeric type"
# Check for circular references
if current_field_id:
try:
FormulaService._check_circular_references(
db, project_id, current_field_id, references
)
except CircularReferenceError as e:
return False, str(e)
return True, None
@staticmethod
def _check_circular_references(
db: Session,
project_id: str,
field_id: str,
references: Set[str],
visited: Optional[Set[str]] = None,
) -> None:
"""
Check for circular references in formula fields.
Raises CircularReferenceError if a cycle is detected.
"""
if visited is None:
visited = set()
# Get the current field's name
current_field = db.query(CustomField).filter(
CustomField.id == field_id
).first()
if current_field:
if current_field.name in references:
raise CircularReferenceError(
f"Circular reference detected: field cannot reference itself"
)
# Get all referenced formula fields
custom_references = references - FormulaService.BUILTIN_FIELDS
if not custom_references:
return
formula_fields = db.query(CustomField).filter(
CustomField.project_id == project_id,
CustomField.name.in_(custom_references),
CustomField.field_type == "formula",
).all()
for field in formula_fields:
if field.id in visited:
raise CircularReferenceError(
f"Circular reference detected involving field '{field.name}'"
)
visited.add(field.id)
if field.formula:
nested_refs = FormulaService.extract_field_references(field.formula)
if current_field and current_field.name in nested_refs:
raise CircularReferenceError(
f"Circular reference detected: '{field.name}' references the current field"
)
FormulaService._check_circular_references(
db, project_id, field_id, nested_refs, visited
)
@staticmethod
def _safe_eval(expression: str) -> Decimal:
"""
Safely evaluate a mathematical expression.
Only allows basic arithmetic operations (+, -, *, /).
"""
try:
node = ast.parse(expression, mode='eval')
return FormulaService._eval_node(node.body)
except Exception as e:
raise FormulaError(f"Failed to evaluate expression: {str(e)}")
@staticmethod
def _eval_node(node: ast.AST) -> Decimal:
"""Recursively evaluate an AST node."""
if isinstance(node, ast.Constant):
if isinstance(node.value, (int, float)):
return Decimal(str(node.value))
raise FormulaError(f"Invalid constant: {node.value}")
elif isinstance(node, ast.BinOp):
left = FormulaService._eval_node(node.left)
right = FormulaService._eval_node(node.right)
op = FormulaService.OPERATORS.get(type(node.op))
if op is None:
raise FormulaError(f"Unsupported operator: {type(node.op).__name__}")
# Handle division by zero
if isinstance(node.op, ast.Div) and right == 0:
return Decimal('0') # Return 0 instead of raising error
return Decimal(str(op(float(left), float(right))))
elif isinstance(node, ast.UnaryOp):
operand = FormulaService._eval_node(node.operand)
op = FormulaService.OPERATORS.get(type(node.op))
if op is None:
raise FormulaError(f"Unsupported operator: {type(node.op).__name__}")
return Decimal(str(op(float(operand))))
else:
raise FormulaError(f"Unsupported expression type: {type(node).__name__}")
@staticmethod
def calculate_formula(
formula: str,
task: Task,
db: Session,
calculated_cache: Optional[Dict[str, Decimal]] = None,
) -> Optional[Decimal]:
"""
Calculate the value of a formula for a given task.
Args:
formula: The formula expression
task: The task to calculate for
db: Database session
calculated_cache: Cache for already calculated formula values (for recursion)
Returns:
The calculated value, or None if calculation fails
"""
if calculated_cache is None:
calculated_cache = {}
references = FormulaService.extract_field_references(formula)
values: Dict[str, Decimal] = {}
# Get builtin field values
for ref in references:
if ref in FormulaService.BUILTIN_FIELDS:
task_value = getattr(task, ref, None)
if task_value is not None:
values[ref] = Decimal(str(task_value))
else:
values[ref] = Decimal('0')
# Get custom field values
custom_references = references - FormulaService.BUILTIN_FIELDS
if custom_references:
# Get field definitions
fields = db.query(CustomField).filter(
CustomField.project_id == task.project_id,
CustomField.name.in_(custom_references),
).all()
field_map = {f.name: f for f in fields}
# Get custom values for this task
custom_values = db.query(TaskCustomValue).filter(
TaskCustomValue.task_id == task.id,
TaskCustomValue.field_id.in_([f.id for f in fields]),
).all()
value_map = {cv.field_id: cv.value for cv in custom_values}
for ref in custom_references:
field = field_map.get(ref)
if not field:
values[ref] = Decimal('0')
continue
if field.field_type == "formula":
# Recursively calculate formula fields
if field.id in calculated_cache:
values[ref] = calculated_cache[field.id]
else:
nested_value = FormulaService.calculate_formula(
field.formula, task, db, calculated_cache
)
values[ref] = nested_value if nested_value is not None else Decimal('0')
calculated_cache[field.id] = values[ref]
else:
# Get stored value
stored_value = value_map.get(field.id)
if stored_value:
try:
values[ref] = Decimal(str(stored_value))
except (InvalidOperation, ValueError):
values[ref] = Decimal('0')
else:
values[ref] = Decimal('0')
# Substitute values into formula
expression = formula
for ref, value in values.items():
expression = expression.replace(f"{{{ref}}}", str(value))
# Evaluate the expression
try:
result = FormulaService._safe_eval(expression)
# Round to 4 decimal places
return result.quantize(Decimal('0.0001'))
except Exception:
return None
@staticmethod
def recalculate_dependent_formulas(
db: Session,
task: Task,
changed_field_id: str,
) -> Dict[str, Decimal]:
"""
Recalculate all formula fields that depend on a changed field.
Returns a dict of field_id -> calculated_value for updated formulas.
"""
# Get the changed field
changed_field = db.query(CustomField).filter(
CustomField.id == changed_field_id
).first()
if not changed_field:
return {}
# Find all formula fields in the project
formula_fields = db.query(CustomField).filter(
CustomField.project_id == task.project_id,
CustomField.field_type == "formula",
).all()
results = {}
calculated_cache: Dict[str, Decimal] = {}
for field in formula_fields:
if not field.formula:
continue
# Check if this formula depends on the changed field
references = FormulaService.extract_field_references(field.formula)
if changed_field.name in references or changed_field.name in FormulaService.BUILTIN_FIELDS:
value = FormulaService.calculate_formula(
field.formula, task, db, calculated_cache
)
if value is not None:
results[field.id] = value
calculated_cache[field.id] = value
# Update or create the custom value
existing = db.query(TaskCustomValue).filter(
TaskCustomValue.task_id == task.id,
TaskCustomValue.field_id == field.id,
).first()
if existing:
existing.value = str(value)
else:
import uuid
new_value = TaskCustomValue(
id=str(uuid.uuid4()),
task_id=task.id,
field_id=field.id,
value=str(value),
)
db.add(new_value)
return results
@staticmethod
def calculate_all_formulas_for_task(
db: Session,
task: Task,
) -> Dict[str, Decimal]:
"""
Calculate all formula fields for a task.
Used when loading a task to get current formula values.
"""
formula_fields = db.query(CustomField).filter(
CustomField.project_id == task.project_id,
CustomField.field_type == "formula",
).all()
results = {}
calculated_cache: Dict[str, Decimal] = {}
for field in formula_fields:
if not field.formula:
continue
value = FormulaService.calculate_formula(
field.formula, task, db, calculated_cache
)
if value is not None:
results[field.id] = value
calculated_cache[field.id] = value
return results