523 lines
18 KiB
Python
523 lines
18 KiB
Python
import os
|
|
import hashlib
|
|
import shutil
|
|
import logging
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import BinaryIO, Optional, Tuple
|
|
from fastapi import UploadFile, HTTPException
|
|
from app.core.config import settings
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PathTraversalError(Exception):
|
|
"""Raised when a path traversal attempt is detected."""
|
|
pass
|
|
|
|
|
|
class StorageValidationError(Exception):
|
|
"""Raised when storage validation fails."""
|
|
pass
|
|
|
|
|
|
class FileStorageService:
|
|
"""Service for handling file storage operations."""
|
|
|
|
# Common NAS mount points to detect
|
|
NAS_MOUNT_INDICATORS = [
|
|
"/mnt/", "/mount/", "/nas/", "/nfs/", "/smb/", "/cifs/",
|
|
"/Volumes/", "/media/", "/srv/", "/storage/"
|
|
]
|
|
|
|
def __init__(self):
|
|
self.base_dir = Path(settings.UPLOAD_DIR).resolve()
|
|
# Backward-compatible attribute name for tests and older code
|
|
self.upload_dir = self.base_dir
|
|
self._storage_status = {
|
|
"validated": False,
|
|
"path_exists": False,
|
|
"writable": False,
|
|
"is_nas": False,
|
|
"error": None,
|
|
}
|
|
self._validate_storage_on_init()
|
|
|
|
def _validate_storage_on_init(self):
|
|
"""Validate storage configuration on service initialization."""
|
|
try:
|
|
# Step 1: Ensure directory exists
|
|
self._ensure_base_dir()
|
|
self._storage_status["path_exists"] = True
|
|
|
|
# Step 2: Check write permissions
|
|
self._check_write_permissions()
|
|
self._storage_status["writable"] = True
|
|
|
|
# Step 3: Check if using NAS
|
|
is_nas = self._detect_nas_storage()
|
|
self._storage_status["is_nas"] = is_nas
|
|
|
|
if not is_nas:
|
|
logger.warning(
|
|
"Storage directory '%s' appears to be local storage, not NAS. "
|
|
"Consider configuring UPLOAD_DIR to a NAS mount point for production use.",
|
|
self.base_dir
|
|
)
|
|
|
|
self._storage_status["validated"] = True
|
|
logger.info(
|
|
"Storage validated successfully: path=%s, is_nas=%s",
|
|
self.base_dir, is_nas
|
|
)
|
|
|
|
except StorageValidationError as e:
|
|
self._storage_status["error"] = str(e)
|
|
logger.error("Storage validation failed: %s", e)
|
|
raise
|
|
except Exception as e:
|
|
self._storage_status["error"] = str(e)
|
|
logger.error("Unexpected error during storage validation: %s", e)
|
|
raise StorageValidationError(f"Storage validation failed: {e}")
|
|
|
|
def _check_write_permissions(self):
|
|
"""Check if the storage directory has write permissions."""
|
|
test_file = self.base_dir / f".write_test_{os.getpid()}"
|
|
try:
|
|
# Try to create and write to a test file
|
|
test_file.write_text("write_test")
|
|
# Verify we can read it back
|
|
content = test_file.read_text()
|
|
if content != "write_test":
|
|
raise StorageValidationError(
|
|
f"Write verification failed for directory: {self.base_dir}"
|
|
)
|
|
# Clean up
|
|
test_file.unlink()
|
|
logger.debug("Write permission check passed for %s", self.base_dir)
|
|
except PermissionError as e:
|
|
raise StorageValidationError(
|
|
f"No write permission for storage directory '{self.base_dir}': {e}"
|
|
)
|
|
except OSError as e:
|
|
raise StorageValidationError(
|
|
f"Failed to verify write permissions for '{self.base_dir}': {e}"
|
|
)
|
|
finally:
|
|
# Ensure test file is removed even on partial failure
|
|
if test_file.exists():
|
|
try:
|
|
test_file.unlink()
|
|
except Exception:
|
|
pass
|
|
|
|
def _detect_nas_storage(self) -> bool:
|
|
"""
|
|
Detect if the storage directory appears to be on a NAS mount.
|
|
|
|
This is a best-effort detection based on common mount point patterns.
|
|
"""
|
|
path_str = str(self.base_dir)
|
|
|
|
# Check common NAS mount point patterns
|
|
for indicator in self.NAS_MOUNT_INDICATORS:
|
|
if indicator in path_str:
|
|
logger.debug("NAS storage detected: path contains '%s'", indicator)
|
|
return True
|
|
|
|
# Check if it's a mount point (Unix-like systems)
|
|
try:
|
|
if self.base_dir.is_mount():
|
|
logger.debug("NAS storage detected: path is a mount point")
|
|
return True
|
|
except Exception:
|
|
pass
|
|
|
|
# Check mount info on Linux
|
|
try:
|
|
with open("/proc/mounts", "r") as f:
|
|
mounts = f.read()
|
|
if path_str in mounts:
|
|
# Check for network filesystem types
|
|
for line in mounts.splitlines():
|
|
if path_str in line:
|
|
fs_type = line.split()[2] if len(line.split()) > 2 else ""
|
|
if fs_type in ["nfs", "nfs4", "cifs", "smb", "smbfs"]:
|
|
logger.debug("NAS storage detected: mounted as %s", fs_type)
|
|
return True
|
|
except FileNotFoundError:
|
|
pass # Not on Linux
|
|
except Exception as e:
|
|
logger.debug("Could not check /proc/mounts: %s", e)
|
|
|
|
return False
|
|
|
|
def get_storage_status(self) -> dict:
|
|
"""Get current storage status for health checks."""
|
|
return {
|
|
**self._storage_status,
|
|
"base_dir": str(self.base_dir),
|
|
"exists": self.base_dir.exists(),
|
|
"is_directory": self.base_dir.is_dir() if self.base_dir.exists() else False,
|
|
}
|
|
|
|
def _ensure_base_dir(self):
|
|
"""Ensure the base upload directory exists."""
|
|
self.base_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
def _validate_path_component(self, component: str, component_name: str) -> None:
|
|
"""
|
|
Validate a path component to prevent path traversal attacks.
|
|
|
|
Args:
|
|
component: The path component to validate (e.g., project_id, task_id)
|
|
component_name: Name of the component for error messages
|
|
|
|
Raises:
|
|
PathTraversalError: If the component contains path traversal patterns
|
|
"""
|
|
if not component:
|
|
raise PathTraversalError(f"Empty {component_name} is not allowed")
|
|
|
|
# Check for path traversal patterns
|
|
dangerous_patterns = ['..', '/', '\\', '\x00']
|
|
for pattern in dangerous_patterns:
|
|
if pattern in component:
|
|
logger.warning(
|
|
"Path traversal attempt detected in %s: %r",
|
|
component_name,
|
|
component
|
|
)
|
|
raise PathTraversalError(
|
|
f"Invalid characters in {component_name}: path traversal not allowed"
|
|
)
|
|
|
|
# Additional check: component should not start with special characters
|
|
if component.startswith('.') or component.startswith('-'):
|
|
logger.warning(
|
|
"Suspicious path component in %s: %r",
|
|
component_name,
|
|
component
|
|
)
|
|
raise PathTraversalError(
|
|
f"Invalid {component_name}: cannot start with '.' or '-'"
|
|
)
|
|
|
|
def _validate_path_in_base_dir(self, path: Path, context: str = "") -> Path:
|
|
"""
|
|
Validate that a resolved path is within the base directory.
|
|
|
|
Args:
|
|
path: The path to validate
|
|
context: Additional context for logging
|
|
|
|
Returns:
|
|
The resolved path if valid
|
|
|
|
Raises:
|
|
PathTraversalError: If the path is outside the base directory
|
|
"""
|
|
resolved_path = path.resolve()
|
|
base_dir = self.base_dir.resolve()
|
|
|
|
# Check if the resolved path is within the base directory
|
|
try:
|
|
resolved_path.relative_to(base_dir)
|
|
except ValueError:
|
|
logger.warning(
|
|
"Path traversal attempt detected: path %s is outside base directory %s. Context: %s",
|
|
resolved_path,
|
|
base_dir,
|
|
context
|
|
)
|
|
raise PathTraversalError(
|
|
"Access denied: path is outside the allowed directory"
|
|
)
|
|
|
|
return resolved_path
|
|
|
|
def _get_file_path(self, project_id: str, task_id: str, attachment_id: str, version: int) -> Path:
|
|
"""
|
|
Generate the file path for an attachment version.
|
|
|
|
Args:
|
|
project_id: The project ID
|
|
task_id: The task ID
|
|
attachment_id: The attachment ID
|
|
version: The version number
|
|
|
|
Returns:
|
|
Safe path within the base directory
|
|
|
|
Raises:
|
|
PathTraversalError: If any component contains path traversal patterns
|
|
"""
|
|
# Validate all path components
|
|
self._validate_path_component(project_id, "project_id")
|
|
self._validate_path_component(task_id, "task_id")
|
|
self._validate_path_component(attachment_id, "attachment_id")
|
|
|
|
if version < 0:
|
|
raise PathTraversalError("Version must be non-negative")
|
|
|
|
# Build the path
|
|
path = self.base_dir / project_id / task_id / attachment_id / str(version)
|
|
|
|
# Validate the final path is within base directory
|
|
return self._validate_path_in_base_dir(
|
|
path,
|
|
f"project={project_id}, task={task_id}, attachment={attachment_id}, version={version}"
|
|
)
|
|
|
|
@staticmethod
|
|
def calculate_checksum(file: BinaryIO) -> str:
|
|
"""Calculate SHA-256 checksum of a file."""
|
|
sha256_hash = hashlib.sha256()
|
|
# Read in chunks to handle large files
|
|
for chunk in iter(lambda: file.read(8192), b""):
|
|
sha256_hash.update(chunk)
|
|
file.seek(0) # Reset file position
|
|
return sha256_hash.hexdigest()
|
|
|
|
@staticmethod
|
|
def get_extension(filename: str) -> str:
|
|
"""Get file extension in lowercase."""
|
|
return filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
|
|
|
@staticmethod
|
|
def validate_file(file: UploadFile) -> Tuple[str, str]:
|
|
"""
|
|
Validate file size and type.
|
|
Returns (extension, mime_type) if valid.
|
|
Raises HTTPException if invalid.
|
|
"""
|
|
# Check file size
|
|
file.file.seek(0, 2) # Seek to end
|
|
file_size = file.file.tell()
|
|
file.file.seek(0) # Reset
|
|
|
|
if file_size > settings.MAX_FILE_SIZE:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"File too large. Maximum size is {settings.MAX_FILE_SIZE_MB}MB"
|
|
)
|
|
|
|
if file_size == 0:
|
|
raise HTTPException(status_code=400, detail="Empty file not allowed")
|
|
|
|
# Get extension
|
|
extension = FileStorageService.get_extension(file.filename or "")
|
|
|
|
# Check blocked extensions
|
|
if extension in settings.BLOCKED_EXTENSIONS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"File type '.{extension}' is not allowed for security reasons"
|
|
)
|
|
|
|
# Check allowed extensions (if whitelist is enabled)
|
|
if settings.ALLOWED_EXTENSIONS and extension not in settings.ALLOWED_EXTENSIONS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"File type '.{extension}' is not supported"
|
|
)
|
|
|
|
mime_type = file.content_type or "application/octet-stream"
|
|
return extension, mime_type
|
|
|
|
async def save_file(
|
|
self,
|
|
file: UploadFile,
|
|
project_id: str,
|
|
task_id: str,
|
|
attachment_id: str,
|
|
version: int
|
|
) -> Tuple[str, int, str]:
|
|
"""
|
|
Save uploaded file to storage.
|
|
Returns (file_path, file_size, checksum).
|
|
|
|
Raises:
|
|
PathTraversalError: If path traversal is detected
|
|
HTTPException: If file validation fails
|
|
"""
|
|
# Validate file
|
|
extension, _ = self.validate_file(file)
|
|
|
|
# Calculate checksum first
|
|
checksum = self.calculate_checksum(file.file)
|
|
|
|
# Create directory structure (path validation is done in _get_file_path)
|
|
try:
|
|
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
|
|
except PathTraversalError as e:
|
|
logger.error("Path traversal attempt during file save: %s", e)
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
dir_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save file with original extension
|
|
filename = f"file.{extension}" if extension else "file"
|
|
file_path = dir_path / filename
|
|
|
|
# Final validation of the file path
|
|
self._validate_path_in_base_dir(file_path, f"saving file {filename}")
|
|
|
|
# Get file size
|
|
file.file.seek(0, 2)
|
|
file_size = file.file.tell()
|
|
file.file.seek(0)
|
|
|
|
# Write file in chunks (streaming)
|
|
with open(file_path, "wb") as buffer:
|
|
shutil.copyfileobj(file.file, buffer)
|
|
|
|
return str(file_path), file_size, checksum
|
|
|
|
def get_file(
|
|
self,
|
|
project_id: str,
|
|
task_id: str,
|
|
attachment_id: str,
|
|
version: int
|
|
) -> Optional[Path]:
|
|
"""
|
|
Get the file path for an attachment version.
|
|
Returns None if file doesn't exist.
|
|
|
|
Raises:
|
|
PathTraversalError: If path traversal is detected
|
|
"""
|
|
try:
|
|
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
|
|
except PathTraversalError as e:
|
|
logger.error("Path traversal attempt during file retrieval: %s", e)
|
|
return None
|
|
|
|
if not dir_path.exists():
|
|
return None
|
|
|
|
# Find the file in the directory
|
|
files = list(dir_path.iterdir())
|
|
if not files:
|
|
return None
|
|
|
|
return files[0]
|
|
|
|
def get_file_by_path(self, file_path: str) -> Optional[Path]:
|
|
"""
|
|
Get file by stored path. Handles both absolute and relative paths.
|
|
|
|
This method validates that the requested path is within the base directory
|
|
to prevent path traversal attacks.
|
|
|
|
Args:
|
|
file_path: The stored file path
|
|
|
|
Returns:
|
|
Path object if file exists and is within base directory, None otherwise
|
|
"""
|
|
if not file_path:
|
|
return None
|
|
|
|
path = Path(file_path)
|
|
|
|
# For absolute paths, validate they are within base_dir
|
|
if path.is_absolute():
|
|
try:
|
|
validated_path = self._validate_path_in_base_dir(
|
|
path,
|
|
f"get_file_by_path absolute: {file_path}"
|
|
)
|
|
if validated_path.exists():
|
|
return validated_path
|
|
except PathTraversalError:
|
|
return None
|
|
return None
|
|
|
|
# For relative paths, resolve from base_dir
|
|
full_path = self.base_dir / path
|
|
|
|
try:
|
|
validated_path = self._validate_path_in_base_dir(
|
|
full_path,
|
|
f"get_file_by_path relative: {file_path}"
|
|
)
|
|
if validated_path.exists():
|
|
return validated_path
|
|
except PathTraversalError:
|
|
return None
|
|
|
|
return None
|
|
|
|
def delete_file(
|
|
self,
|
|
project_id: str,
|
|
task_id: str,
|
|
attachment_id: str,
|
|
version: Optional[int] = None
|
|
) -> bool:
|
|
"""
|
|
Delete file(s) from storage.
|
|
If version is None, deletes all versions.
|
|
Returns True if successful.
|
|
|
|
Raises:
|
|
PathTraversalError: If path traversal is detected
|
|
"""
|
|
try:
|
|
if version is not None:
|
|
# Delete specific version
|
|
dir_path = self._get_file_path(project_id, task_id, attachment_id, version)
|
|
else:
|
|
# Delete all versions (attachment directory)
|
|
# Validate components first
|
|
self._validate_path_component(project_id, "project_id")
|
|
self._validate_path_component(task_id, "task_id")
|
|
self._validate_path_component(attachment_id, "attachment_id")
|
|
|
|
dir_path = self.base_dir / project_id / task_id / attachment_id
|
|
dir_path = self._validate_path_in_base_dir(
|
|
dir_path,
|
|
f"delete attachment: project={project_id}, task={task_id}, attachment={attachment_id}"
|
|
)
|
|
except PathTraversalError as e:
|
|
logger.error("Path traversal attempt during file deletion: %s", e)
|
|
return False
|
|
|
|
if dir_path.exists():
|
|
shutil.rmtree(dir_path)
|
|
return True
|
|
return False
|
|
|
|
def delete_task_files(self, project_id: str, task_id: str) -> bool:
|
|
"""
|
|
Delete all files for a task.
|
|
|
|
Raises:
|
|
PathTraversalError: If path traversal is detected
|
|
"""
|
|
try:
|
|
# Validate components
|
|
self._validate_path_component(project_id, "project_id")
|
|
self._validate_path_component(task_id, "task_id")
|
|
|
|
dir_path = self.base_dir / project_id / task_id
|
|
dir_path = self._validate_path_in_base_dir(
|
|
dir_path,
|
|
f"delete task files: project={project_id}, task={task_id}"
|
|
)
|
|
except PathTraversalError as e:
|
|
logger.error("Path traversal attempt during task file deletion: %s", e)
|
|
return False
|
|
|
|
if dir_path.exists():
|
|
shutil.rmtree(dir_path)
|
|
return True
|
|
return False
|
|
|
|
|
|
# Singleton instance
|
|
file_storage_service = FileStorageService()
|