421 lines
13 KiB
Python
421 lines
13 KiB
Python
"""
|
|
Tool_OCR - File Management Service
|
|
Handles file uploads, storage, validation, and cleanup
|
|
"""
|
|
|
|
import logging
|
|
import shutil
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Optional
|
|
from datetime import datetime, timedelta
|
|
|
|
from fastapi import UploadFile
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.config import settings
|
|
from app.models.ocr import OCRBatch, OCRFile, FileStatus
|
|
from app.services.preprocessor import DocumentPreprocessor
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FileManagementError(Exception):
|
|
"""Exception raised for file management errors"""
|
|
pass
|
|
|
|
|
|
class FileManager:
|
|
"""
|
|
File management service for upload, storage, and cleanup
|
|
|
|
Directory structure:
|
|
uploads/
|
|
├── batches/
|
|
│ └── {batch_id}/
|
|
│ ├── inputs/ # Original uploaded files
|
|
│ ├── outputs/ # OCR results
|
|
│ │ ├── markdown/ # Markdown files
|
|
│ │ ├── json/ # JSON files
|
|
│ │ └── images/ # Extracted images
|
|
│ └── exports/ # Export files (PDF, Excel, etc.)
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize file manager"""
|
|
self.preprocessor = DocumentPreprocessor()
|
|
self.base_upload_dir = Path(settings.upload_dir)
|
|
self.base_upload_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
def create_batch_directory(self, batch_id: int) -> Path:
|
|
"""
|
|
Create directory structure for a batch
|
|
|
|
Args:
|
|
batch_id: Batch ID
|
|
|
|
Returns:
|
|
Path: Batch directory path
|
|
"""
|
|
batch_dir = self.base_upload_dir / "batches" / str(batch_id)
|
|
|
|
# Create subdirectories
|
|
(batch_dir / "inputs").mkdir(parents=True, exist_ok=True)
|
|
(batch_dir / "outputs" / "markdown").mkdir(parents=True, exist_ok=True)
|
|
(batch_dir / "outputs" / "json").mkdir(parents=True, exist_ok=True)
|
|
(batch_dir / "outputs" / "images").mkdir(parents=True, exist_ok=True)
|
|
(batch_dir / "exports").mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info(f"Created batch directory: {batch_dir}")
|
|
return batch_dir
|
|
|
|
def get_batch_directory(self, batch_id: int) -> Path:
|
|
"""
|
|
Get batch directory path
|
|
|
|
Args:
|
|
batch_id: Batch ID
|
|
|
|
Returns:
|
|
Path: Batch directory path
|
|
"""
|
|
return self.base_upload_dir / "batches" / str(batch_id)
|
|
|
|
def validate_upload(self, file: UploadFile) -> Tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate uploaded file before saving
|
|
|
|
Args:
|
|
file: Uploaded file
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
"""
|
|
# Check filename
|
|
if not file.filename:
|
|
return False, "文件名不能為空"
|
|
|
|
# Check file size (read content size)
|
|
file.file.seek(0, 2) # Seek to end
|
|
file_size = file.file.tell()
|
|
file.file.seek(0) # Reset to beginning
|
|
|
|
if file_size == 0:
|
|
return False, "文件為空"
|
|
|
|
if file_size > settings.max_upload_size:
|
|
max_mb = settings.max_upload_size / (1024 * 1024)
|
|
return False, f"文件大小超過限制 ({max_mb}MB)"
|
|
|
|
# Check file extension
|
|
file_ext = Path(file.filename).suffix.lower()
|
|
allowed_extensions = {'.png', '.jpg', '.jpeg', '.pdf', '.doc', '.docx', '.ppt', '.pptx'}
|
|
if file_ext not in allowed_extensions:
|
|
return False, f"不支持的文件格式 ({file_ext}),僅支持: {', '.join(allowed_extensions)}"
|
|
|
|
return True, None
|
|
|
|
def save_upload(
|
|
self,
|
|
file: UploadFile,
|
|
batch_id: int,
|
|
validate: bool = True
|
|
) -> Tuple[Path, str]:
|
|
"""
|
|
Save uploaded file to batch directory
|
|
|
|
Args:
|
|
file: Uploaded file
|
|
batch_id: Batch ID
|
|
validate: Whether to validate file
|
|
|
|
Returns:
|
|
Tuple of (file_path, original_filename)
|
|
|
|
Raises:
|
|
FileManagementError: If file validation or saving fails
|
|
"""
|
|
# Validate if requested
|
|
if validate:
|
|
is_valid, error_msg = self.validate_upload(file)
|
|
if not is_valid:
|
|
raise FileManagementError(error_msg)
|
|
|
|
# Generate unique filename to avoid conflicts
|
|
original_filename = file.filename
|
|
file_ext = Path(original_filename).suffix
|
|
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
|
|
|
# Get batch input directory
|
|
batch_dir = self.get_batch_directory(batch_id)
|
|
input_dir = batch_dir / "inputs"
|
|
input_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save file
|
|
file_path = input_dir / unique_filename
|
|
try:
|
|
with file_path.open("wb") as buffer:
|
|
shutil.copyfileobj(file.file, buffer)
|
|
|
|
logger.info(f"Saved upload: {file_path} (original: {original_filename})")
|
|
return file_path, original_filename
|
|
|
|
except Exception as e:
|
|
# Clean up partial file if exists
|
|
file_path.unlink(missing_ok=True)
|
|
raise FileManagementError(f"保存文件失敗: {str(e)}")
|
|
|
|
def validate_saved_file(self, file_path: Path) -> Tuple[bool, Optional[str], Optional[str]]:
|
|
"""
|
|
Validate saved file using preprocessor
|
|
|
|
Args:
|
|
file_path: Path to saved file
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message, detected_format)
|
|
"""
|
|
return self.preprocessor.validate_file(file_path)
|
|
|
|
def create_batch(
|
|
self,
|
|
db: Session,
|
|
user_id: int,
|
|
batch_name: Optional[str] = None
|
|
) -> OCRBatch:
|
|
"""
|
|
Create new OCR batch
|
|
|
|
Args:
|
|
db: Database session
|
|
user_id: User ID
|
|
batch_name: Optional batch name
|
|
|
|
Returns:
|
|
OCRBatch: Created batch object
|
|
"""
|
|
# Create batch record
|
|
batch = OCRBatch(
|
|
user_id=user_id,
|
|
batch_name=batch_name or f"Batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
)
|
|
db.add(batch)
|
|
db.commit()
|
|
db.refresh(batch)
|
|
|
|
# Create directory structure
|
|
self.create_batch_directory(batch.id)
|
|
|
|
logger.info(f"Created batch: {batch.id} for user {user_id}")
|
|
return batch
|
|
|
|
def add_file_to_batch(
|
|
self,
|
|
db: Session,
|
|
batch_id: int,
|
|
file: UploadFile
|
|
) -> OCRFile:
|
|
"""
|
|
Add file to batch and save to disk
|
|
|
|
Args:
|
|
db: Database session
|
|
batch_id: Batch ID
|
|
file: Uploaded file
|
|
|
|
Returns:
|
|
OCRFile: Created file record
|
|
|
|
Raises:
|
|
FileManagementError: If file operations fail
|
|
"""
|
|
# Save file to disk
|
|
file_path, original_filename = self.save_upload(file, batch_id)
|
|
|
|
# Validate saved file
|
|
is_valid, detected_format, error_msg = self.validate_saved_file(file_path)
|
|
|
|
# Create file record
|
|
ocr_file = OCRFile(
|
|
batch_id=batch_id,
|
|
filename=file_path.name,
|
|
original_filename=original_filename,
|
|
file_path=str(file_path),
|
|
file_size=file_path.stat().st_size,
|
|
file_format=detected_format or Path(original_filename).suffix.lower().lstrip('.'),
|
|
status=FileStatus.PENDING if is_valid else FileStatus.FAILED,
|
|
error_message=error_msg if not is_valid else None
|
|
)
|
|
|
|
db.add(ocr_file)
|
|
|
|
# Update batch total_files count
|
|
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
|
|
if batch:
|
|
batch.total_files += 1
|
|
if not is_valid:
|
|
batch.failed_files += 1
|
|
|
|
db.commit()
|
|
db.refresh(ocr_file)
|
|
|
|
logger.info(f"Added file to batch {batch_id}: {ocr_file.id} (status: {ocr_file.status})")
|
|
return ocr_file
|
|
|
|
def add_files_to_batch(
|
|
self,
|
|
db: Session,
|
|
batch_id: int,
|
|
files: List[UploadFile]
|
|
) -> List[OCRFile]:
|
|
"""
|
|
Add multiple files to batch
|
|
|
|
Args:
|
|
db: Database session
|
|
batch_id: Batch ID
|
|
files: List of uploaded files
|
|
|
|
Returns:
|
|
List[OCRFile]: List of created file records
|
|
"""
|
|
ocr_files = []
|
|
for file in files:
|
|
try:
|
|
ocr_file = self.add_file_to_batch(db, batch_id, file)
|
|
ocr_files.append(ocr_file)
|
|
except FileManagementError as e:
|
|
logger.error(f"Failed to add file {file.filename} to batch {batch_id}: {e}")
|
|
# Continue with other files
|
|
continue
|
|
|
|
return ocr_files
|
|
|
|
def get_file_paths(self, batch_id: int, file_id: int) -> dict:
|
|
"""
|
|
Get all paths for a file in a batch
|
|
|
|
Args:
|
|
batch_id: Batch ID
|
|
file_id: File ID
|
|
|
|
Returns:
|
|
Dict containing all relevant paths
|
|
"""
|
|
batch_dir = self.get_batch_directory(batch_id)
|
|
|
|
return {
|
|
"input_dir": batch_dir / "inputs",
|
|
"output_dir": batch_dir / "outputs",
|
|
"markdown_dir": batch_dir / "outputs" / "markdown",
|
|
"json_dir": batch_dir / "outputs" / "json",
|
|
"images_dir": batch_dir / "outputs" / "images" / str(file_id),
|
|
"export_dir": batch_dir / "exports",
|
|
}
|
|
|
|
def cleanup_expired_batches(self, db: Session, retention_hours: int = 24) -> int:
|
|
"""
|
|
Clean up expired batch files
|
|
|
|
Args:
|
|
db: Database session
|
|
retention_hours: Number of hours to retain files
|
|
|
|
Returns:
|
|
int: Number of batches cleaned up
|
|
"""
|
|
cutoff_time = datetime.utcnow() - timedelta(hours=retention_hours)
|
|
|
|
# Find expired batches
|
|
expired_batches = db.query(OCRBatch).filter(
|
|
OCRBatch.created_at < cutoff_time
|
|
).all()
|
|
|
|
cleaned_count = 0
|
|
for batch in expired_batches:
|
|
try:
|
|
# Delete batch directory
|
|
batch_dir = self.get_batch_directory(batch.id)
|
|
if batch_dir.exists():
|
|
shutil.rmtree(batch_dir)
|
|
logger.info(f"Deleted batch directory: {batch_dir}")
|
|
|
|
# Delete database records (cascade will handle related records)
|
|
db.delete(batch)
|
|
cleaned_count += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to cleanup batch {batch.id}: {e}")
|
|
continue
|
|
|
|
if cleaned_count > 0:
|
|
db.commit()
|
|
logger.info(f"Cleaned up {cleaned_count} expired batches")
|
|
|
|
return cleaned_count
|
|
|
|
def verify_file_ownership(
|
|
self,
|
|
db: Session,
|
|
user_id: int,
|
|
batch_id: int
|
|
) -> bool:
|
|
"""
|
|
Verify user owns the batch
|
|
|
|
Args:
|
|
db: Database session
|
|
user_id: User ID
|
|
batch_id: Batch ID
|
|
|
|
Returns:
|
|
bool: True if user owns batch, False otherwise
|
|
"""
|
|
batch = db.query(OCRBatch).filter(
|
|
OCRBatch.id == batch_id,
|
|
OCRBatch.user_id == user_id
|
|
).first()
|
|
|
|
return batch is not None
|
|
|
|
def get_batch_statistics(self, db: Session, batch_id: int) -> dict:
|
|
"""
|
|
Get statistics for a batch
|
|
|
|
Args:
|
|
db: Database session
|
|
batch_id: Batch ID
|
|
|
|
Returns:
|
|
Dict containing batch statistics
|
|
"""
|
|
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
|
|
if not batch:
|
|
return {}
|
|
|
|
# Calculate total file size
|
|
total_size = sum(f.file_size for f in batch.files)
|
|
|
|
# Calculate processing time
|
|
processing_time = None
|
|
if batch.completed_at and batch.started_at:
|
|
processing_time = (batch.completed_at - batch.started_at).total_seconds()
|
|
|
|
return {
|
|
"batch_id": batch.id,
|
|
"batch_name": batch.batch_name,
|
|
"status": batch.status,
|
|
"total_files": batch.total_files,
|
|
"completed_files": batch.completed_files,
|
|
"failed_files": batch.failed_files,
|
|
"pending_files": batch.total_files - batch.completed_files - batch.failed_files,
|
|
"progress_percentage": batch.progress_percentage,
|
|
"total_file_size": total_size,
|
|
"total_file_size_mb": round(total_size / (1024 * 1024), 2),
|
|
"created_at": batch.created_at.isoformat(),
|
|
"started_at": batch.started_at.isoformat() if batch.started_at else None,
|
|
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
|
|
"processing_time": processing_time,
|
|
}
|