""" Tool_OCR - File Management Service Handles file uploads, storage, validation, and cleanup """ import logging import shutil import uuid from pathlib import Path from typing import List, Tuple, Optional from datetime import datetime, timedelta from fastapi import UploadFile from sqlalchemy.orm import Session from app.core.config import settings from app.models.ocr import OCRBatch, OCRFile, FileStatus from app.services.preprocessor import DocumentPreprocessor logger = logging.getLogger(__name__) class FileManagementError(Exception): """Exception raised for file management errors""" pass class FileManager: """ File management service for upload, storage, and cleanup Directory structure: uploads/ ├── batches/ │ └── {batch_id}/ │ ├── inputs/ # Original uploaded files │ ├── outputs/ # OCR results │ │ ├── markdown/ # Markdown files │ │ ├── json/ # JSON files │ │ └── images/ # Extracted images │ └── exports/ # Export files (PDF, Excel, etc.) """ def __init__(self): """Initialize file manager""" self.preprocessor = DocumentPreprocessor() self.base_upload_dir = Path(settings.upload_dir) self.base_upload_dir.mkdir(parents=True, exist_ok=True) def create_batch_directory(self, batch_id: int) -> Path: """ Create directory structure for a batch Args: batch_id: Batch ID Returns: Path: Batch directory path """ batch_dir = self.base_upload_dir / "batches" / str(batch_id) # Create subdirectories (batch_dir / "inputs").mkdir(parents=True, exist_ok=True) (batch_dir / "outputs" / "markdown").mkdir(parents=True, exist_ok=True) (batch_dir / "outputs" / "json").mkdir(parents=True, exist_ok=True) (batch_dir / "outputs" / "images").mkdir(parents=True, exist_ok=True) (batch_dir / "exports").mkdir(parents=True, exist_ok=True) logger.info(f"Created batch directory: {batch_dir}") return batch_dir def get_batch_directory(self, batch_id: int) -> Path: """ Get batch directory path Args: batch_id: Batch ID Returns: Path: Batch directory path """ return self.base_upload_dir / "batches" / str(batch_id) def validate_upload(self, file: UploadFile) -> Tuple[bool, Optional[str]]: """ Validate uploaded file before saving Args: file: Uploaded file Returns: Tuple of (is_valid, error_message) """ # Check filename if not file.filename: return False, "文件名不能為空" # Check file size (read content size) file.file.seek(0, 2) # Seek to end file_size = file.file.tell() file.file.seek(0) # Reset to beginning if file_size == 0: return False, "文件為空" if file_size > settings.max_upload_size: max_mb = settings.max_upload_size / (1024 * 1024) return False, f"文件大小超過限制 ({max_mb}MB)" # Check file extension file_ext = Path(file.filename).suffix.lower() allowed_extensions = {'.png', '.jpg', '.jpeg', '.pdf', '.doc', '.docx', '.ppt', '.pptx'} if file_ext not in allowed_extensions: return False, f"不支持的文件格式 ({file_ext}),僅支持: {', '.join(allowed_extensions)}" return True, None def save_upload( self, file: UploadFile, batch_id: int, validate: bool = True ) -> Tuple[Path, str]: """ Save uploaded file to batch directory Args: file: Uploaded file batch_id: Batch ID validate: Whether to validate file Returns: Tuple of (file_path, original_filename) Raises: FileManagementError: If file validation or saving fails """ # Validate if requested if validate: is_valid, error_msg = self.validate_upload(file) if not is_valid: raise FileManagementError(error_msg) # Generate unique filename to avoid conflicts original_filename = file.filename file_ext = Path(original_filename).suffix unique_filename = f"{uuid.uuid4()}{file_ext}" # Get batch input directory batch_dir = self.get_batch_directory(batch_id) input_dir = batch_dir / "inputs" input_dir.mkdir(parents=True, exist_ok=True) # Save file file_path = input_dir / unique_filename try: with file_path.open("wb") as buffer: shutil.copyfileobj(file.file, buffer) logger.info(f"Saved upload: {file_path} (original: {original_filename})") return file_path, original_filename except Exception as e: # Clean up partial file if exists file_path.unlink(missing_ok=True) raise FileManagementError(f"保存文件失敗: {str(e)}") def validate_saved_file(self, file_path: Path) -> Tuple[bool, Optional[str], Optional[str]]: """ Validate saved file using preprocessor Args: file_path: Path to saved file Returns: Tuple of (is_valid, error_message, detected_format) """ return self.preprocessor.validate_file(file_path) def create_batch( self, db: Session, user_id: int, batch_name: Optional[str] = None ) -> OCRBatch: """ Create new OCR batch Args: db: Database session user_id: User ID batch_name: Optional batch name Returns: OCRBatch: Created batch object """ # Create batch record batch = OCRBatch( user_id=user_id, batch_name=batch_name or f"Batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}" ) db.add(batch) db.commit() db.refresh(batch) # Create directory structure self.create_batch_directory(batch.id) logger.info(f"Created batch: {batch.id} for user {user_id}") return batch def add_file_to_batch( self, db: Session, batch_id: int, file: UploadFile ) -> OCRFile: """ Add file to batch and save to disk Args: db: Database session batch_id: Batch ID file: Uploaded file Returns: OCRFile: Created file record Raises: FileManagementError: If file operations fail """ # Save file to disk file_path, original_filename = self.save_upload(file, batch_id) # Validate saved file is_valid, detected_format, error_msg = self.validate_saved_file(file_path) # Create file record ocr_file = OCRFile( batch_id=batch_id, filename=file_path.name, original_filename=original_filename, file_path=str(file_path), file_size=file_path.stat().st_size, file_format=detected_format or Path(original_filename).suffix.lower().lstrip('.'), status=FileStatus.PENDING if is_valid else FileStatus.FAILED, error_message=error_msg if not is_valid else None ) db.add(ocr_file) # Update batch total_files count batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first() if batch: batch.total_files += 1 if not is_valid: batch.failed_files += 1 db.commit() db.refresh(ocr_file) logger.info(f"Added file to batch {batch_id}: {ocr_file.id} (status: {ocr_file.status})") return ocr_file def add_files_to_batch( self, db: Session, batch_id: int, files: List[UploadFile] ) -> List[OCRFile]: """ Add multiple files to batch Args: db: Database session batch_id: Batch ID files: List of uploaded files Returns: List[OCRFile]: List of created file records """ ocr_files = [] for file in files: try: ocr_file = self.add_file_to_batch(db, batch_id, file) ocr_files.append(ocr_file) except FileManagementError as e: logger.error(f"Failed to add file {file.filename} to batch {batch_id}: {e}") # Continue with other files continue return ocr_files def get_file_paths(self, batch_id: int, file_id: int) -> dict: """ Get all paths for a file in a batch Args: batch_id: Batch ID file_id: File ID Returns: Dict containing all relevant paths """ batch_dir = self.get_batch_directory(batch_id) return { "input_dir": batch_dir / "inputs", "output_dir": batch_dir / "outputs", "markdown_dir": batch_dir / "outputs" / "markdown", "json_dir": batch_dir / "outputs" / "json", "images_dir": batch_dir / "outputs" / "images" / str(file_id), "export_dir": batch_dir / "exports", } def cleanup_expired_batches(self, db: Session, retention_hours: int = 24) -> int: """ Clean up expired batch files Args: db: Database session retention_hours: Number of hours to retain files Returns: int: Number of batches cleaned up """ cutoff_time = datetime.utcnow() - timedelta(hours=retention_hours) # Find expired batches expired_batches = db.query(OCRBatch).filter( OCRBatch.created_at < cutoff_time ).all() cleaned_count = 0 for batch in expired_batches: try: # Delete batch directory batch_dir = self.get_batch_directory(batch.id) if batch_dir.exists(): shutil.rmtree(batch_dir) logger.info(f"Deleted batch directory: {batch_dir}") # Delete database records (cascade will handle related records) db.delete(batch) cleaned_count += 1 except Exception as e: logger.error(f"Failed to cleanup batch {batch.id}: {e}") continue if cleaned_count > 0: db.commit() logger.info(f"Cleaned up {cleaned_count} expired batches") return cleaned_count def verify_file_ownership( self, db: Session, user_id: int, batch_id: int ) -> bool: """ Verify user owns the batch Args: db: Database session user_id: User ID batch_id: Batch ID Returns: bool: True if user owns batch, False otherwise """ batch = db.query(OCRBatch).filter( OCRBatch.id == batch_id, OCRBatch.user_id == user_id ).first() return batch is not None def get_batch_statistics(self, db: Session, batch_id: int) -> dict: """ Get statistics for a batch Args: db: Database session batch_id: Batch ID Returns: Dict containing batch statistics """ batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first() if not batch: return {} # Calculate total file size total_size = sum(f.file_size for f in batch.files) # Calculate processing time processing_time = None if batch.completed_at and batch.started_at: processing_time = (batch.completed_at - batch.started_at).total_seconds() return { "batch_id": batch.id, "batch_name": batch.batch_name, "status": batch.status, "total_files": batch.total_files, "completed_files": batch.completed_files, "failed_files": batch.failed_files, "pending_files": batch.total_files - batch.completed_files - batch.failed_files, "progress_percentage": batch.progress_percentage, "total_file_size": total_size, "total_file_size_mb": round(total_size / (1024 * 1024), 2), "created_at": batch.created_at.isoformat(), "started_at": batch.started_at.isoformat() if batch.started_at else None, "completed_at": batch.completed_at.isoformat() if batch.completed_at else None, "processing_time": processing_time, }