Files
OCR/frontend/src/services/batchProcessing.ts
egg d20751d56b feat: add batch processing for multiple file uploads
- Add BatchState management in taskStore with progress tracking
- Implement batch processing service with concurrency control
  - Direct Track: max 5 parallel tasks
  - OCR Track: sequential processing (GPU VRAM limit)
- Refactor ProcessingPage to support batch mode with BatchProcessingPanel
- Update UploadPage to initialize batch state for multi-file uploads
- Add i18n translations for batch processing (zh-TW, en-US)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-12 17:05:16 +08:00

325 lines
8.7 KiB
TypeScript

/**
* Batch Processing Service
*
* Handles batch processing of multiple tasks with parallel execution for Direct Track
* and queue processing for OCR Track (GPU VRAM limitation).
*
* Concurrency limits:
* - Direct Track: Max 5 concurrent tasks (CPU-bound)
* - OCR Track: Max 1 concurrent task (GPU VRAM limited)
*/
import { apiClientV2 } from '@/services/apiV2'
import { useTaskStore, type BatchTaskState, type BatchStrategy } from '@/store/taskStore'
import type { ProcessingTrack, ProcessingOptions } from '@/types/apiV2'
// Concurrency limits
const DIRECT_TRACK_CONCURRENCY = 5
const OCR_TRACK_CONCURRENCY = 1
// Polling interval for task status (ms)
const POLL_INTERVAL = 2000
/**
* Analyze all tasks in batch to determine recommended processing track
*/
export async function analyzeBatchTasks(taskIds: string[]): Promise<void> {
const { updateBatchTaskState, setBatchAnalyzing } = useTaskStore.getState()
setBatchAnalyzing(true)
try {
// Analyze all tasks in parallel (analysis is lightweight)
const analysisPromises = taskIds.map(async (taskId) => {
try {
const result = await apiClientV2.analyzeDocument(taskId)
updateBatchTaskState(taskId, {
recommendedTrack: result.recommended_track,
analysisResult: result,
})
return { taskId, success: true, result }
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Analysis failed'
updateBatchTaskState(taskId, {
error: errorMessage,
})
return { taskId, success: false, error: errorMessage }
}
})
await Promise.all(analysisPromises)
} finally {
setBatchAnalyzing(false)
}
}
/**
* Determine the actual track to use based on strategy and recommendation
*/
function determineTrack(
strategy: BatchStrategy,
recommendedTrack: ProcessingTrack | null
): ProcessingTrack {
switch (strategy) {
case 'force_ocr':
return 'ocr'
case 'force_direct':
return 'direct'
case 'auto':
default:
// Use recommended track from analysis, fallback to 'auto'
return recommendedTrack || 'auto'
}
}
/**
* Build processing options for a task
*/
function buildProcessingOptions(
track: ProcessingTrack,
batchOptions: {
layoutModel: 'chinese' | 'default' | 'cdla'
preprocessingMode: 'auto' | 'manual' | 'disabled'
language: string
}
): ProcessingOptions {
const options: ProcessingOptions = {
use_dual_track: true,
force_track: track,
language: batchOptions.language,
}
// Only add OCR-specific options for OCR track
if (track === 'ocr') {
options.layout_model = batchOptions.layoutModel
options.preprocessing_mode = batchOptions.preprocessingMode
}
return options
}
/**
* Process a single task and poll for completion
*/
async function processTask(
taskId: string,
options: ProcessingOptions,
onStatusUpdate: (status: BatchTaskState['status'], error?: string) => void
): Promise<boolean> {
try {
// Start processing
onStatusUpdate('processing')
await apiClientV2.startTask(taskId, options)
// Poll for completion
while (true) {
await sleep(POLL_INTERVAL)
const task = await apiClientV2.getTask(taskId)
if (task.status === 'completed') {
onStatusUpdate('completed')
return true
}
if (task.status === 'failed') {
onStatusUpdate('failed', task.error_message || 'Processing failed')
return false
}
// Still processing, continue polling
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error'
onStatusUpdate('failed', errorMessage)
return false
}
}
/**
* Sleep utility
*/
function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms))
}
/**
* Process tasks with concurrency limit
*/
async function processWithConcurrency(
tasks: Array<{ taskId: string; options: ProcessingOptions }>,
concurrency: number,
onTaskUpdate: (taskId: string, status: BatchTaskState['status'], error?: string) => void
): Promise<void> {
const queue = [...tasks]
const running: Promise<void>[] = []
while (queue.length > 0 || running.length > 0) {
// Start new tasks up to concurrency limit
while (running.length < concurrency && queue.length > 0) {
const task = queue.shift()!
const promise = processTask(task.taskId, task.options, (status, error) => {
onTaskUpdate(task.taskId, status, error)
}).then(() => {
// Remove from running when done
const index = running.indexOf(promise)
if (index > -1) {
running.splice(index, 1)
}
})
running.push(promise)
}
// Wait for at least one task to complete if we're at capacity
if (running.length >= concurrency || (queue.length === 0 && running.length > 0)) {
await Promise.race(running)
}
}
}
/**
* Main batch processing function
* Processes Direct Track tasks in parallel (max 5) and OCR Track tasks sequentially (max 1)
*/
export async function processBatch(): Promise<void> {
const store = useTaskStore.getState()
const { batchState, updateBatchTaskState, startBatchProcessing, stopBatchProcessing } = store
if (!batchState.isActive || batchState.taskIds.length === 0) {
console.warn('No batch to process')
return
}
// Start batch processing
startBatchProcessing()
const { taskIds, taskStates, processingOptions } = batchState
// Separate tasks by track
const directTasks: Array<{ taskId: string; options: ProcessingOptions }> = []
const ocrTasks: Array<{ taskId: string; options: ProcessingOptions }> = []
for (const taskId of taskIds) {
const taskState = taskStates[taskId]
if (!taskState || taskState.status !== 'pending') continue
const track = determineTrack(
processingOptions.strategy,
taskState.recommendedTrack
)
const options = buildProcessingOptions(track, processingOptions)
// Update task with determined track
updateBatchTaskState(taskId, { track })
if (track === 'direct') {
directTasks.push({ taskId, options })
} else {
// OCR, hybrid, or auto all go through OCR queue
ocrTasks.push({ taskId, options })
}
}
const onTaskUpdate = (taskId: string, status: BatchTaskState['status'], error?: string) => {
const updates: Partial<BatchTaskState> = { status }
if (status === 'processing') {
updates.startedAt = new Date().toISOString()
} else if (status === 'completed' || status === 'failed') {
updates.completedAt = new Date().toISOString()
if (error) updates.error = error
}
updateBatchTaskState(taskId, updates)
}
try {
// Process Direct and OCR tracks concurrently
// Direct: up to 5 parallel
// OCR: sequential (1 at a time)
await Promise.all([
processWithConcurrency(directTasks, DIRECT_TRACK_CONCURRENCY, onTaskUpdate),
processWithConcurrency(ocrTasks, OCR_TRACK_CONCURRENCY, onTaskUpdate),
])
} finally {
stopBatchProcessing()
}
}
/**
* Cancel batch processing
* Note: This cancels pending tasks but cannot cancel already-running tasks
*/
export async function cancelBatch(): Promise<void> {
const store = useTaskStore.getState()
const { batchState, updateBatchTaskState, stopBatchProcessing } = store
// Cancel all processing tasks
for (const taskId of batchState.taskIds) {
const taskState = batchState.taskStates[taskId]
if (taskState?.status === 'processing') {
try {
await apiClientV2.cancelTask(taskId)
updateBatchTaskState(taskId, {
status: 'failed',
error: 'Cancelled by user',
completedAt: new Date().toISOString(),
})
} catch (error) {
console.error(`Failed to cancel task ${taskId}:`, error)
}
}
}
stopBatchProcessing()
}
/**
* Get batch processing summary
*/
export function getBatchSummary(): {
total: number
directCount: number
ocrCount: number
pendingCount: number
completedCount: number
failedCount: number
} {
const { batchState } = useTaskStore.getState()
const { taskIds, taskStates } = batchState
let directCount = 0
let ocrCount = 0
let pendingCount = 0
let completedCount = 0
let failedCount = 0
for (const taskId of taskIds) {
const taskState = taskStates[taskId]
if (!taskState) continue
if (taskState.track === 'direct') directCount++
else if (taskState.track === 'ocr') ocrCount++
switch (taskState.status) {
case 'pending':
pendingCount++
break
case 'completed':
completedCount++
break
case 'failed':
failedCount++
break
}
}
return {
total: taskIds.length,
directCount,
ocrCount,
pendingCount,
completedCount,
failedCount,
}
}