245 lines
7.2 KiB
Python
245 lines
7.2 KiB
Python
"""
|
|
Tool_OCR - OCR Router
|
|
File upload, OCR processing, and status endpoints
|
|
"""
|
|
|
|
import logging
|
|
from typing import List
|
|
from pathlib import Path
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, BackgroundTasks
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.deps import get_db, get_current_active_user
|
|
from app.models.user import User
|
|
from app.models.ocr import OCRBatch, OCRFile, OCRResult, BatchStatus, FileStatus
|
|
from app.schemas.ocr import (
|
|
OCRBatchResponse,
|
|
BatchStatusResponse,
|
|
FileStatusResponse,
|
|
OCRResultDetailResponse,
|
|
UploadBatchResponse,
|
|
ProcessRequest,
|
|
ProcessResponse,
|
|
)
|
|
from app.services.file_manager import FileManager, FileManagementError
|
|
from app.services.ocr_service import OCRService
|
|
from app.services.background_tasks import process_batch_files_with_retry
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/v1", tags=["OCR"])
|
|
|
|
# Initialize services
|
|
file_manager = FileManager()
|
|
ocr_service = OCRService()
|
|
|
|
|
|
@router.post("/upload", response_model=UploadBatchResponse, summary="Upload files for OCR")
|
|
async def upload_files(
|
|
files: List[UploadFile] = File(..., description="Files to upload (PNG, JPG, PDF)"),
|
|
batch_name: str = None,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_active_user)
|
|
):
|
|
"""
|
|
Upload files for OCR processing
|
|
|
|
Creates a new batch and uploads files to it
|
|
|
|
- **files**: List of files to upload (PNG, JPG, JPEG, PDF)
|
|
- **batch_name**: Optional name for the batch
|
|
"""
|
|
if not files:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="No files provided"
|
|
)
|
|
|
|
try:
|
|
# Create batch
|
|
batch = file_manager.create_batch(db, current_user.id, batch_name)
|
|
|
|
# Upload files
|
|
uploaded_files = file_manager.add_files_to_batch(db, batch.id, files)
|
|
|
|
logger.info(f"Uploaded {len(uploaded_files)} files to batch {batch.id} for user {current_user.id}")
|
|
|
|
# Refresh batch to get updated counts
|
|
db.refresh(batch)
|
|
|
|
# Return response matching frontend expectations
|
|
return {
|
|
"batch_id": batch.id,
|
|
"files": uploaded_files
|
|
}
|
|
|
|
except FileManagementError as e:
|
|
logger.error(f"File upload error: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error during upload: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to upload files"
|
|
)
|
|
|
|
|
|
# NOTE: process_batch_files function moved to app.services.background_tasks
|
|
# Now using process_batch_files_with_retry with retry logic
|
|
|
|
@router.post("/ocr/process", response_model=ProcessResponse, summary="Trigger OCR processing")
|
|
async def process_ocr(
|
|
request: ProcessRequest,
|
|
background_tasks: BackgroundTasks,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_active_user)
|
|
):
|
|
"""
|
|
Trigger OCR processing for a batch
|
|
|
|
Starts background processing of all files in the batch
|
|
|
|
- **batch_id**: Batch ID to process
|
|
- **lang**: Language code (ch, en, japan, korean)
|
|
- **detect_layout**: Enable layout detection
|
|
"""
|
|
# Verify batch ownership
|
|
batch = db.query(OCRBatch).filter(
|
|
OCRBatch.id == request.batch_id,
|
|
OCRBatch.user_id == current_user.id
|
|
).first()
|
|
|
|
if not batch:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Batch not found"
|
|
)
|
|
|
|
if batch.status != BatchStatus.PENDING:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Batch is already {batch.status.value}"
|
|
)
|
|
|
|
# Start background processing with retry logic
|
|
background_tasks.add_task(
|
|
process_batch_files_with_retry,
|
|
batch_id=batch.id,
|
|
lang=request.lang,
|
|
detect_layout=request.detect_layout,
|
|
db=SessionLocal() # Create new session for background task
|
|
)
|
|
|
|
logger.info(f"Started OCR processing for batch {batch.id}")
|
|
|
|
return {
|
|
"message": "OCR processing started",
|
|
"batch_id": batch.id,
|
|
"total_files": batch.total_files,
|
|
"status": "processing"
|
|
}
|
|
|
|
|
|
@router.get("/batch/{batch_id}/status", response_model=BatchStatusResponse, summary="Get batch status")
|
|
async def get_batch_status(
|
|
batch_id: int,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_active_user)
|
|
):
|
|
"""
|
|
Get batch processing status
|
|
|
|
Returns batch information and all files in the batch
|
|
|
|
- **batch_id**: Batch ID
|
|
"""
|
|
# Verify batch ownership
|
|
batch = db.query(OCRBatch).filter(
|
|
OCRBatch.id == batch_id,
|
|
OCRBatch.user_id == current_user.id
|
|
).first()
|
|
|
|
if not batch:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Batch not found"
|
|
)
|
|
|
|
# Get all files in batch
|
|
files = db.query(OCRFile).filter(OCRFile.batch_id == batch_id).all()
|
|
|
|
return {
|
|
"batch": batch,
|
|
"files": files
|
|
}
|
|
|
|
|
|
@router.get("/ocr/result/{file_id}", response_model=OCRResultDetailResponse, summary="Get OCR result")
|
|
async def get_ocr_result(
|
|
file_id: int,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_active_user)
|
|
):
|
|
"""
|
|
Get OCR result for a file
|
|
|
|
Returns flattened file and OCR result information for frontend preview
|
|
|
|
- **file_id**: File ID
|
|
"""
|
|
# Get file
|
|
ocr_file = db.query(OCRFile).join(OCRBatch).filter(
|
|
OCRFile.id == file_id,
|
|
OCRBatch.user_id == current_user.id
|
|
).first()
|
|
|
|
if not ocr_file:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="File not found"
|
|
)
|
|
|
|
# Get result if exists
|
|
result = db.query(OCRResult).filter(OCRResult.file_id == file_id).first()
|
|
|
|
# Read markdown content if result exists
|
|
markdown_content = None
|
|
if result and result.markdown_path:
|
|
markdown_file = Path(result.markdown_path)
|
|
if markdown_file.exists():
|
|
try:
|
|
markdown_content = markdown_file.read_text(encoding='utf-8')
|
|
except Exception as e:
|
|
logger.warning(f"Failed to read markdown file {result.markdown_path}: {e}")
|
|
|
|
# Build JSON data from result if available
|
|
json_data = None
|
|
if result:
|
|
json_data = {
|
|
"total_text_regions": result.total_text_regions,
|
|
"average_confidence": result.average_confidence,
|
|
"detected_language": result.detected_language,
|
|
"layout_data": result.layout_data,
|
|
"images_metadata": result.images_metadata,
|
|
}
|
|
|
|
# Return flattened structure matching frontend expectations
|
|
return {
|
|
"file_id": ocr_file.id,
|
|
"filename": ocr_file.filename,
|
|
"status": ocr_file.status.value,
|
|
"markdown_content": markdown_content,
|
|
"json_data": json_data,
|
|
"confidence": result.average_confidence if result else None,
|
|
"processing_time": ocr_file.processing_time,
|
|
}
|
|
|
|
|
|
# Import SessionLocal for background tasks
|
|
from app.core.database import SessionLocal
|