Files
PROJECT-CONTORL/backend/app/services/file_storage_service.py
beabigegg 679b89ae4c feat: implement security, error resilience, and query optimization proposals
Security Validation (enhance-security-validation):
- JWT secret validation with entropy checking and pattern detection
- CSRF protection middleware with token generation/validation
- Frontend CSRF token auto-injection for DELETE/PUT/PATCH requests
- MIME type validation with magic bytes detection for file uploads

Error Resilience (add-error-resilience):
- React ErrorBoundary component with fallback UI and retry functionality
- ErrorBoundaryWithI18n wrapper for internationalization support
- Page-level and section-level error boundaries in App.tsx

Query Performance (optimize-query-performance):
- Query monitoring utility with threshold warnings
- N+1 query fixes using joinedload/selectinload
- Optimized project members, tasks, and subtasks endpoints

Bug Fixes:
- WebSocket session management (P0): Return primitives instead of ORM objects
- LIKE query injection (P1): Escape special characters in search queries

Tests: 543 backend tests, 56 frontend tests passing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-11 18:41:19 +08:00

555 lines
19 KiB
Python

import os
import hashlib
import shutil
import logging
import tempfile
from pathlib import Path
from typing import BinaryIO, Optional, Tuple
from fastapi import UploadFile, HTTPException
from app.core.config import settings
logger = logging.getLogger(__name__)
class PathTraversalError(Exception):
"""Raised when a path traversal attempt is detected."""
pass
class StorageValidationError(Exception):
"""Raised when storage validation fails."""
pass
class FileStorageService:
"""Service for handling file storage operations."""
# Common NAS mount points to detect
NAS_MOUNT_INDICATORS = [
"/mnt/", "/mount/", "/nas/", "/nfs/", "/smb/", "/cifs/",
"/Volumes/", "/media/", "/srv/", "/storage/"
]
def __init__(self):
self.base_dir = Path(settings.UPLOAD_DIR).resolve()
# Backward-compatible attribute name for tests and older code
self.upload_dir = self.base_dir
self._storage_status = {
"validated": False,
"path_exists": False,
"writable": False,
"is_nas": False,
"error": None,
}
self._validate_storage_on_init()
def _validate_storage_on_init(self):
"""Validate storage configuration on service initialization."""
try:
# Step 1: Ensure directory exists
self._ensure_base_dir()
self._storage_status["path_exists"] = True
# Step 2: Check write permissions
self._check_write_permissions()
self._storage_status["writable"] = True
# Step 3: Check if using NAS
is_nas = self._detect_nas_storage()
self._storage_status["is_nas"] = is_nas
if not is_nas:
logger.warning(
"Storage directory '%s' appears to be local storage, not NAS. "
"Consider configuring UPLOAD_DIR to a NAS mount point for production use.",
self.base_dir
)
self._storage_status["validated"] = True
logger.info(
"Storage validated successfully: path=%s, is_nas=%s",
self.base_dir, is_nas
)
except StorageValidationError as e:
self._storage_status["error"] = str(e)
logger.error("Storage validation failed: %s", e)
raise
except Exception as e:
self._storage_status["error"] = str(e)
logger.error("Unexpected error during storage validation: %s", e)
raise StorageValidationError(f"Storage validation failed: {e}")
def _check_write_permissions(self):
"""Check if the storage directory has write permissions."""
test_file = self.base_dir / f".write_test_{os.getpid()}"
try:
# Try to create and write to a test file
test_file.write_text("write_test")
# Verify we can read it back
content = test_file.read_text()
if content != "write_test":
raise StorageValidationError(
f"Write verification failed for directory: {self.base_dir}"
)
# Clean up
test_file.unlink()
logger.debug("Write permission check passed for %s", self.base_dir)
except PermissionError as e:
raise StorageValidationError(
f"No write permission for storage directory '{self.base_dir}': {e}"
)
except OSError as e:
raise StorageValidationError(
f"Failed to verify write permissions for '{self.base_dir}': {e}"
)
finally:
# Ensure test file is removed even on partial failure
if test_file.exists():
try:
test_file.unlink()
except Exception:
pass
def _detect_nas_storage(self) -> bool:
"""
Detect if the storage directory appears to be on a NAS mount.
This is a best-effort detection based on common mount point patterns.
"""
path_str = str(self.base_dir)
# Check common NAS mount point patterns
for indicator in self.NAS_MOUNT_INDICATORS:
if indicator in path_str:
logger.debug("NAS storage detected: path contains '%s'", indicator)
return True
# Check if it's a mount point (Unix-like systems)
try:
if self.base_dir.is_mount():
logger.debug("NAS storage detected: path is a mount point")
return True
except Exception:
pass
# Check mount info on Linux
try:
with open("/proc/mounts", "r") as f:
mounts = f.read()
if path_str in mounts:
# Check for network filesystem types
for line in mounts.splitlines():
if path_str in line:
fs_type = line.split()[2] if len(line.split()) > 2 else ""
if fs_type in ["nfs", "nfs4", "cifs", "smb", "smbfs"]:
logger.debug("NAS storage detected: mounted as %s", fs_type)
return True
except FileNotFoundError:
pass # Not on Linux
except Exception as e:
logger.debug("Could not check /proc/mounts: %s", e)
return False
def get_storage_status(self) -> dict:
"""Get current storage status for health checks."""
return {
**self._storage_status,
"base_dir": str(self.base_dir),
"exists": self.base_dir.exists(),
"is_directory": self.base_dir.is_dir() if self.base_dir.exists() else False,
}
def _ensure_base_dir(self):
"""Ensure the base upload directory exists."""
self.base_dir.mkdir(parents=True, exist_ok=True)
def _validate_path_component(self, component: str, component_name: str) -> None:
"""
Validate a path component to prevent path traversal attacks.
Args:
component: The path component to validate (e.g., project_id, task_id)
component_name: Name of the component for error messages
Raises:
PathTraversalError: If the component contains path traversal patterns
"""
if not component:
raise PathTraversalError(f"Empty {component_name} is not allowed")
# Check for path traversal patterns
dangerous_patterns = ['..', '/', '\\', '\x00']
for pattern in dangerous_patterns:
if pattern in component:
logger.warning(
"Path traversal attempt detected in %s: %r",
component_name,
component
)
raise PathTraversalError(
f"Invalid characters in {component_name}: path traversal not allowed"
)
# Additional check: component should not start with special characters
if component.startswith('.') or component.startswith('-'):
logger.warning(
"Suspicious path component in %s: %r",
component_name,
component
)
raise PathTraversalError(
f"Invalid {component_name}: cannot start with '.' or '-'"
)
def _validate_path_in_base_dir(self, path: Path, context: str = "") -> Path:
"""
Validate that a resolved path is within the base directory.
Args:
path: The path to validate
context: Additional context for logging
Returns:
The resolved path if valid
Raises:
PathTraversalError: If the path is outside the base directory
"""
resolved_path = path.resolve()
base_dir = self.base_dir.resolve()
# Check if the resolved path is within the base directory
try:
resolved_path.relative_to(base_dir)
except ValueError:
logger.warning(
"Path traversal attempt detected: path %s is outside base directory %s. Context: %s",
resolved_path,
base_dir,
context
)
raise PathTraversalError(
"Access denied: path is outside the allowed directory"
)
return resolved_path
def _get_file_path(self, project_id: str, task_id: str, attachment_id: str, version: int) -> Path:
"""
Generate the file path for an attachment version.
Args:
project_id: The project ID
task_id: The task ID
attachment_id: The attachment ID
version: The version number
Returns:
Safe path within the base directory
Raises:
PathTraversalError: If any component contains path traversal patterns
"""
# Validate all path components
self._validate_path_component(project_id, "project_id")
self._validate_path_component(task_id, "task_id")
self._validate_path_component(attachment_id, "attachment_id")
if version < 0:
raise PathTraversalError("Version must be non-negative")
# Build the path
path = self.base_dir / project_id / task_id / attachment_id / str(version)
# Validate the final path is within base directory
return self._validate_path_in_base_dir(
path,
f"project={project_id}, task={task_id}, attachment={attachment_id}, version={version}"
)
@staticmethod
def calculate_checksum(file: BinaryIO) -> str:
"""Calculate SHA-256 checksum of a file."""
sha256_hash = hashlib.sha256()
# Read in chunks to handle large files
for chunk in iter(lambda: file.read(8192), b""):
sha256_hash.update(chunk)
file.seek(0) # Reset file position
return sha256_hash.hexdigest()
@staticmethod
def get_extension(filename: str) -> str:
"""Get file extension in lowercase."""
return filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
@staticmethod
def validate_file(file: UploadFile, validate_mime: bool = True) -> Tuple[str, str]:
"""
Validate file size, type, and optionally MIME content.
Returns (extension, mime_type) if valid.
Raises HTTPException if invalid.
Args:
file: The uploaded file
validate_mime: If True, validate MIME type using magic bytes detection
"""
# Check file size
file.file.seek(0, 2) # Seek to end
file_size = file.file.tell()
file.file.seek(0) # Reset
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size is {settings.MAX_FILE_SIZE_MB}MB"
)
if file_size == 0:
raise HTTPException(status_code=400, detail="Empty file not allowed")
# Get extension
extension = FileStorageService.get_extension(file.filename or "")
# Check blocked extensions
if extension in settings.BLOCKED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"File type '.{extension}' is not allowed for security reasons"
)
# Check allowed extensions (if whitelist is enabled)
if settings.ALLOWED_EXTENSIONS and extension not in settings.ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"File type '.{extension}' is not supported"
)
# Validate MIME type using magic bytes detection
if validate_mime:
from app.services.mime_validation_service import mime_validation_service
# Read first 16 bytes for magic detection (enough for most signatures)
file_header = file.file.read(16)
file.file.seek(0) # Reset
is_valid, detected_mime, error_message = mime_validation_service.validate_file_content(
file_content=file_header,
declared_extension=extension,
declared_mime_type=file.content_type
)
if not is_valid:
logger.warning(
"MIME validation failed for file '%s': %s (detected: %s)",
file.filename, error_message, detected_mime
)
raise HTTPException(
status_code=400,
detail=error_message or "File type validation failed"
)
# Use detected MIME type if available, otherwise fall back to declared
mime_type = detected_mime if detected_mime else (file.content_type or "application/octet-stream")
else:
mime_type = file.content_type or "application/octet-stream"
return extension, mime_type
async def save_file(
self,
file: UploadFile,
project_id: str,
task_id: str,
attachment_id: str,
version: int
) -> Tuple[str, int, str]:
"""
Save uploaded file to storage.
Returns (file_path, file_size, checksum).
Raises:
PathTraversalError: If path traversal is detected
HTTPException: If file validation fails
"""
# Validate file
extension, _ = self.validate_file(file)
# Calculate checksum first
checksum = self.calculate_checksum(file.file)
# Create directory structure (path validation is done in _get_file_path)
try:
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
except PathTraversalError as e:
logger.error("Path traversal attempt during file save: %s", e)
raise HTTPException(status_code=400, detail=str(e))
dir_path.mkdir(parents=True, exist_ok=True)
# Save file with original extension
filename = f"file.{extension}" if extension else "file"
file_path = dir_path / filename
# Final validation of the file path
self._validate_path_in_base_dir(file_path, f"saving file {filename}")
# Get file size
file.file.seek(0, 2)
file_size = file.file.tell()
file.file.seek(0)
# Write file in chunks (streaming)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return str(file_path), file_size, checksum
def get_file(
self,
project_id: str,
task_id: str,
attachment_id: str,
version: int
) -> Optional[Path]:
"""
Get the file path for an attachment version.
Returns None if file doesn't exist.
Raises:
PathTraversalError: If path traversal is detected
"""
try:
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
except PathTraversalError as e:
logger.error("Path traversal attempt during file retrieval: %s", e)
return None
if not dir_path.exists():
return None
# Find the file in the directory
files = list(dir_path.iterdir())
if not files:
return None
return files[0]
def get_file_by_path(self, file_path: str) -> Optional[Path]:
"""
Get file by stored path. Handles both absolute and relative paths.
This method validates that the requested path is within the base directory
to prevent path traversal attacks.
Args:
file_path: The stored file path
Returns:
Path object if file exists and is within base directory, None otherwise
"""
if not file_path:
return None
path = Path(file_path)
# For absolute paths, validate they are within base_dir
if path.is_absolute():
try:
validated_path = self._validate_path_in_base_dir(
path,
f"get_file_by_path absolute: {file_path}"
)
if validated_path.exists():
return validated_path
except PathTraversalError:
return None
return None
# For relative paths, resolve from base_dir
full_path = self.base_dir / path
try:
validated_path = self._validate_path_in_base_dir(
full_path,
f"get_file_by_path relative: {file_path}"
)
if validated_path.exists():
return validated_path
except PathTraversalError:
return None
return None
def delete_file(
self,
project_id: str,
task_id: str,
attachment_id: str,
version: Optional[int] = None
) -> bool:
"""
Delete file(s) from storage.
If version is None, deletes all versions.
Returns True if successful.
Raises:
PathTraversalError: If path traversal is detected
"""
try:
if version is not None:
# Delete specific version
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
else:
# Delete all versions (attachment directory)
# Validate components first
self._validate_path_component(project_id, "project_id")
self._validate_path_component(task_id, "task_id")
self._validate_path_component(attachment_id, "attachment_id")
dir_path = self.base_dir / project_id / task_id / attachment_id
dir_path = self._validate_path_in_base_dir(
dir_path,
f"delete attachment: project={project_id}, task={task_id}, attachment={attachment_id}"
)
except PathTraversalError as e:
logger.error("Path traversal attempt during file deletion: %s", e)
return False
if dir_path.exists():
shutil.rmtree(dir_path)
return True
return False
def delete_task_files(self, project_id: str, task_id: str) -> bool:
"""
Delete all files for a task.
Raises:
PathTraversalError: If path traversal is detected
"""
try:
# Validate components
self._validate_path_component(project_id, "project_id")
self._validate_path_component(task_id, "task_id")
dir_path = self.base_dir / project_id / task_id
dir_path = self._validate_path_in_base_dir(
dir_path,
f"delete task files: project={project_id}, task={task_id}"
)
except PathTraversalError as e:
logger.error("Path traversal attempt during task file deletion: %s", e)
return False
if dir_path.exists():
shutil.rmtree(dir_path)
return True
return False
# Singleton instance
file_storage_service = FileStorageService()