feat: add dual-track API endpoints for document processing
- Add ProcessingTrackEnum, ProcessingOptions, ProcessingMetadata schemas - Add DocumentAnalysisResponse for document type detection - Update /start endpoint with dual-track query parameters - Add /analyze endpoint for document type detection with confidence scores - Add /metadata endpoint for processing track information - Add /download/unified endpoint for UnifiedDocument format export - Update tasks.md to mark Section 6 API updates as completed 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -28,25 +28,50 @@ from app.schemas.task import (
|
||||
TaskStatsResponse,
|
||||
TaskStatusEnum,
|
||||
UploadResponse,
|
||||
ProcessingTrackEnum,
|
||||
ProcessingOptions,
|
||||
AnalyzeRequest,
|
||||
DocumentAnalysisResponse,
|
||||
ProcessingMetadata,
|
||||
TaskResponseWithMetadata,
|
||||
ExportOptions,
|
||||
)
|
||||
from app.services.task_service import task_service
|
||||
from app.services.file_access_service import file_access_service
|
||||
from app.services.ocr_service import OCRService
|
||||
|
||||
# Import dual-track components
|
||||
try:
|
||||
from app.services.document_type_detector import DocumentTypeDetector
|
||||
DUAL_TRACK_AVAILABLE = True
|
||||
except ImportError:
|
||||
DUAL_TRACK_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v2/tasks", tags=["Tasks"])
|
||||
|
||||
|
||||
def process_task_ocr(task_id: str, task_db_id: int, file_path: str, filename: str):
|
||||
def process_task_ocr(
|
||||
task_id: str,
|
||||
task_db_id: int,
|
||||
file_path: str,
|
||||
filename: str,
|
||||
use_dual_track: bool = True,
|
||||
force_track: Optional[str] = None,
|
||||
language: str = 'ch'
|
||||
):
|
||||
"""
|
||||
Background task to process OCR for a task
|
||||
Background task to process OCR for a task with dual-track support
|
||||
|
||||
Args:
|
||||
task_id: Task UUID string
|
||||
task_db_id: Task database ID
|
||||
file_path: Path to uploaded file
|
||||
filename: Original filename
|
||||
use_dual_track: Enable dual-track processing
|
||||
force_track: Force specific track ('ocr' or 'direct')
|
||||
language: OCR language code
|
||||
"""
|
||||
from app.core.database import SessionLocal
|
||||
from app.models.task import Task
|
||||
@@ -56,6 +81,7 @@ def process_task_ocr(task_id: str, task_db_id: int, file_path: str, filename: st
|
||||
|
||||
try:
|
||||
logger.info(f"Starting OCR processing for task {task_id}, file: {filename}")
|
||||
logger.info(f"Processing options: dual_track={use_dual_track}, force_track={force_track}, lang={language}")
|
||||
|
||||
# Get task directly by database ID (bypass user isolation for background task)
|
||||
task = db.query(Task).filter(Task.id == task_db_id).first()
|
||||
@@ -70,13 +96,25 @@ def process_task_ocr(task_id: str, task_db_id: int, file_path: str, filename: st
|
||||
result_dir = Path(settings.result_dir) / task_id
|
||||
result_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process the file with OCR
|
||||
ocr_result = ocr_service.process_image(
|
||||
image_path=Path(file_path),
|
||||
lang='ch',
|
||||
detect_layout=True,
|
||||
output_dir=result_dir
|
||||
)
|
||||
# Process the file with OCR (use dual-track if available)
|
||||
if use_dual_track and hasattr(ocr_service, 'process'):
|
||||
# Use new dual-track processing
|
||||
ocr_result = ocr_service.process(
|
||||
file_path=Path(file_path),
|
||||
lang=language,
|
||||
detect_layout=True,
|
||||
output_dir=result_dir,
|
||||
use_dual_track=use_dual_track,
|
||||
force_track=force_track
|
||||
)
|
||||
else:
|
||||
# Fall back to traditional processing
|
||||
ocr_result = ocr_service.process_image(
|
||||
image_path=Path(file_path),
|
||||
lang=language,
|
||||
detect_layout=True,
|
||||
output_dir=result_dir
|
||||
)
|
||||
|
||||
# Calculate processing time
|
||||
processing_time_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
@@ -574,13 +612,19 @@ async def download_pdf(
|
||||
async def start_task(
|
||||
task_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
use_dual_track: bool = Query(True, description="Enable dual-track processing"),
|
||||
force_track: Optional[str] = Query(None, description="Force track: 'ocr' or 'direct'"),
|
||||
language: str = Query("ch", description="OCR language code"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Start processing a pending task
|
||||
Start processing a pending task with dual-track support
|
||||
|
||||
- **task_id**: Task UUID
|
||||
- **use_dual_track**: Enable intelligent track selection (default: true)
|
||||
- **force_track**: Force specific processing track ('ocr' or 'direct')
|
||||
- **language**: OCR language code (default: 'ch')
|
||||
"""
|
||||
try:
|
||||
# Get task details
|
||||
@@ -619,16 +663,20 @@ async def start_task(
|
||||
status=TaskStatus.PROCESSING
|
||||
)
|
||||
|
||||
# Start OCR processing in background
|
||||
# Start OCR processing in background with dual-track parameters
|
||||
background_tasks.add_task(
|
||||
process_task_ocr,
|
||||
task_id=task_id,
|
||||
task_db_id=task.id,
|
||||
file_path=task_file.stored_path,
|
||||
filename=task_file.original_name
|
||||
filename=task_file.original_name,
|
||||
use_dual_track=use_dual_track,
|
||||
force_track=force_track,
|
||||
language=language
|
||||
)
|
||||
|
||||
logger.info(f"Started OCR processing task {task_id} for user {current_user.email}")
|
||||
logger.info(f"Options: dual_track={use_dual_track}, force_track={force_track}, lang={language}")
|
||||
return task
|
||||
|
||||
except HTTPException:
|
||||
@@ -747,3 +795,226 @@ async def retry_task(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retry task: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ===== Document Analysis Endpoints =====
|
||||
|
||||
@router.post("/{task_id}/analyze", response_model=DocumentAnalysisResponse, summary="Analyze document type")
|
||||
async def analyze_document(
|
||||
task_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Analyze document to determine recommended processing track
|
||||
|
||||
Returns document type analysis with recommended processing track
|
||||
(OCR for scanned documents, DIRECT for editable PDFs)
|
||||
|
||||
- **task_id**: Task UUID
|
||||
"""
|
||||
try:
|
||||
if not DUAL_TRACK_AVAILABLE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Dual-track processing not available"
|
||||
)
|
||||
|
||||
# Get task details
|
||||
task = task_service.get_task_by_id(
|
||||
db=db,
|
||||
task_id=task_id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Task not found"
|
||||
)
|
||||
|
||||
# Get task file
|
||||
task_file = db.query(TaskFile).filter(TaskFile.task_id == task.id).first()
|
||||
if not task_file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Task file not found"
|
||||
)
|
||||
|
||||
# Analyze document
|
||||
detector = DocumentTypeDetector()
|
||||
recommendation = detector.analyze(Path(task_file.stored_path))
|
||||
|
||||
# Build response
|
||||
response = DocumentAnalysisResponse(
|
||||
task_id=task_id,
|
||||
filename=task_file.original_name or "",
|
||||
recommended_track=ProcessingTrackEnum(recommendation.track),
|
||||
confidence=recommendation.confidence,
|
||||
reason=recommendation.reason,
|
||||
document_info=recommendation.document_info or {},
|
||||
is_editable=recommendation.track == "direct",
|
||||
text_coverage=recommendation.document_info.get("text_coverage") if recommendation.document_info else None,
|
||||
page_count=recommendation.document_info.get("page_count") if recommendation.document_info else None
|
||||
)
|
||||
|
||||
logger.info(f"Document analysis for task {task_id}: {recommendation.track} (confidence: {recommendation.confidence})")
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to analyze document for task {task_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to analyze document: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/metadata", response_model=ProcessingMetadata, summary="Get processing metadata")
|
||||
async def get_processing_metadata(
|
||||
task_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get processing metadata for a completed task
|
||||
|
||||
Returns detailed processing information including track used,
|
||||
element counts, and statistics.
|
||||
|
||||
- **task_id**: Task UUID
|
||||
"""
|
||||
try:
|
||||
# Get task details
|
||||
task = task_service.get_task_by_id(
|
||||
db=db,
|
||||
task_id=task_id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Task not found"
|
||||
)
|
||||
|
||||
if task.status != TaskStatus.COMPLETED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Task not completed"
|
||||
)
|
||||
|
||||
# Load JSON result to get metadata
|
||||
if not task.result_json_path:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Result JSON not found"
|
||||
)
|
||||
|
||||
json_path = Path(task.result_json_path)
|
||||
if not json_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Result file not found"
|
||||
)
|
||||
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
result_data = json.load(f)
|
||||
|
||||
# Extract metadata
|
||||
metadata = result_data.get('metadata', {})
|
||||
statistics = result_data.get('statistics', {})
|
||||
|
||||
response = ProcessingMetadata(
|
||||
processing_track=ProcessingTrackEnum(metadata.get('processing_track', 'ocr')),
|
||||
processing_time_seconds=metadata.get('processing_time', 0),
|
||||
language=metadata.get('language', 'ch'),
|
||||
page_count=statistics.get('page_count', 1),
|
||||
total_elements=statistics.get('total_elements', 0),
|
||||
total_text_regions=len(result_data.get('text_regions', [])) if 'text_regions' in result_data else statistics.get('total_elements', 0),
|
||||
total_tables=statistics.get('total_tables', 0),
|
||||
total_images=statistics.get('total_images', 0),
|
||||
average_confidence=result_data.get('average_confidence'),
|
||||
unified_format=metadata.get('processing_info', {}).get('export_format') == 'unified_document_v1'
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get metadata for task {task_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get metadata: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/download/unified", summary="Download unified format")
|
||||
async def download_unified(
|
||||
task_id: str,
|
||||
include_metadata: bool = Query(True, description="Include processing metadata"),
|
||||
include_statistics: bool = Query(True, description="Include document statistics"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Download results in unified document format
|
||||
|
||||
Returns JSON with full UnifiedDocument structure including
|
||||
all elements, coordinates, and metadata.
|
||||
|
||||
- **task_id**: Task UUID
|
||||
- **include_metadata**: Include processing metadata
|
||||
- **include_statistics**: Include document statistics
|
||||
"""
|
||||
try:
|
||||
# Get task details
|
||||
task = task_service.get_task_by_id(
|
||||
db=db,
|
||||
task_id=task_id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Task not found"
|
||||
)
|
||||
|
||||
if task.status != TaskStatus.COMPLETED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Task not completed"
|
||||
)
|
||||
|
||||
# Get JSON result path
|
||||
if not task.result_json_path:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Result JSON not found"
|
||||
)
|
||||
|
||||
json_path = Path(task.result_json_path)
|
||||
if not json_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Result file not found"
|
||||
)
|
||||
|
||||
# Return the unified format JSON
|
||||
return FileResponse(
|
||||
path=str(json_path),
|
||||
filename=f"{task_id}_unified.json",
|
||||
media_type="application/json"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to download unified format for task {task_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to download: {str(e)}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user