first
This commit is contained in:
244
backend/app/routers/ocr.py
Normal file
244
backend/app/routers/ocr.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Tool_OCR - OCR Router
|
||||
File upload, OCR processing, and status endpoints
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.deps import get_db, get_current_active_user
|
||||
from app.models.user import User
|
||||
from app.models.ocr import OCRBatch, OCRFile, OCRResult, BatchStatus, FileStatus
|
||||
from app.schemas.ocr import (
|
||||
OCRBatchResponse,
|
||||
BatchStatusResponse,
|
||||
FileStatusResponse,
|
||||
OCRResultDetailResponse,
|
||||
UploadBatchResponse,
|
||||
ProcessRequest,
|
||||
ProcessResponse,
|
||||
)
|
||||
from app.services.file_manager import FileManager, FileManagementError
|
||||
from app.services.ocr_service import OCRService
|
||||
from app.services.background_tasks import process_batch_files_with_retry
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1", tags=["OCR"])
|
||||
|
||||
# Initialize services
|
||||
file_manager = FileManager()
|
||||
ocr_service = OCRService()
|
||||
|
||||
|
||||
@router.post("/upload", response_model=UploadBatchResponse, summary="Upload files for OCR")
|
||||
async def upload_files(
|
||||
files: List[UploadFile] = File(..., description="Files to upload (PNG, JPG, PDF)"),
|
||||
batch_name: str = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
Upload files for OCR processing
|
||||
|
||||
Creates a new batch and uploads files to it
|
||||
|
||||
- **files**: List of files to upload (PNG, JPG, JPEG, PDF)
|
||||
- **batch_name**: Optional name for the batch
|
||||
"""
|
||||
if not files:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No files provided"
|
||||
)
|
||||
|
||||
try:
|
||||
# Create batch
|
||||
batch = file_manager.create_batch(db, current_user.id, batch_name)
|
||||
|
||||
# Upload files
|
||||
uploaded_files = file_manager.add_files_to_batch(db, batch.id, files)
|
||||
|
||||
logger.info(f"Uploaded {len(uploaded_files)} files to batch {batch.id} for user {current_user.id}")
|
||||
|
||||
# Refresh batch to get updated counts
|
||||
db.refresh(batch)
|
||||
|
||||
# Return response matching frontend expectations
|
||||
return {
|
||||
"batch_id": batch.id,
|
||||
"files": uploaded_files
|
||||
}
|
||||
|
||||
except FileManagementError as e:
|
||||
logger.error(f"File upload error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during upload: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to upload files"
|
||||
)
|
||||
|
||||
|
||||
# NOTE: process_batch_files function moved to app.services.background_tasks
|
||||
# Now using process_batch_files_with_retry with retry logic
|
||||
|
||||
@router.post("/ocr/process", response_model=ProcessResponse, summary="Trigger OCR processing")
|
||||
async def process_ocr(
|
||||
request: ProcessRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
Trigger OCR processing for a batch
|
||||
|
||||
Starts background processing of all files in the batch
|
||||
|
||||
- **batch_id**: Batch ID to process
|
||||
- **lang**: Language code (ch, en, japan, korean)
|
||||
- **detect_layout**: Enable layout detection
|
||||
"""
|
||||
# Verify batch ownership
|
||||
batch = db.query(OCRBatch).filter(
|
||||
OCRBatch.id == request.batch_id,
|
||||
OCRBatch.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not batch:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Batch not found"
|
||||
)
|
||||
|
||||
if batch.status != BatchStatus.PENDING:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Batch is already {batch.status.value}"
|
||||
)
|
||||
|
||||
# Start background processing with retry logic
|
||||
background_tasks.add_task(
|
||||
process_batch_files_with_retry,
|
||||
batch_id=batch.id,
|
||||
lang=request.lang,
|
||||
detect_layout=request.detect_layout,
|
||||
db=SessionLocal() # Create new session for background task
|
||||
)
|
||||
|
||||
logger.info(f"Started OCR processing for batch {batch.id}")
|
||||
|
||||
return {
|
||||
"message": "OCR processing started",
|
||||
"batch_id": batch.id,
|
||||
"total_files": batch.total_files,
|
||||
"status": "processing"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/batch/{batch_id}/status", response_model=BatchStatusResponse, summary="Get batch status")
|
||||
async def get_batch_status(
|
||||
batch_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
Get batch processing status
|
||||
|
||||
Returns batch information and all files in the batch
|
||||
|
||||
- **batch_id**: Batch ID
|
||||
"""
|
||||
# Verify batch ownership
|
||||
batch = db.query(OCRBatch).filter(
|
||||
OCRBatch.id == batch_id,
|
||||
OCRBatch.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not batch:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Batch not found"
|
||||
)
|
||||
|
||||
# Get all files in batch
|
||||
files = db.query(OCRFile).filter(OCRFile.batch_id == batch_id).all()
|
||||
|
||||
return {
|
||||
"batch": batch,
|
||||
"files": files
|
||||
}
|
||||
|
||||
|
||||
@router.get("/ocr/result/{file_id}", response_model=OCRResultDetailResponse, summary="Get OCR result")
|
||||
async def get_ocr_result(
|
||||
file_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
Get OCR result for a file
|
||||
|
||||
Returns flattened file and OCR result information for frontend preview
|
||||
|
||||
- **file_id**: File ID
|
||||
"""
|
||||
# Get file
|
||||
ocr_file = db.query(OCRFile).join(OCRBatch).filter(
|
||||
OCRFile.id == file_id,
|
||||
OCRBatch.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not ocr_file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found"
|
||||
)
|
||||
|
||||
# Get result if exists
|
||||
result = db.query(OCRResult).filter(OCRResult.file_id == file_id).first()
|
||||
|
||||
# Read markdown content if result exists
|
||||
markdown_content = None
|
||||
if result and result.markdown_path:
|
||||
markdown_file = Path(result.markdown_path)
|
||||
if markdown_file.exists():
|
||||
try:
|
||||
markdown_content = markdown_file.read_text(encoding='utf-8')
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read markdown file {result.markdown_path}: {e}")
|
||||
|
||||
# Build JSON data from result if available
|
||||
json_data = None
|
||||
if result:
|
||||
json_data = {
|
||||
"total_text_regions": result.total_text_regions,
|
||||
"average_confidence": result.average_confidence,
|
||||
"detected_language": result.detected_language,
|
||||
"layout_data": result.layout_data,
|
||||
"images_metadata": result.images_metadata,
|
||||
}
|
||||
|
||||
# Return flattened structure matching frontend expectations
|
||||
return {
|
||||
"file_id": ocr_file.id,
|
||||
"filename": ocr_file.filename,
|
||||
"status": ocr_file.status.value,
|
||||
"markdown_content": markdown_content,
|
||||
"json_data": json_data,
|
||||
"confidence": result.average_confidence if result else None,
|
||||
"processing_time": ocr_file.processing_time,
|
||||
}
|
||||
|
||||
|
||||
# Import SessionLocal for background tasks
|
||||
from app.core.database import SessionLocal
|
||||
Reference in New Issue
Block a user