first
This commit is contained in:
420
backend/app/services/file_manager.py
Normal file
420
backend/app/services/file_manager.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Tool_OCR - File Management Service
|
||||
Handles file uploads, storage, validation, and cleanup
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.ocr import OCRBatch, OCRFile, FileStatus
|
||||
from app.services.preprocessor import DocumentPreprocessor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileManagementError(Exception):
|
||||
"""Exception raised for file management errors"""
|
||||
pass
|
||||
|
||||
|
||||
class FileManager:
|
||||
"""
|
||||
File management service for upload, storage, and cleanup
|
||||
|
||||
Directory structure:
|
||||
uploads/
|
||||
├── batches/
|
||||
│ └── {batch_id}/
|
||||
│ ├── inputs/ # Original uploaded files
|
||||
│ ├── outputs/ # OCR results
|
||||
│ │ ├── markdown/ # Markdown files
|
||||
│ │ ├── json/ # JSON files
|
||||
│ │ └── images/ # Extracted images
|
||||
│ └── exports/ # Export files (PDF, Excel, etc.)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize file manager"""
|
||||
self.preprocessor = DocumentPreprocessor()
|
||||
self.base_upload_dir = Path(settings.upload_dir)
|
||||
self.base_upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def create_batch_directory(self, batch_id: int) -> Path:
|
||||
"""
|
||||
Create directory structure for a batch
|
||||
|
||||
Args:
|
||||
batch_id: Batch ID
|
||||
|
||||
Returns:
|
||||
Path: Batch directory path
|
||||
"""
|
||||
batch_dir = self.base_upload_dir / "batches" / str(batch_id)
|
||||
|
||||
# Create subdirectories
|
||||
(batch_dir / "inputs").mkdir(parents=True, exist_ok=True)
|
||||
(batch_dir / "outputs" / "markdown").mkdir(parents=True, exist_ok=True)
|
||||
(batch_dir / "outputs" / "json").mkdir(parents=True, exist_ok=True)
|
||||
(batch_dir / "outputs" / "images").mkdir(parents=True, exist_ok=True)
|
||||
(batch_dir / "exports").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Created batch directory: {batch_dir}")
|
||||
return batch_dir
|
||||
|
||||
def get_batch_directory(self, batch_id: int) -> Path:
|
||||
"""
|
||||
Get batch directory path
|
||||
|
||||
Args:
|
||||
batch_id: Batch ID
|
||||
|
||||
Returns:
|
||||
Path: Batch directory path
|
||||
"""
|
||||
return self.base_upload_dir / "batches" / str(batch_id)
|
||||
|
||||
def validate_upload(self, file: UploadFile) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate uploaded file before saving
|
||||
|
||||
Args:
|
||||
file: Uploaded file
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
# Check filename
|
||||
if not file.filename:
|
||||
return False, "文件名不能為空"
|
||||
|
||||
# Check file size (read content size)
|
||||
file.file.seek(0, 2) # Seek to end
|
||||
file_size = file.file.tell()
|
||||
file.file.seek(0) # Reset to beginning
|
||||
|
||||
if file_size == 0:
|
||||
return False, "文件為空"
|
||||
|
||||
if file_size > settings.max_upload_size:
|
||||
max_mb = settings.max_upload_size / (1024 * 1024)
|
||||
return False, f"文件大小超過限制 ({max_mb}MB)"
|
||||
|
||||
# Check file extension
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
allowed_extensions = {'.png', '.jpg', '.jpeg', '.pdf', '.doc', '.docx', '.ppt', '.pptx'}
|
||||
if file_ext not in allowed_extensions:
|
||||
return False, f"不支持的文件格式 ({file_ext}),僅支持: {', '.join(allowed_extensions)}"
|
||||
|
||||
return True, None
|
||||
|
||||
def save_upload(
|
||||
self,
|
||||
file: UploadFile,
|
||||
batch_id: int,
|
||||
validate: bool = True
|
||||
) -> Tuple[Path, str]:
|
||||
"""
|
||||
Save uploaded file to batch directory
|
||||
|
||||
Args:
|
||||
file: Uploaded file
|
||||
batch_id: Batch ID
|
||||
validate: Whether to validate file
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, original_filename)
|
||||
|
||||
Raises:
|
||||
FileManagementError: If file validation or saving fails
|
||||
"""
|
||||
# Validate if requested
|
||||
if validate:
|
||||
is_valid, error_msg = self.validate_upload(file)
|
||||
if not is_valid:
|
||||
raise FileManagementError(error_msg)
|
||||
|
||||
# Generate unique filename to avoid conflicts
|
||||
original_filename = file.filename
|
||||
file_ext = Path(original_filename).suffix
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
|
||||
# Get batch input directory
|
||||
batch_dir = self.get_batch_directory(batch_id)
|
||||
input_dir = batch_dir / "inputs"
|
||||
input_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save file
|
||||
file_path = input_dir / unique_filename
|
||||
try:
|
||||
with file_path.open("wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
logger.info(f"Saved upload: {file_path} (original: {original_filename})")
|
||||
return file_path, original_filename
|
||||
|
||||
except Exception as e:
|
||||
# Clean up partial file if exists
|
||||
file_path.unlink(missing_ok=True)
|
||||
raise FileManagementError(f"保存文件失敗: {str(e)}")
|
||||
|
||||
def validate_saved_file(self, file_path: Path) -> Tuple[bool, Optional[str], Optional[str]]:
|
||||
"""
|
||||
Validate saved file using preprocessor
|
||||
|
||||
Args:
|
||||
file_path: Path to saved file
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message, detected_format)
|
||||
"""
|
||||
return self.preprocessor.validate_file(file_path)
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
db: Session,
|
||||
user_id: int,
|
||||
batch_name: Optional[str] = None
|
||||
) -> OCRBatch:
|
||||
"""
|
||||
Create new OCR batch
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
batch_name: Optional batch name
|
||||
|
||||
Returns:
|
||||
OCRBatch: Created batch object
|
||||
"""
|
||||
# Create batch record
|
||||
batch = OCRBatch(
|
||||
user_id=user_id,
|
||||
batch_name=batch_name or f"Batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
db.add(batch)
|
||||
db.commit()
|
||||
db.refresh(batch)
|
||||
|
||||
# Create directory structure
|
||||
self.create_batch_directory(batch.id)
|
||||
|
||||
logger.info(f"Created batch: {batch.id} for user {user_id}")
|
||||
return batch
|
||||
|
||||
def add_file_to_batch(
|
||||
self,
|
||||
db: Session,
|
||||
batch_id: int,
|
||||
file: UploadFile
|
||||
) -> OCRFile:
|
||||
"""
|
||||
Add file to batch and save to disk
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
batch_id: Batch ID
|
||||
file: Uploaded file
|
||||
|
||||
Returns:
|
||||
OCRFile: Created file record
|
||||
|
||||
Raises:
|
||||
FileManagementError: If file operations fail
|
||||
"""
|
||||
# Save file to disk
|
||||
file_path, original_filename = self.save_upload(file, batch_id)
|
||||
|
||||
# Validate saved file
|
||||
is_valid, detected_format, error_msg = self.validate_saved_file(file_path)
|
||||
|
||||
# Create file record
|
||||
ocr_file = OCRFile(
|
||||
batch_id=batch_id,
|
||||
filename=file_path.name,
|
||||
original_filename=original_filename,
|
||||
file_path=str(file_path),
|
||||
file_size=file_path.stat().st_size,
|
||||
file_format=detected_format or Path(original_filename).suffix.lower().lstrip('.'),
|
||||
status=FileStatus.PENDING if is_valid else FileStatus.FAILED,
|
||||
error_message=error_msg if not is_valid else None
|
||||
)
|
||||
|
||||
db.add(ocr_file)
|
||||
|
||||
# Update batch total_files count
|
||||
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
|
||||
if batch:
|
||||
batch.total_files += 1
|
||||
if not is_valid:
|
||||
batch.failed_files += 1
|
||||
|
||||
db.commit()
|
||||
db.refresh(ocr_file)
|
||||
|
||||
logger.info(f"Added file to batch {batch_id}: {ocr_file.id} (status: {ocr_file.status})")
|
||||
return ocr_file
|
||||
|
||||
def add_files_to_batch(
|
||||
self,
|
||||
db: Session,
|
||||
batch_id: int,
|
||||
files: List[UploadFile]
|
||||
) -> List[OCRFile]:
|
||||
"""
|
||||
Add multiple files to batch
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
batch_id: Batch ID
|
||||
files: List of uploaded files
|
||||
|
||||
Returns:
|
||||
List[OCRFile]: List of created file records
|
||||
"""
|
||||
ocr_files = []
|
||||
for file in files:
|
||||
try:
|
||||
ocr_file = self.add_file_to_batch(db, batch_id, file)
|
||||
ocr_files.append(ocr_file)
|
||||
except FileManagementError as e:
|
||||
logger.error(f"Failed to add file {file.filename} to batch {batch_id}: {e}")
|
||||
# Continue with other files
|
||||
continue
|
||||
|
||||
return ocr_files
|
||||
|
||||
def get_file_paths(self, batch_id: int, file_id: int) -> dict:
|
||||
"""
|
||||
Get all paths for a file in a batch
|
||||
|
||||
Args:
|
||||
batch_id: Batch ID
|
||||
file_id: File ID
|
||||
|
||||
Returns:
|
||||
Dict containing all relevant paths
|
||||
"""
|
||||
batch_dir = self.get_batch_directory(batch_id)
|
||||
|
||||
return {
|
||||
"input_dir": batch_dir / "inputs",
|
||||
"output_dir": batch_dir / "outputs",
|
||||
"markdown_dir": batch_dir / "outputs" / "markdown",
|
||||
"json_dir": batch_dir / "outputs" / "json",
|
||||
"images_dir": batch_dir / "outputs" / "images" / str(file_id),
|
||||
"export_dir": batch_dir / "exports",
|
||||
}
|
||||
|
||||
def cleanup_expired_batches(self, db: Session, retention_hours: int = 24) -> int:
|
||||
"""
|
||||
Clean up expired batch files
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
retention_hours: Number of hours to retain files
|
||||
|
||||
Returns:
|
||||
int: Number of batches cleaned up
|
||||
"""
|
||||
cutoff_time = datetime.utcnow() - timedelta(hours=retention_hours)
|
||||
|
||||
# Find expired batches
|
||||
expired_batches = db.query(OCRBatch).filter(
|
||||
OCRBatch.created_at < cutoff_time
|
||||
).all()
|
||||
|
||||
cleaned_count = 0
|
||||
for batch in expired_batches:
|
||||
try:
|
||||
# Delete batch directory
|
||||
batch_dir = self.get_batch_directory(batch.id)
|
||||
if batch_dir.exists():
|
||||
shutil.rmtree(batch_dir)
|
||||
logger.info(f"Deleted batch directory: {batch_dir}")
|
||||
|
||||
# Delete database records (cascade will handle related records)
|
||||
db.delete(batch)
|
||||
cleaned_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup batch {batch.id}: {e}")
|
||||
continue
|
||||
|
||||
if cleaned_count > 0:
|
||||
db.commit()
|
||||
logger.info(f"Cleaned up {cleaned_count} expired batches")
|
||||
|
||||
return cleaned_count
|
||||
|
||||
def verify_file_ownership(
|
||||
self,
|
||||
db: Session,
|
||||
user_id: int,
|
||||
batch_id: int
|
||||
) -> bool:
|
||||
"""
|
||||
Verify user owns the batch
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
batch_id: Batch ID
|
||||
|
||||
Returns:
|
||||
bool: True if user owns batch, False otherwise
|
||||
"""
|
||||
batch = db.query(OCRBatch).filter(
|
||||
OCRBatch.id == batch_id,
|
||||
OCRBatch.user_id == user_id
|
||||
).first()
|
||||
|
||||
return batch is not None
|
||||
|
||||
def get_batch_statistics(self, db: Session, batch_id: int) -> dict:
|
||||
"""
|
||||
Get statistics for a batch
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
batch_id: Batch ID
|
||||
|
||||
Returns:
|
||||
Dict containing batch statistics
|
||||
"""
|
||||
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
|
||||
if not batch:
|
||||
return {}
|
||||
|
||||
# Calculate total file size
|
||||
total_size = sum(f.file_size for f in batch.files)
|
||||
|
||||
# Calculate processing time
|
||||
processing_time = None
|
||||
if batch.completed_at and batch.started_at:
|
||||
processing_time = (batch.completed_at - batch.started_at).total_seconds()
|
||||
|
||||
return {
|
||||
"batch_id": batch.id,
|
||||
"batch_name": batch.batch_name,
|
||||
"status": batch.status,
|
||||
"total_files": batch.total_files,
|
||||
"completed_files": batch.completed_files,
|
||||
"failed_files": batch.failed_files,
|
||||
"pending_files": batch.total_files - batch.completed_files - batch.failed_files,
|
||||
"progress_percentage": batch.progress_percentage,
|
||||
"total_file_size": total_size,
|
||||
"total_file_size_mb": round(total_size / (1024 * 1024), 2),
|
||||
"created_at": batch.created_at.isoformat(),
|
||||
"started_at": batch.started_at.isoformat() if batch.started_at else None,
|
||||
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
|
||||
"processing_time": processing_time,
|
||||
}
|
||||
Reference in New Issue
Block a user