Files
OCR/backend/app/services/file_manager.py
beabigegg da700721fa first
2025-11-12 22:53:17 +08:00

421 lines
13 KiB
Python

"""
Tool_OCR - File Management Service
Handles file uploads, storage, validation, and cleanup
"""
import logging
import shutil
import uuid
from pathlib import Path
from typing import List, Tuple, Optional
from datetime import datetime, timedelta
from fastapi import UploadFile
from sqlalchemy.orm import Session
from app.core.config import settings
from app.models.ocr import OCRBatch, OCRFile, FileStatus
from app.services.preprocessor import DocumentPreprocessor
logger = logging.getLogger(__name__)
class FileManagementError(Exception):
"""Exception raised for file management errors"""
pass
class FileManager:
"""
File management service for upload, storage, and cleanup
Directory structure:
uploads/
├── batches/
│ └── {batch_id}/
│ ├── inputs/ # Original uploaded files
│ ├── outputs/ # OCR results
│ │ ├── markdown/ # Markdown files
│ │ ├── json/ # JSON files
│ │ └── images/ # Extracted images
│ └── exports/ # Export files (PDF, Excel, etc.)
"""
def __init__(self):
"""Initialize file manager"""
self.preprocessor = DocumentPreprocessor()
self.base_upload_dir = Path(settings.upload_dir)
self.base_upload_dir.mkdir(parents=True, exist_ok=True)
def create_batch_directory(self, batch_id: int) -> Path:
"""
Create directory structure for a batch
Args:
batch_id: Batch ID
Returns:
Path: Batch directory path
"""
batch_dir = self.base_upload_dir / "batches" / str(batch_id)
# Create subdirectories
(batch_dir / "inputs").mkdir(parents=True, exist_ok=True)
(batch_dir / "outputs" / "markdown").mkdir(parents=True, exist_ok=True)
(batch_dir / "outputs" / "json").mkdir(parents=True, exist_ok=True)
(batch_dir / "outputs" / "images").mkdir(parents=True, exist_ok=True)
(batch_dir / "exports").mkdir(parents=True, exist_ok=True)
logger.info(f"Created batch directory: {batch_dir}")
return batch_dir
def get_batch_directory(self, batch_id: int) -> Path:
"""
Get batch directory path
Args:
batch_id: Batch ID
Returns:
Path: Batch directory path
"""
return self.base_upload_dir / "batches" / str(batch_id)
def validate_upload(self, file: UploadFile) -> Tuple[bool, Optional[str]]:
"""
Validate uploaded file before saving
Args:
file: Uploaded file
Returns:
Tuple of (is_valid, error_message)
"""
# Check filename
if not file.filename:
return False, "文件名不能為空"
# Check file size (read content size)
file.file.seek(0, 2) # Seek to end
file_size = file.file.tell()
file.file.seek(0) # Reset to beginning
if file_size == 0:
return False, "文件為空"
if file_size > settings.max_upload_size:
max_mb = settings.max_upload_size / (1024 * 1024)
return False, f"文件大小超過限制 ({max_mb}MB)"
# Check file extension
file_ext = Path(file.filename).suffix.lower()
allowed_extensions = {'.png', '.jpg', '.jpeg', '.pdf', '.doc', '.docx', '.ppt', '.pptx'}
if file_ext not in allowed_extensions:
return False, f"不支持的文件格式 ({file_ext}),僅支持: {', '.join(allowed_extensions)}"
return True, None
def save_upload(
self,
file: UploadFile,
batch_id: int,
validate: bool = True
) -> Tuple[Path, str]:
"""
Save uploaded file to batch directory
Args:
file: Uploaded file
batch_id: Batch ID
validate: Whether to validate file
Returns:
Tuple of (file_path, original_filename)
Raises:
FileManagementError: If file validation or saving fails
"""
# Validate if requested
if validate:
is_valid, error_msg = self.validate_upload(file)
if not is_valid:
raise FileManagementError(error_msg)
# Generate unique filename to avoid conflicts
original_filename = file.filename
file_ext = Path(original_filename).suffix
unique_filename = f"{uuid.uuid4()}{file_ext}"
# Get batch input directory
batch_dir = self.get_batch_directory(batch_id)
input_dir = batch_dir / "inputs"
input_dir.mkdir(parents=True, exist_ok=True)
# Save file
file_path = input_dir / unique_filename
try:
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
logger.info(f"Saved upload: {file_path} (original: {original_filename})")
return file_path, original_filename
except Exception as e:
# Clean up partial file if exists
file_path.unlink(missing_ok=True)
raise FileManagementError(f"保存文件失敗: {str(e)}")
def validate_saved_file(self, file_path: Path) -> Tuple[bool, Optional[str], Optional[str]]:
"""
Validate saved file using preprocessor
Args:
file_path: Path to saved file
Returns:
Tuple of (is_valid, error_message, detected_format)
"""
return self.preprocessor.validate_file(file_path)
def create_batch(
self,
db: Session,
user_id: int,
batch_name: Optional[str] = None
) -> OCRBatch:
"""
Create new OCR batch
Args:
db: Database session
user_id: User ID
batch_name: Optional batch name
Returns:
OCRBatch: Created batch object
"""
# Create batch record
batch = OCRBatch(
user_id=user_id,
batch_name=batch_name or f"Batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
)
db.add(batch)
db.commit()
db.refresh(batch)
# Create directory structure
self.create_batch_directory(batch.id)
logger.info(f"Created batch: {batch.id} for user {user_id}")
return batch
def add_file_to_batch(
self,
db: Session,
batch_id: int,
file: UploadFile
) -> OCRFile:
"""
Add file to batch and save to disk
Args:
db: Database session
batch_id: Batch ID
file: Uploaded file
Returns:
OCRFile: Created file record
Raises:
FileManagementError: If file operations fail
"""
# Save file to disk
file_path, original_filename = self.save_upload(file, batch_id)
# Validate saved file
is_valid, detected_format, error_msg = self.validate_saved_file(file_path)
# Create file record
ocr_file = OCRFile(
batch_id=batch_id,
filename=file_path.name,
original_filename=original_filename,
file_path=str(file_path),
file_size=file_path.stat().st_size,
file_format=detected_format or Path(original_filename).suffix.lower().lstrip('.'),
status=FileStatus.PENDING if is_valid else FileStatus.FAILED,
error_message=error_msg if not is_valid else None
)
db.add(ocr_file)
# Update batch total_files count
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
if batch:
batch.total_files += 1
if not is_valid:
batch.failed_files += 1
db.commit()
db.refresh(ocr_file)
logger.info(f"Added file to batch {batch_id}: {ocr_file.id} (status: {ocr_file.status})")
return ocr_file
def add_files_to_batch(
self,
db: Session,
batch_id: int,
files: List[UploadFile]
) -> List[OCRFile]:
"""
Add multiple files to batch
Args:
db: Database session
batch_id: Batch ID
files: List of uploaded files
Returns:
List[OCRFile]: List of created file records
"""
ocr_files = []
for file in files:
try:
ocr_file = self.add_file_to_batch(db, batch_id, file)
ocr_files.append(ocr_file)
except FileManagementError as e:
logger.error(f"Failed to add file {file.filename} to batch {batch_id}: {e}")
# Continue with other files
continue
return ocr_files
def get_file_paths(self, batch_id: int, file_id: int) -> dict:
"""
Get all paths for a file in a batch
Args:
batch_id: Batch ID
file_id: File ID
Returns:
Dict containing all relevant paths
"""
batch_dir = self.get_batch_directory(batch_id)
return {
"input_dir": batch_dir / "inputs",
"output_dir": batch_dir / "outputs",
"markdown_dir": batch_dir / "outputs" / "markdown",
"json_dir": batch_dir / "outputs" / "json",
"images_dir": batch_dir / "outputs" / "images" / str(file_id),
"export_dir": batch_dir / "exports",
}
def cleanup_expired_batches(self, db: Session, retention_hours: int = 24) -> int:
"""
Clean up expired batch files
Args:
db: Database session
retention_hours: Number of hours to retain files
Returns:
int: Number of batches cleaned up
"""
cutoff_time = datetime.utcnow() - timedelta(hours=retention_hours)
# Find expired batches
expired_batches = db.query(OCRBatch).filter(
OCRBatch.created_at < cutoff_time
).all()
cleaned_count = 0
for batch in expired_batches:
try:
# Delete batch directory
batch_dir = self.get_batch_directory(batch.id)
if batch_dir.exists():
shutil.rmtree(batch_dir)
logger.info(f"Deleted batch directory: {batch_dir}")
# Delete database records (cascade will handle related records)
db.delete(batch)
cleaned_count += 1
except Exception as e:
logger.error(f"Failed to cleanup batch {batch.id}: {e}")
continue
if cleaned_count > 0:
db.commit()
logger.info(f"Cleaned up {cleaned_count} expired batches")
return cleaned_count
def verify_file_ownership(
self,
db: Session,
user_id: int,
batch_id: int
) -> bool:
"""
Verify user owns the batch
Args:
db: Database session
user_id: User ID
batch_id: Batch ID
Returns:
bool: True if user owns batch, False otherwise
"""
batch = db.query(OCRBatch).filter(
OCRBatch.id == batch_id,
OCRBatch.user_id == user_id
).first()
return batch is not None
def get_batch_statistics(self, db: Session, batch_id: int) -> dict:
"""
Get statistics for a batch
Args:
db: Database session
batch_id: Batch ID
Returns:
Dict containing batch statistics
"""
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
if not batch:
return {}
# Calculate total file size
total_size = sum(f.file_size for f in batch.files)
# Calculate processing time
processing_time = None
if batch.completed_at and batch.started_at:
processing_time = (batch.completed_at - batch.started_at).total_seconds()
return {
"batch_id": batch.id,
"batch_name": batch.batch_name,
"status": batch.status,
"total_files": batch.total_files,
"completed_files": batch.completed_files,
"failed_files": batch.failed_files,
"pending_files": batch.total_files - batch.completed_files - batch.failed_files,
"progress_percentage": batch.progress_percentage,
"total_file_size": total_size,
"total_file_size_mb": round(total_size / (1024 * 1024), 2),
"created_at": batch.created_at.isoformat(),
"started_at": batch.started_at.isoformat() if batch.started_at else None,
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
"processing_time": processing_time,
}