feat: implement hybrid image extraction and memory management

Backend:
- Add hybrid image extraction for Direct track (inline image blocks)
- Add render_inline_image_regions() fallback when OCR doesn't find images
- Add check_document_for_missing_images() for detecting missing images
- Add memory management system (MemoryGuard, ModelManager, ServicePool)
- Update pdf_generator_service to handle HYBRID processing track
- Add ElementType.LOGO for logo extraction

Frontend:
- Fix PDF viewer re-rendering issues with memoization
- Add TaskNotFound component and useTaskValidation hook
- Disable StrictMode due to react-pdf incompatibility
- Fix task detail and results page loading states

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
egg
2025-11-26 10:56:22 +08:00
parent ba8ddf2b68
commit 1afdb822c3
26 changed files with 8273 additions and 366 deletions

View File

@@ -0,0 +1,468 @@
"""
Tool_OCR - OCR Service Pool
Manages a pool of OCRService instances to prevent duplicate model loading
and control concurrent GPU operations.
"""
import asyncio
import logging
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from app.services.memory_manager import get_model_manager, MemoryConfig
if TYPE_CHECKING:
from app.services.ocr_service import OCRService
logger = logging.getLogger(__name__)
class ServiceState(Enum):
"""State of a pooled service"""
AVAILABLE = "available"
IN_USE = "in_use"
UNHEALTHY = "unhealthy"
INITIALIZING = "initializing"
@dataclass
class PooledService:
"""Wrapper for a pooled OCRService instance"""
service: Any # OCRService
device: str
state: ServiceState = ServiceState.AVAILABLE
created_at: float = field(default_factory=time.time)
last_used: float = field(default_factory=time.time)
use_count: int = 0
error_count: int = 0
current_task_id: Optional[str] = None
class PoolConfig:
"""Configuration for the service pool"""
def __init__(
self,
max_services_per_device: int = 1,
max_total_services: int = 2,
acquire_timeout_seconds: float = 300.0,
max_queue_size: int = 50,
health_check_interval_seconds: int = 60,
max_consecutive_errors: int = 3,
service_idle_timeout_seconds: int = 600,
enable_auto_scaling: bool = False,
):
self.max_services_per_device = max_services_per_device
self.max_total_services = max_total_services
self.acquire_timeout_seconds = acquire_timeout_seconds
self.max_queue_size = max_queue_size
self.health_check_interval_seconds = health_check_interval_seconds
self.max_consecutive_errors = max_consecutive_errors
self.service_idle_timeout_seconds = service_idle_timeout_seconds
self.enable_auto_scaling = enable_auto_scaling
class OCRServicePool:
"""
Pool of OCRService instances with concurrency control.
Features:
- Per-device instance management (one service per GPU)
- Queue-based task distribution
- Semaphore-based concurrency limits
- Health monitoring
- Automatic service recovery
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
"""Singleton pattern"""
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, config: Optional[PoolConfig] = None):
if self._initialized:
return
self.config = config or PoolConfig()
self.services: Dict[str, List[PooledService]] = {}
self.semaphores: Dict[str, threading.Semaphore] = {}
self.queues: Dict[str, List] = {}
self._pool_lock = threading.RLock()
self._condition = threading.Condition(self._pool_lock)
# Metrics
self._metrics = {
"total_acquisitions": 0,
"total_releases": 0,
"total_timeouts": 0,
"total_errors": 0,
"queue_waits": 0,
}
# Initialize default device pool
self._initialize_device("GPU:0")
self._initialized = True
logger.info("OCRServicePool initialized")
def _initialize_device(self, device: str):
"""Initialize pool resources for a device"""
with self._pool_lock:
if device not in self.services:
self.services[device] = []
self.semaphores[device] = threading.Semaphore(
self.config.max_services_per_device
)
self.queues[device] = []
logger.info(f"Initialized pool for device {device}")
def _create_service(self, device: str) -> PooledService:
"""
Create a new OCRService instance for the pool.
Args:
device: Device identifier (e.g., "GPU:0", "CPU")
Returns:
PooledService wrapper
"""
# Import here to avoid circular imports
from app.services.ocr_service import OCRService
logger.info(f"Creating new OCRService for device {device}")
start_time = time.time()
# Create service instance
service = OCRService()
creation_time = time.time() - start_time
logger.info(f"OCRService created in {creation_time:.2f}s for device {device}")
return PooledService(
service=service,
device=device,
state=ServiceState.AVAILABLE
)
def acquire(
self,
device: str = "GPU:0",
timeout: Optional[float] = None,
task_id: Optional[str] = None
) -> Optional[PooledService]:
"""
Acquire an OCRService from the pool.
Args:
device: Preferred device (e.g., "GPU:0")
timeout: Maximum time to wait for a service
task_id: Optional task ID for tracking
Returns:
PooledService if available, None if timeout
"""
timeout = timeout or self.config.acquire_timeout_seconds
self._initialize_device(device)
start_time = time.time()
deadline = start_time + timeout
with self._condition:
while True:
# Try to get an available service
service = self._try_acquire_service(device, task_id)
if service:
self._metrics["total_acquisitions"] += 1
return service
# Check if we can create a new service
if self._can_create_service(device):
try:
pooled = self._create_service(device)
pooled.state = ServiceState.IN_USE
pooled.current_task_id = task_id
pooled.use_count += 1
self.services[device].append(pooled)
self._metrics["total_acquisitions"] += 1
logger.info(f"Created and acquired new service for {device}")
return pooled
except Exception as e:
logger.error(f"Failed to create service for {device}: {e}")
self._metrics["total_errors"] += 1
# Wait for a service to become available
remaining = deadline - time.time()
if remaining <= 0:
self._metrics["total_timeouts"] += 1
logger.warning(f"Timeout waiting for service on {device}")
return None
self._metrics["queue_waits"] += 1
logger.debug(f"Waiting for service on {device} (timeout: {remaining:.1f}s)")
self._condition.wait(timeout=min(remaining, 1.0))
def _try_acquire_service(self, device: str, task_id: Optional[str]) -> Optional[PooledService]:
"""Try to acquire an available service without waiting"""
for pooled in self.services.get(device, []):
if pooled.state == ServiceState.AVAILABLE:
pooled.state = ServiceState.IN_USE
pooled.last_used = time.time()
pooled.use_count += 1
pooled.current_task_id = task_id
logger.debug(f"Acquired existing service for {device} (use #{pooled.use_count})")
return pooled
return None
def _can_create_service(self, device: str) -> bool:
"""Check if a new service can be created"""
device_count = len(self.services.get(device, []))
total_count = sum(len(services) for services in self.services.values())
return (
device_count < self.config.max_services_per_device and
total_count < self.config.max_total_services
)
def release(self, pooled: PooledService, error: Optional[Exception] = None):
"""
Release a service back to the pool.
Args:
pooled: The pooled service to release
error: Optional error that occurred during use
"""
with self._condition:
if error:
pooled.error_count += 1
self._metrics["total_errors"] += 1
logger.warning(f"Service released with error: {error}")
# Mark unhealthy if too many errors
if pooled.error_count >= self.config.max_consecutive_errors:
pooled.state = ServiceState.UNHEALTHY
logger.error(f"Service marked unhealthy after {pooled.error_count} errors")
else:
pooled.state = ServiceState.AVAILABLE
else:
pooled.error_count = 0 # Reset error count on success
pooled.state = ServiceState.AVAILABLE
pooled.last_used = time.time()
pooled.current_task_id = None
self._metrics["total_releases"] += 1
# Clean up GPU memory after release
try:
model_manager = get_model_manager()
model_manager.memory_guard.clear_gpu_cache()
except Exception as e:
logger.debug(f"Cache clear after release failed: {e}")
# Notify waiting threads
self._condition.notify_all()
logger.debug(f"Service released for device {pooled.device}")
@contextmanager
def acquire_context(
self,
device: str = "GPU:0",
timeout: Optional[float] = None,
task_id: Optional[str] = None
):
"""
Context manager for acquiring and releasing a service.
Usage:
with pool.acquire_context("GPU:0") as pooled:
result = pooled.service.process(...)
"""
pooled = None
error = None
try:
pooled = self.acquire(device, timeout, task_id)
if pooled is None:
raise TimeoutError(f"Failed to acquire service for {device}")
yield pooled
except Exception as e:
error = e
raise
finally:
if pooled:
self.release(pooled, error)
def get_service(self, device: str = "GPU:0") -> Optional["OCRService"]:
"""
Get a service directly (for backward compatibility).
This acquires a service and returns the underlying OCRService.
The caller is responsible for calling release_service() when done.
Args:
device: Device identifier
Returns:
OCRService instance or None
"""
pooled = self.acquire(device)
if pooled:
return pooled.service
return None
def get_pool_stats(self) -> Dict:
"""Get current pool statistics"""
with self._pool_lock:
stats = {
"devices": {},
"metrics": self._metrics.copy(),
"total_services": 0,
"available_services": 0,
"in_use_services": 0,
}
for device, services in self.services.items():
available = sum(1 for s in services if s.state == ServiceState.AVAILABLE)
in_use = sum(1 for s in services if s.state == ServiceState.IN_USE)
unhealthy = sum(1 for s in services if s.state == ServiceState.UNHEALTHY)
stats["devices"][device] = {
"total": len(services),
"available": available,
"in_use": in_use,
"unhealthy": unhealthy,
"max_allowed": self.config.max_services_per_device,
}
stats["total_services"] += len(services)
stats["available_services"] += available
stats["in_use_services"] += in_use
return stats
def health_check(self) -> Dict:
"""
Perform health check on all pooled services.
Returns:
Health check results
"""
results = {
"healthy": True,
"services": [],
"timestamp": time.time()
}
with self._pool_lock:
for device, services in self.services.items():
for idx, pooled in enumerate(services):
service_health = {
"device": device,
"index": idx,
"state": pooled.state.value,
"error_count": pooled.error_count,
"use_count": pooled.use_count,
"idle_seconds": time.time() - pooled.last_used,
}
# Check if service is responsive
if pooled.state == ServiceState.AVAILABLE:
try:
# Simple check - verify service has required attributes
has_process = hasattr(pooled.service, 'process')
has_gpu_status = hasattr(pooled.service, 'get_gpu_status')
service_health["responsive"] = has_process and has_gpu_status
except Exception as e:
service_health["responsive"] = False
service_health["error"] = str(e)
results["healthy"] = False
else:
service_health["responsive"] = pooled.state != ServiceState.UNHEALTHY
if pooled.state == ServiceState.UNHEALTHY:
results["healthy"] = False
results["services"].append(service_health)
return results
def recover_unhealthy(self):
"""
Attempt to recover unhealthy services.
"""
with self._pool_lock:
for device, services in self.services.items():
for idx, pooled in enumerate(services):
if pooled.state == ServiceState.UNHEALTHY:
logger.info(f"Attempting to recover unhealthy service {device}:{idx}")
try:
# Remove old service
services.remove(pooled)
# Create new service
new_pooled = self._create_service(device)
services.append(new_pooled)
logger.info(f"Successfully recovered service {device}:{idx}")
except Exception as e:
logger.error(f"Failed to recover service {device}:{idx}: {e}")
def shutdown(self):
"""
Shutdown the pool and cleanup all services.
"""
logger.info("OCRServicePool shutdown started")
with self._pool_lock:
for device, services in self.services.items():
for pooled in services:
try:
# Clean up service resources
if hasattr(pooled.service, 'cleanup_gpu_memory'):
pooled.service.cleanup_gpu_memory()
except Exception as e:
logger.warning(f"Error cleaning up service: {e}")
# Clear all pools
self.services.clear()
self.semaphores.clear()
self.queues.clear()
logger.info("OCRServicePool shutdown completed")
# Global singleton instance
_service_pool: Optional[OCRServicePool] = None
def get_service_pool(config: Optional[PoolConfig] = None) -> OCRServicePool:
"""
Get the global OCRServicePool instance.
Args:
config: Optional configuration (only used on first call)
Returns:
OCRServicePool singleton instance
"""
global _service_pool
if _service_pool is None:
_service_pool = OCRServicePool(config)
return _service_pool
def shutdown_service_pool():
"""Shutdown the global service pool"""
global _service_pool
if _service_pool is not None:
_service_pool.shutdown()
_service_pool = None