feat: implement actual OCR processing in start_task endpoint
Changes: - Add process_task_ocr background function to execute OCR processing - Initialize OCRService and process uploaded file - Save OCR results to JSON and Markdown files - Update task status to COMPLETED/FAILED based on processing outcome - Use FastAPI BackgroundTasks for async processing - Direct database updates in background task (bypass user isolation) Features: - Real OCR processing with GPU/CPU acceleration - Processing time tracking - Error handling and status updates - Result files saved in task-specific directories Fixes: - Task status stuck in PROCESSING (no actual OCR execution) - No CPU/GPU utilization during "processing" 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -9,9 +9,11 @@ from pathlib import Path
|
|||||||
import shutil
|
import shutil
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File
|
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, BackgroundTasks
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from app.core.deps import get_db, get_current_user
|
from app.core.deps import get_db, get_current_user
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
@@ -29,12 +31,96 @@ from app.schemas.task import (
|
|||||||
)
|
)
|
||||||
from app.services.task_service import task_service
|
from app.services.task_service import task_service
|
||||||
from app.services.file_access_service import file_access_service
|
from app.services.file_access_service import file_access_service
|
||||||
|
from app.services.ocr_service import OCRService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v2/tasks", tags=["Tasks"])
|
router = APIRouter(prefix="/api/v2/tasks", tags=["Tasks"])
|
||||||
|
|
||||||
|
|
||||||
|
def process_task_ocr(task_id: str, task_db_id: int, file_path: str, filename: str):
|
||||||
|
"""
|
||||||
|
Background task to process OCR for a task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task UUID string
|
||||||
|
task_db_id: Task database ID
|
||||||
|
file_path: Path to uploaded file
|
||||||
|
filename: Original filename
|
||||||
|
"""
|
||||||
|
from app.core.database import SessionLocal
|
||||||
|
from app.models.task import Task
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Starting OCR processing for task {task_id}, file: {filename}")
|
||||||
|
|
||||||
|
# Get task directly by database ID (bypass user isolation for background task)
|
||||||
|
task = db.query(Task).filter(Task.id == task_db_id).first()
|
||||||
|
if not task:
|
||||||
|
logger.error(f"Task {task_id} not found in database")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize OCR service
|
||||||
|
ocr_service = OCRService()
|
||||||
|
|
||||||
|
# Process the file with OCR
|
||||||
|
ocr_result = ocr_service.process_image(
|
||||||
|
image_path=Path(file_path),
|
||||||
|
lang='ch',
|
||||||
|
detect_layout=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate processing time
|
||||||
|
processing_time_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||||
|
|
||||||
|
# Create result directory
|
||||||
|
result_dir = Path(settings.result_dir) / task_id
|
||||||
|
result_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save JSON result
|
||||||
|
json_path = result_dir / f"{Path(filename).stem}_result.json"
|
||||||
|
with open(json_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(ocr_result, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# Save Markdown result
|
||||||
|
markdown_path = result_dir / f"{Path(filename).stem}_result.md"
|
||||||
|
markdown_content = ocr_result.get('markdown', '')
|
||||||
|
with open(markdown_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(markdown_content)
|
||||||
|
|
||||||
|
# Update task with results (direct database update)
|
||||||
|
task.result_json_path = str(json_path)
|
||||||
|
task.result_markdown_path = str(markdown_path)
|
||||||
|
task.processing_time_ms = processing_time_ms
|
||||||
|
task.status = TaskStatus.COMPLETED
|
||||||
|
task.completed_at = datetime.utcnow()
|
||||||
|
task.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"OCR processing completed for task {task_id} in {processing_time_ms}ms")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"OCR processing failed for task {task_id}")
|
||||||
|
|
||||||
|
# Update task status to failed (direct database update)
|
||||||
|
try:
|
||||||
|
task = db.query(Task).filter(Task.id == task_db_id).first()
|
||||||
|
if task:
|
||||||
|
task.status = TaskStatus.FAILED
|
||||||
|
task.error_message = str(e)
|
||||||
|
task.updated_at = datetime.utcnow()
|
||||||
|
db.commit()
|
||||||
|
except Exception as update_error:
|
||||||
|
logger.error(f"Failed to update task status: {update_error}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=TaskResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/", response_model=TaskResponse, status_code=status.HTTP_201_CREATED)
|
||||||
async def create_task(
|
async def create_task(
|
||||||
task_data: TaskCreate,
|
task_data: TaskCreate,
|
||||||
@@ -425,6 +511,7 @@ async def download_pdf(
|
|||||||
@router.post("/{task_id}/start", response_model=TaskResponse, summary="Start task processing")
|
@router.post("/{task_id}/start", response_model=TaskResponse, summary="Start task processing")
|
||||||
async def start_task(
|
async def start_task(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
@@ -434,11 +521,11 @@ async def start_task(
|
|||||||
- **task_id**: Task UUID
|
- **task_id**: Task UUID
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
task = task_service.update_task_status(
|
# Get task details
|
||||||
|
task = task_service.get_task_by_id(
|
||||||
db=db,
|
db=db,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id
|
||||||
status=TaskStatus.PROCESSING
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not task:
|
if not task:
|
||||||
@@ -447,7 +534,39 @@ async def start_task(
|
|||||||
detail="Task not found"
|
detail="Task not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Started task {task_id} for user {current_user.email}")
|
# Check if task is in pending status
|
||||||
|
if task.status != TaskStatus.PENDING:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Cannot start task in '{task.status.value}' status"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get task file
|
||||||
|
task_file = db.query(TaskFile).filter(TaskFile.task_id == task.id).first()
|
||||||
|
if not task_file:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Task file not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update task status to processing
|
||||||
|
task = task_service.update_task_status(
|
||||||
|
db=db,
|
||||||
|
task_id=task_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
status=TaskStatus.PROCESSING
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start OCR processing in background
|
||||||
|
background_tasks.add_task(
|
||||||
|
process_task_ocr,
|
||||||
|
task_id=task_id,
|
||||||
|
task_db_id=task.id,
|
||||||
|
file_path=task_file.stored_path,
|
||||||
|
filename=task_file.original_name
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Started OCR processing task {task_id} for user {current_user.email}")
|
||||||
return task
|
return task
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
Reference in New Issue
Block a user