import os import hashlib import shutil import logging import tempfile from pathlib import Path from typing import BinaryIO, Optional, Tuple from fastapi import UploadFile, HTTPException from app.core.config import settings logger = logging.getLogger(__name__) class PathTraversalError(Exception): """Raised when a path traversal attempt is detected.""" pass class StorageValidationError(Exception): """Raised when storage validation fails.""" pass class FileStorageService: """Service for handling file storage operations.""" # Common NAS mount points to detect NAS_MOUNT_INDICATORS = [ "/mnt/", "/mount/", "/nas/", "/nfs/", "/smb/", "/cifs/", "/Volumes/", "/media/", "/srv/", "/storage/" ] def __init__(self): self.base_dir = Path(settings.UPLOAD_DIR).resolve() # Backward-compatible attribute name for tests and older code self.upload_dir = self.base_dir self._storage_status = { "validated": False, "path_exists": False, "writable": False, "is_nas": False, "error": None, } self._validate_storage_on_init() def _validate_storage_on_init(self): """Validate storage configuration on service initialization.""" try: # Step 1: Ensure directory exists self._ensure_base_dir() self._storage_status["path_exists"] = True # Step 2: Check write permissions self._check_write_permissions() self._storage_status["writable"] = True # Step 3: Check if using NAS is_nas = self._detect_nas_storage() self._storage_status["is_nas"] = is_nas if not is_nas: logger.warning( "Storage directory '%s' appears to be local storage, not NAS. " "Consider configuring UPLOAD_DIR to a NAS mount point for production use.", self.base_dir ) self._storage_status["validated"] = True logger.info( "Storage validated successfully: path=%s, is_nas=%s", self.base_dir, is_nas ) except StorageValidationError as e: self._storage_status["error"] = str(e) logger.error("Storage validation failed: %s", e) raise except Exception as e: self._storage_status["error"] = str(e) logger.error("Unexpected error during storage validation: %s", e) raise StorageValidationError(f"Storage validation failed: {e}") def _check_write_permissions(self): """Check if the storage directory has write permissions.""" test_file = self.base_dir / f".write_test_{os.getpid()}" try: # Try to create and write to a test file test_file.write_text("write_test") # Verify we can read it back content = test_file.read_text() if content != "write_test": raise StorageValidationError( f"Write verification failed for directory: {self.base_dir}" ) # Clean up test_file.unlink() logger.debug("Write permission check passed for %s", self.base_dir) except PermissionError as e: raise StorageValidationError( f"No write permission for storage directory '{self.base_dir}': {e}" ) except OSError as e: raise StorageValidationError( f"Failed to verify write permissions for '{self.base_dir}': {e}" ) finally: # Ensure test file is removed even on partial failure if test_file.exists(): try: test_file.unlink() except Exception: pass def _detect_nas_storage(self) -> bool: """ Detect if the storage directory appears to be on a NAS mount. This is a best-effort detection based on common mount point patterns. """ path_str = str(self.base_dir) # Check common NAS mount point patterns for indicator in self.NAS_MOUNT_INDICATORS: if indicator in path_str: logger.debug("NAS storage detected: path contains '%s'", indicator) return True # Check if it's a mount point (Unix-like systems) try: if self.base_dir.is_mount(): logger.debug("NAS storage detected: path is a mount point") return True except Exception: pass # Check mount info on Linux try: with open("/proc/mounts", "r") as f: mounts = f.read() if path_str in mounts: # Check for network filesystem types for line in mounts.splitlines(): if path_str in line: fs_type = line.split()[2] if len(line.split()) > 2 else "" if fs_type in ["nfs", "nfs4", "cifs", "smb", "smbfs"]: logger.debug("NAS storage detected: mounted as %s", fs_type) return True except FileNotFoundError: pass # Not on Linux except Exception as e: logger.debug("Could not check /proc/mounts: %s", e) return False def get_storage_status(self) -> dict: """Get current storage status for health checks.""" return { **self._storage_status, "base_dir": str(self.base_dir), "exists": self.base_dir.exists(), "is_directory": self.base_dir.is_dir() if self.base_dir.exists() else False, } def _ensure_base_dir(self): """Ensure the base upload directory exists.""" self.base_dir.mkdir(parents=True, exist_ok=True) def _validate_path_component(self, component: str, component_name: str) -> None: """ Validate a path component to prevent path traversal attacks. Args: component: The path component to validate (e.g., project_id, task_id) component_name: Name of the component for error messages Raises: PathTraversalError: If the component contains path traversal patterns """ if not component: raise PathTraversalError(f"Empty {component_name} is not allowed") # Check for path traversal patterns dangerous_patterns = ['..', '/', '\\', '\x00'] for pattern in dangerous_patterns: if pattern in component: logger.warning( "Path traversal attempt detected in %s: %r", component_name, component ) raise PathTraversalError( f"Invalid characters in {component_name}: path traversal not allowed" ) # Additional check: component should not start with special characters if component.startswith('.') or component.startswith('-'): logger.warning( "Suspicious path component in %s: %r", component_name, component ) raise PathTraversalError( f"Invalid {component_name}: cannot start with '.' or '-'" ) def _validate_path_in_base_dir(self, path: Path, context: str = "") -> Path: """ Validate that a resolved path is within the base directory. Args: path: The path to validate context: Additional context for logging Returns: The resolved path if valid Raises: PathTraversalError: If the path is outside the base directory """ resolved_path = path.resolve() base_dir = self.base_dir.resolve() # Check if the resolved path is within the base directory try: resolved_path.relative_to(base_dir) except ValueError: logger.warning( "Path traversal attempt detected: path %s is outside base directory %s. Context: %s", resolved_path, base_dir, context ) raise PathTraversalError( "Access denied: path is outside the allowed directory" ) return resolved_path def _get_file_path(self, project_id: str, task_id: str, attachment_id: str, version: int) -> Path: """ Generate the file path for an attachment version. Args: project_id: The project ID task_id: The task ID attachment_id: The attachment ID version: The version number Returns: Safe path within the base directory Raises: PathTraversalError: If any component contains path traversal patterns """ # Validate all path components self._validate_path_component(project_id, "project_id") self._validate_path_component(task_id, "task_id") self._validate_path_component(attachment_id, "attachment_id") if version < 0: raise PathTraversalError("Version must be non-negative") # Build the path path = self.base_dir / project_id / task_id / attachment_id / str(version) # Validate the final path is within base directory return self._validate_path_in_base_dir( path, f"project={project_id}, task={task_id}, attachment={attachment_id}, version={version}" ) @staticmethod def calculate_checksum(file: BinaryIO) -> str: """Calculate SHA-256 checksum of a file.""" sha256_hash = hashlib.sha256() # Read in chunks to handle large files for chunk in iter(lambda: file.read(8192), b""): sha256_hash.update(chunk) file.seek(0) # Reset file position return sha256_hash.hexdigest() @staticmethod def get_extension(filename: str) -> str: """Get file extension in lowercase.""" return filename.rsplit(".", 1)[-1].lower() if "." in filename else "" @staticmethod def validate_file(file: UploadFile) -> Tuple[str, str]: """ Validate file size and type. Returns (extension, mime_type) if valid. Raises HTTPException if invalid. """ # Check file size file.file.seek(0, 2) # Seek to end file_size = file.file.tell() file.file.seek(0) # Reset if file_size > settings.MAX_FILE_SIZE: raise HTTPException( status_code=400, detail=f"File too large. Maximum size is {settings.MAX_FILE_SIZE_MB}MB" ) if file_size == 0: raise HTTPException(status_code=400, detail="Empty file not allowed") # Get extension extension = FileStorageService.get_extension(file.filename or "") # Check blocked extensions if extension in settings.BLOCKED_EXTENSIONS: raise HTTPException( status_code=400, detail=f"File type '.{extension}' is not allowed for security reasons" ) # Check allowed extensions (if whitelist is enabled) if settings.ALLOWED_EXTENSIONS and extension not in settings.ALLOWED_EXTENSIONS: raise HTTPException( status_code=400, detail=f"File type '.{extension}' is not supported" ) mime_type = file.content_type or "application/octet-stream" return extension, mime_type async def save_file( self, file: UploadFile, project_id: str, task_id: str, attachment_id: str, version: int ) -> Tuple[str, int, str]: """ Save uploaded file to storage. Returns (file_path, file_size, checksum). Raises: PathTraversalError: If path traversal is detected HTTPException: If file validation fails """ # Validate file extension, _ = self.validate_file(file) # Calculate checksum first checksum = self.calculate_checksum(file.file) # Create directory structure (path validation is done in _get_file_path) try: dir_path = self._get_file_path(project_id, task_id, attachment_id, version) except PathTraversalError as e: logger.error("Path traversal attempt during file save: %s", e) raise HTTPException(status_code=400, detail=str(e)) dir_path.mkdir(parents=True, exist_ok=True) # Save file with original extension filename = f"file.{extension}" if extension else "file" file_path = dir_path / filename # Final validation of the file path self._validate_path_in_base_dir(file_path, f"saving file {filename}") # Get file size file.file.seek(0, 2) file_size = file.file.tell() file.file.seek(0) # Write file in chunks (streaming) with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) return str(file_path), file_size, checksum def get_file( self, project_id: str, task_id: str, attachment_id: str, version: int ) -> Optional[Path]: """ Get the file path for an attachment version. Returns None if file doesn't exist. Raises: PathTraversalError: If path traversal is detected """ try: dir_path = self._get_file_path(project_id, task_id, attachment_id, version) except PathTraversalError as e: logger.error("Path traversal attempt during file retrieval: %s", e) return None if not dir_path.exists(): return None # Find the file in the directory files = list(dir_path.iterdir()) if not files: return None return files[0] def get_file_by_path(self, file_path: str) -> Optional[Path]: """ Get file by stored path. Handles both absolute and relative paths. This method validates that the requested path is within the base directory to prevent path traversal attacks. Args: file_path: The stored file path Returns: Path object if file exists and is within base directory, None otherwise """ if not file_path: return None path = Path(file_path) # For absolute paths, validate they are within base_dir if path.is_absolute(): try: validated_path = self._validate_path_in_base_dir( path, f"get_file_by_path absolute: {file_path}" ) if validated_path.exists(): return validated_path except PathTraversalError: return None return None # For relative paths, resolve from base_dir full_path = self.base_dir / path try: validated_path = self._validate_path_in_base_dir( full_path, f"get_file_by_path relative: {file_path}" ) if validated_path.exists(): return validated_path except PathTraversalError: return None return None def delete_file( self, project_id: str, task_id: str, attachment_id: str, version: Optional[int] = None ) -> bool: """ Delete file(s) from storage. If version is None, deletes all versions. Returns True if successful. Raises: PathTraversalError: If path traversal is detected """ try: if version is not None: # Delete specific version dir_path = self._get_file_path(project_id, task_id, attachment_id, version) else: # Delete all versions (attachment directory) # Validate components first self._validate_path_component(project_id, "project_id") self._validate_path_component(task_id, "task_id") self._validate_path_component(attachment_id, "attachment_id") dir_path = self.base_dir / project_id / task_id / attachment_id dir_path = self._validate_path_in_base_dir( dir_path, f"delete attachment: project={project_id}, task={task_id}, attachment={attachment_id}" ) except PathTraversalError as e: logger.error("Path traversal attempt during file deletion: %s", e) return False if dir_path.exists(): shutil.rmtree(dir_path) return True return False def delete_task_files(self, project_id: str, task_id: str) -> bool: """ Delete all files for a task. Raises: PathTraversalError: If path traversal is detected """ try: # Validate components self._validate_path_component(project_id, "project_id") self._validate_path_component(task_id, "task_id") dir_path = self.base_dir / project_id / task_id dir_path = self._validate_path_in_base_dir( dir_path, f"delete task files: project={project_id}, task={task_id}" ) except PathTraversalError as e: logger.error("Path traversal attempt during task file deletion: %s", e) return False if dir_path.exists(): shutil.rmtree(dir_path) return True return False # Singleton instance file_storage_service = FileStorageService()