feat: implement hybrid image extraction and memory management

Backend:
- Add hybrid image extraction for Direct track (inline image blocks)
- Add render_inline_image_regions() fallback when OCR doesn't find images
- Add check_document_for_missing_images() for detecting missing images
- Add memory management system (MemoryGuard, ModelManager, ServicePool)
- Update pdf_generator_service to handle HYBRID processing track
- Add ElementType.LOGO for logo extraction

Frontend:
- Fix PDF viewer re-rendering issues with memoization
- Add TaskNotFound component and useTaskValidation hook
- Disable StrictMode due to react-pdf incompatibility
- Fix task detail and results page loading states

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
egg
2025-11-26 10:56:22 +08:00
parent ba8ddf2b68
commit 1afdb822c3
26 changed files with 8273 additions and 366 deletions

View File

@@ -247,9 +247,11 @@ class DirectExtractionEngine:
element_counter += len(image_elements)
# Extract vector graphics (charts, diagrams) from drawing commands
# Pass table_bboxes to filter out table border drawings before clustering
if self.enable_image_extraction:
vector_elements = self._extract_vector_graphics(
page, page_num, document_id, element_counter, output_dir
page, page_num, document_id, element_counter, output_dir,
table_bboxes=table_bboxes
)
elements.extend(vector_elements)
element_counter += len(vector_elements)
@@ -705,40 +707,52 @@ class DirectExtractionEngine:
y1=bbox_data[3]
)
# Extract column widths from table cells
# Extract column widths from table cells by analyzing X boundaries
column_widths = []
if hasattr(table, 'cells') and table.cells:
# Group cells by column
cols_x = {}
# Collect all unique X boundaries (both left and right edges)
x_boundaries = set()
for cell in table.cells:
col_idx = None
# Determine column index by x0 position
for idx, x0 in enumerate(sorted(set(c[0] for c in table.cells))):
if abs(cell[0] - x0) < 1.0: # Within 1pt tolerance
col_idx = idx
break
x_boundaries.add(round(cell[0], 1)) # x0 (left edge)
x_boundaries.add(round(cell[2], 1)) # x1 (right edge)
if col_idx is not None:
if col_idx not in cols_x:
cols_x[col_idx] = {'x0': cell[0], 'x1': cell[2]}
else:
cols_x[col_idx]['x1'] = max(cols_x[col_idx]['x1'], cell[2])
# Sort boundaries to get column edges
sorted_x = sorted(x_boundaries)
# Calculate width for each column
for col_idx in sorted(cols_x.keys()):
width = cols_x[col_idx]['x1'] - cols_x[col_idx]['x0']
column_widths.append(width)
# Calculate column widths from adjacent boundaries
if len(sorted_x) >= 2:
column_widths = [sorted_x[i+1] - sorted_x[i] for i in range(len(sorted_x)-1)]
logger.debug(f"Calculated column widths from {len(sorted_x)} boundaries: {column_widths}")
# Extract row heights from table cells by analyzing Y boundaries
row_heights = []
if hasattr(table, 'cells') and table.cells:
# Collect all unique Y boundaries (both top and bottom edges)
y_boundaries = set()
for cell in table.cells:
y_boundaries.add(round(cell[1], 1)) # y0 (top edge)
y_boundaries.add(round(cell[3], 1)) # y1 (bottom edge)
# Sort boundaries to get row edges
sorted_y = sorted(y_boundaries)
# Calculate row heights from adjacent boundaries
if len(sorted_y) >= 2:
row_heights = [sorted_y[i+1] - sorted_y[i] for i in range(len(sorted_y)-1)]
logger.debug(f"Calculated row heights from {len(sorted_y)} boundaries: {row_heights}")
# Create table cells
# Note: Include ALL cells (even empty ones) to preserve table structure
# This is critical for correct HTML generation and PDF rendering
cells = []
for row_idx, row in enumerate(data):
for col_idx, cell_text in enumerate(row):
if cell_text:
cells.append(TableCell(
row=row_idx,
col=col_idx,
content=str(cell_text) if cell_text else ""
))
# Always add cell, even if empty, to maintain table structure
cells.append(TableCell(
row=row_idx,
col=col_idx,
content=str(cell_text) if cell_text else ""
))
# Create table data
table_data = TableData(
@@ -748,8 +762,13 @@ class DirectExtractionEngine:
headers=data[0] if data else None # Assume first row is header
)
# Store column widths in metadata
metadata = {"column_widths": column_widths} if column_widths else None
# Store column widths and row heights in metadata
metadata = {}
if column_widths:
metadata["column_widths"] = column_widths
if row_heights:
metadata["row_heights"] = row_heights
metadata = metadata if metadata else None
return DocumentElement(
element_id=f"table_{page_num}_{counter}",
@@ -978,7 +997,9 @@ class DirectExtractionEngine:
image_filename = f"{document_id}_p{page_num}_img{img_idx}.png"
image_path = output_dir / image_filename
pix.save(str(image_path))
image_data["saved_path"] = str(image_path)
# Store relative filename only (consistent with OCR track)
# PDF generator will join with result_dir to get full path
image_data["saved_path"] = image_filename
logger.debug(f"Saved image to {image_path}")
element = DocumentElement(
@@ -1001,12 +1022,272 @@ class DirectExtractionEngine:
return elements
def has_missing_images(self, page: fitz.Page) -> bool:
"""
Detect if a page likely has images that weren't extracted.
This checks for inline image blocks (type=1 in text dict) which indicate
graphics composed of many small image blocks (like logos) that
page.get_images() cannot detect.
Args:
page: PyMuPDF page object
Returns:
True if there are likely missing images that need OCR extraction
"""
try:
# Check if get_images found anything
standard_images = page.get_images()
if standard_images:
return False # Standard images were found, no need for fallback
# Check for inline image blocks (type=1)
text_dict = page.get_text("dict", sort=True)
blocks = text_dict.get("blocks", [])
image_block_count = sum(1 for b in blocks if b.get("type") == 1)
# If there are many inline image blocks, likely there's a logo or graphic
if image_block_count >= 10:
logger.info(f"Detected {image_block_count} inline image blocks - may need OCR for image extraction")
return True
return False
except Exception as e:
logger.warning(f"Error checking for missing images: {e}")
return False
def check_document_for_missing_images(self, pdf_path: Path) -> List[int]:
"""
Check a PDF document for pages that likely have missing images.
This opens the PDF and checks each page for inline image blocks
that weren't extracted by get_images().
Args:
pdf_path: Path to the PDF file
Returns:
List of page numbers (1-indexed) that have missing images
"""
pages_with_missing_images = []
try:
doc = fitz.open(str(pdf_path))
for page_num in range(len(doc)):
page = doc[page_num]
if self.has_missing_images(page):
pages_with_missing_images.append(page_num + 1) # 1-indexed
doc.close()
if pages_with_missing_images:
logger.info(f"Document has missing images on pages: {pages_with_missing_images}")
except Exception as e:
logger.error(f"Error checking document for missing images: {e}")
return pages_with_missing_images
def render_inline_image_regions(
self,
pdf_path: Path,
unified_doc: 'UnifiedDocument',
pages: List[int],
output_dir: Optional[Path] = None
) -> int:
"""
Render inline image regions and add them to the unified document.
This is a fallback when OCR doesn't detect images. It clusters inline
image blocks (type=1) and renders them as images.
Args:
pdf_path: Path to the PDF file
unified_doc: UnifiedDocument to add images to
pages: List of page numbers (1-indexed) to process
output_dir: Directory to save rendered images
Returns:
Number of images added
"""
images_added = 0
try:
doc = fitz.open(str(pdf_path))
for page_num in pages:
if page_num < 1 or page_num > len(doc):
continue
page = doc[page_num - 1] # 0-indexed
page_rect = page.rect
# Get inline image blocks
text_dict = page.get_text("dict", sort=True)
blocks = text_dict.get("blocks", [])
image_blocks = []
for block in blocks:
if block.get("type") == 1: # Image block
bbox = block.get("bbox")
if bbox:
image_blocks.append(fitz.Rect(bbox))
if len(image_blocks) < 5: # Reduced from 10
logger.debug(f"Page {page_num}: Only {len(image_blocks)} inline image blocks, skipping")
continue
logger.info(f"Page {page_num}: Found {len(image_blocks)} inline image blocks")
# Cluster nearby image blocks
regions = self._cluster_nearby_rects(image_blocks, tolerance=5.0)
logger.info(f"Page {page_num}: Clustered into {len(regions)} regions")
# Find the corresponding page in unified_doc
target_page = None
for p in unified_doc.pages:
if p.page_number == page_num:
target_page = p
break
if not target_page:
continue
for region_idx, region_rect in enumerate(regions):
logger.info(f"Page {page_num} region {region_idx}: {region_rect} (w={region_rect.width:.1f}, h={region_rect.height:.1f})")
# Skip very small regions
if region_rect.width < 30 or region_rect.height < 30:
logger.info(f" -> Skipped: too small (min 30x30)")
continue
# Skip regions that are primarily in the table area (below top 40%)
# But allow regions that START in the top portion
page_30_pct = page_rect.height * 0.3
page_40_pct = page_rect.height * 0.4
if region_rect.y0 > page_40_pct:
logger.info(f" -> Skipped: y0={region_rect.y0:.1f} > 40% of page ({page_40_pct:.1f})")
continue
logger.info(f"Rendering inline image region {region_idx} on page {page_num}: {region_rect}")
try:
# Add small padding
clip_rect = region_rect + (-2, -2, 2, 2)
clip_rect.intersect(page_rect)
# Render at 2x resolution
mat = fitz.Matrix(2, 2)
pix = page.get_pixmap(clip=clip_rect, matrix=mat, alpha=False)
# Create bounding box
bbox = BoundingBox(
x0=clip_rect.x0,
y0=clip_rect.y0,
x1=clip_rect.x1,
y1=clip_rect.y1
)
image_data = {
"width": pix.width,
"height": pix.height,
"colorspace": "rgb",
"type": "inline_region"
}
# Save image if output directory provided
if output_dir:
output_dir.mkdir(parents=True, exist_ok=True)
doc_id = unified_doc.document_id or "unknown"
image_filename = f"{doc_id}_p{page_num}_logo{region_idx}.png"
image_path = output_dir / image_filename
pix.save(str(image_path))
image_data["saved_path"] = image_filename
logger.info(f"Saved inline image region to {image_path}")
element = DocumentElement(
element_id=f"logo_{page_num}_{region_idx}",
type=ElementType.LOGO,
content=image_data,
bbox=bbox,
confidence=0.9,
metadata={
"region_type": "inline_image_blocks",
"block_count": len(image_blocks)
}
)
target_page.elements.append(element)
images_added += 1
pix = None # Free memory
except Exception as e:
logger.error(f"Error rendering inline image region {region_idx}: {e}")
doc.close()
if images_added > 0:
current_images = unified_doc.metadata.total_images or 0
unified_doc.metadata.total_images = current_images + images_added
logger.info(f"Added {images_added} inline image regions to document")
except Exception as e:
logger.error(f"Error rendering inline image regions: {e}")
return images_added
def _cluster_nearby_rects(self, rects: List[fitz.Rect], tolerance: float = 5.0) -> List[fitz.Rect]:
"""Cluster nearby rectangles into regions."""
if not rects:
return []
sorted_rects = sorted(rects, key=lambda r: (r.y0, r.x0))
merged = []
for rect in sorted_rects:
merged_with_existing = False
for i, region in enumerate(merged):
expanded = region + (-tolerance, -tolerance, tolerance, tolerance)
if expanded.intersects(rect):
merged[i] = region | rect
merged_with_existing = True
break
if not merged_with_existing:
merged.append(rect)
# Second pass: merge any regions that now overlap
changed = True
while changed:
changed = False
new_merged = []
skip = set()
for i, r1 in enumerate(merged):
if i in skip:
continue
current = r1
for j, r2 in enumerate(merged[i+1:], start=i+1):
if j in skip:
continue
expanded = current + (-tolerance, -tolerance, tolerance, tolerance)
if expanded.intersects(r2):
current = current | r2
skip.add(j)
changed = True
new_merged.append(current)
merged = new_merged
return merged
def _extract_vector_graphics(self,
page: fitz.Page,
page_num: int,
document_id: str,
counter: int,
output_dir: Optional[Path]) -> List[DocumentElement]:
output_dir: Optional[Path],
table_bboxes: Optional[List[BoundingBox]] = None) -> List[DocumentElement]:
"""
Extract vector graphics (charts, diagrams) from page.
@@ -1020,6 +1301,7 @@ class DirectExtractionEngine:
document_id: Unique document identifier
counter: Starting counter for element IDs
output_dir: Directory to save rendered graphics
table_bboxes: List of table bounding boxes to exclude table border drawings
Returns:
List of DocumentElement objects representing vector graphics
@@ -1034,16 +1316,25 @@ class DirectExtractionEngine:
logger.debug(f"Page {page_num} contains {len(drawings)} vector drawing commands")
# Filter out drawings that are likely table borders
# Table borders are typically thin rectangular lines within table regions
non_table_drawings = self._filter_table_border_drawings(drawings, table_bboxes)
logger.debug(f"After filtering table borders: {len(non_table_drawings)} drawings remain")
if not non_table_drawings:
logger.debug("All drawings appear to be table borders, no vector graphics to extract")
return elements
# Cluster drawings into groups (charts, diagrams, etc.)
try:
# PyMuPDF's cluster_drawings() groups nearby drawings automatically
drawing_clusters = page.cluster_drawings()
# Use custom clustering that only considers non-table drawings
drawing_clusters = self._cluster_non_table_drawings(page, non_table_drawings)
logger.debug(f"Clustered into {len(drawing_clusters)} groups")
except (AttributeError, TypeError) as e:
# cluster_drawings not available or has different signature
# Fallback: try to identify charts by analyzing drawing density
logger.warning(f"cluster_drawings() failed ({e}), using fallback method")
drawing_clusters = self._cluster_drawings_fallback(page, drawings)
logger.warning(f"Custom clustering failed ({e}), using fallback method")
drawing_clusters = self._cluster_drawings_fallback(page, non_table_drawings)
for cluster_idx, bbox in enumerate(drawing_clusters):
# Ignore small regions (likely noise or separator lines)
@@ -1148,6 +1439,124 @@ class DirectExtractionEngine:
return filtered_clusters
def _filter_table_border_drawings(self, drawings: list, table_bboxes: Optional[List[BoundingBox]]) -> list:
"""
Filter out drawings that are likely table borders.
Table borders are typically:
- Thin rectangular lines (height or width < 5pt)
- Located within or on the edge of table bounding boxes
Args:
drawings: List of PyMuPDF drawing objects
table_bboxes: List of table bounding boxes
Returns:
List of drawings that are NOT table borders (likely logos, charts, etc.)
"""
if not table_bboxes:
return drawings
non_table_drawings = []
table_border_count = 0
for drawing in drawings:
rect = drawing.get('rect')
if not rect:
continue
draw_rect = fitz.Rect(rect)
# Check if this drawing is a thin line (potential table border)
is_thin_line = draw_rect.width < 5 or draw_rect.height < 5
# Check if drawing overlaps significantly with any table
overlaps_table = False
for table_bbox in table_bboxes:
table_rect = fitz.Rect(table_bbox.x0, table_bbox.y0, table_bbox.x1, table_bbox.y1)
# Expand table rect slightly to include border lines on edges
expanded_table = table_rect + (-5, -5, 5, 5)
if expanded_table.contains(draw_rect) or expanded_table.intersects(draw_rect):
# Calculate overlap ratio
intersection = draw_rect & expanded_table
if not intersection.is_empty:
overlap_ratio = intersection.get_area() / draw_rect.get_area() if draw_rect.get_area() > 0 else 0
# If drawing is mostly inside table region, it's likely a border
if overlap_ratio > 0.8:
overlaps_table = True
break
# Keep drawing if it's NOT (thin line AND overlapping table)
# This keeps: logos (complex shapes), charts outside tables, etc.
if is_thin_line and overlaps_table:
table_border_count += 1
else:
non_table_drawings.append(drawing)
if table_border_count > 0:
logger.debug(f"Filtered out {table_border_count} table border drawings")
return non_table_drawings
def _cluster_non_table_drawings(self, page: fitz.Page, drawings: list) -> list:
"""
Cluster non-table drawings into groups.
This method clusters drawings that have been pre-filtered to exclude table borders.
It uses a more conservative clustering approach suitable for logos and charts.
Args:
page: PyMuPDF page object
drawings: Pre-filtered list of drawings (excluding table borders)
Returns:
List of fitz.Rect representing clustered drawing regions
"""
if not drawings:
return []
# Collect all drawing bounding boxes
bboxes = []
for drawing in drawings:
rect = drawing.get('rect')
if rect:
bboxes.append(fitz.Rect(rect))
if not bboxes:
return []
# More conservative clustering with smaller tolerance
# This prevents grouping distant graphics together
clusters = []
tolerance = 10 # Smaller tolerance than fallback (was 20)
for bbox in bboxes:
# Try to merge with existing cluster
merged = False
for i, cluster in enumerate(clusters):
# Check if bbox is close to this cluster
expanded_cluster = cluster + (-tolerance, -tolerance, tolerance, tolerance)
if expanded_cluster.intersects(bbox):
# Merge bbox into cluster
clusters[i] = cluster | bbox # Union of rectangles
merged = True
break
if not merged:
# Create new cluster
clusters.append(bbox)
# Filter out very small clusters (noise)
# Keep minimum 30x30 for logos (smaller than default 50x50)
filtered_clusters = [c for c in clusters if c.width >= 30 and c.height >= 30]
logger.debug(f"Non-table clustering: {len(bboxes)} drawings -> {len(clusters)} clusters -> {len(filtered_clusters)} filtered")
return filtered_clusters
def _deduplicate_table_chart_overlap(self, elements: List[DocumentElement]) -> List[DocumentElement]:
"""
Intelligently resolve TABLE-CHART overlaps based on table structure completeness.

File diff suppressed because it is too large Load Diff

View File

@@ -25,6 +25,7 @@ except ImportError:
from app.core.config import settings
from app.services.office_converter import OfficeConverter, OfficeConverterError
from app.services.memory_manager import get_model_manager, MemoryConfig, MemoryGuard, prediction_context
# Import dual-track components
try:
@@ -96,6 +97,26 @@ class OCRService:
self._model_last_used = {} # Track last usage time for each model
self._memory_warning_logged = False
# Initialize MemoryGuard for enhanced memory monitoring
self._memory_guard = None
if settings.enable_model_lifecycle_management:
try:
memory_config = MemoryConfig(
warning_threshold=settings.memory_warning_threshold,
critical_threshold=settings.memory_critical_threshold,
emergency_threshold=settings.memory_emergency_threshold,
model_idle_timeout_seconds=settings.pp_structure_idle_timeout_seconds,
gpu_memory_limit_mb=settings.gpu_memory_limit_mb,
enable_cpu_fallback=settings.enable_cpu_fallback,
)
self._memory_guard = MemoryGuard(memory_config)
logger.debug("MemoryGuard initialized for OCRService")
except Exception as e:
logger.warning(f"Failed to initialize MemoryGuard: {e}")
# Track if CPU fallback was activated
self._cpu_fallback_active = False
self._detect_and_configure_gpu()
# Log GPU optimization settings
@@ -217,53 +238,91 @@ class OCRService:
def _check_gpu_memory_usage(self):
"""
Check GPU memory usage and log warnings if approaching limits.
Implements memory optimization for RTX 4060 8GB.
Uses MemoryGuard for enhanced monitoring with multiple backends.
"""
if not self.use_gpu or not settings.enable_memory_optimization:
return
try:
device_id = self.gpu_info.get('device_id', 0)
memory_allocated = paddle.device.cuda.memory_allocated(device_id)
memory_allocated_mb = memory_allocated / (1024**2)
memory_limit_mb = settings.gpu_memory_limit_mb
# Use MemoryGuard if available for better monitoring
if self._memory_guard:
stats = self._memory_guard.get_memory_stats()
utilization = (memory_allocated_mb / memory_limit_mb * 100) if memory_limit_mb > 0 else 0
# Log based on usage ratio
if stats.gpu_used_ratio > 0.90 and not self._memory_warning_logged:
logger.warning(
f"GPU memory usage critical: {stats.gpu_used_mb:.0f}MB / {stats.gpu_total_mb:.0f}MB "
f"({stats.gpu_used_ratio*100:.1f}%)"
)
logger.warning("Consider enabling auto_unload_unused_models or reducing batch size")
self._memory_warning_logged = True
if utilization > 90 and not self._memory_warning_logged:
logger.warning(f"GPU memory usage high: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
logger.warning("Consider enabling auto_unload_unused_models or reducing batch size")
self._memory_warning_logged = True
elif utilization > 75:
logger.info(f"GPU memory: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
# Trigger emergency cleanup if enabled
if settings.enable_emergency_cleanup:
self._cleanup_unused_models()
self._memory_guard.clear_gpu_cache()
elif stats.gpu_used_ratio > 0.75:
logger.info(
f"GPU memory: {stats.gpu_used_mb:.0f}MB / {stats.gpu_total_mb:.0f}MB "
f"({stats.gpu_used_ratio*100:.1f}%)"
)
else:
# Fallback to original implementation
device_id = self.gpu_info.get('device_id', 0)
memory_allocated = paddle.device.cuda.memory_allocated(device_id)
memory_allocated_mb = memory_allocated / (1024**2)
memory_limit_mb = settings.gpu_memory_limit_mb
utilization = (memory_allocated_mb / memory_limit_mb * 100) if memory_limit_mb > 0 else 0
if utilization > 90 and not self._memory_warning_logged:
logger.warning(f"GPU memory usage high: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
logger.warning("Consider enabling auto_unload_unused_models or reducing batch size")
self._memory_warning_logged = True
elif utilization > 75:
logger.info(f"GPU memory: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
except Exception as e:
logger.debug(f"Memory check failed: {e}")
def _cleanup_unused_models(self):
"""
Clean up unused language models to free GPU memory.
Clean up unused models (including PP-StructureV3) to free GPU memory.
Models idle longer than model_idle_timeout_seconds will be unloaded.
Note: PP-StructureV3 is NO LONGER exempted from cleanup - it will be
unloaded based on pp_structure_idle_timeout_seconds configuration.
"""
if not settings.auto_unload_unused_models:
return
current_time = datetime.now()
timeout = settings.model_idle_timeout_seconds
models_to_remove = []
for lang, last_used in self._model_last_used.items():
if lang == 'structure': # Don't unload structure engine
continue
# Use different timeout for structure engine vs language models
if lang == 'structure':
timeout = settings.pp_structure_idle_timeout_seconds
else:
timeout = settings.model_idle_timeout_seconds
idle_seconds = (current_time - last_used).total_seconds()
if idle_seconds > timeout:
models_to_remove.append(lang)
for lang in models_to_remove:
if lang in self.ocr_engines:
logger.info(f"Unloading idle OCR engine for {lang} (idle {timeout}s)")
del self.ocr_engines[lang]
del self._model_last_used[lang]
for model_key in models_to_remove:
if model_key == 'structure':
if self.structure_engine is not None:
logger.info(f"Unloading idle PP-StructureV3 engine (idle {settings.pp_structure_idle_timeout_seconds}s)")
self._unload_structure_engine()
if model_key in self._model_last_used:
del self._model_last_used[model_key]
elif model_key in self.ocr_engines:
logger.info(f"Unloading idle OCR engine for {model_key} (idle {settings.model_idle_timeout_seconds}s)")
del self.ocr_engines[model_key]
if model_key in self._model_last_used:
del self._model_last_used[model_key]
if models_to_remove and self.use_gpu:
# Clear CUDA cache
@@ -273,6 +332,41 @@ class OCRService:
except Exception as e:
logger.debug(f"Cache clear failed: {e}")
def _unload_structure_engine(self):
"""
Properly unload PP-StructureV3 engine and free GPU memory.
"""
if self.structure_engine is None:
return
try:
# Clear internal engine components
if hasattr(self.structure_engine, 'table_engine'):
self.structure_engine.table_engine = None
if hasattr(self.structure_engine, 'text_detector'):
self.structure_engine.text_detector = None
if hasattr(self.structure_engine, 'text_recognizer'):
self.structure_engine.text_recognizer = None
if hasattr(self.structure_engine, 'layout_predictor'):
self.structure_engine.layout_predictor = None
# Delete the engine
del self.structure_engine
self.structure_engine = None
# Force garbage collection
gc.collect()
# Clear GPU cache
if self.use_gpu:
paddle.device.cuda.empty_cache()
logger.info("PP-StructureV3 engine unloaded successfully")
except Exception as e:
logger.warning(f"Error unloading PP-StructureV3: {e}")
self.structure_engine = None
def clear_gpu_cache(self):
"""
Manually clear GPU memory cache.
@@ -519,46 +613,160 @@ class OCRService:
logger.warning(f"GPU memory cleanup failed (non-critical): {e}")
# Don't fail the processing if cleanup fails
def check_gpu_memory(self, required_mb: int = 2000) -> bool:
def check_gpu_memory(self, required_mb: int = 2000, enable_fallback: bool = True) -> bool:
"""
Check if sufficient GPU memory is available.
Check if sufficient GPU memory is available using MemoryGuard.
This method now uses MemoryGuard for accurate memory queries across
multiple backends (pynvml, torch, paddle) instead of returning True
blindly for PaddlePaddle-only environments.
Args:
required_mb: Required memory in MB (default 2000MB for OCR models)
enable_fallback: If True and CPU fallback is enabled, switch to CPU mode
when memory is insufficient instead of returning False
Returns:
True if sufficient memory is available or GPU is not used
True if sufficient memory is available, GPU is not used, or CPU fallback activated
"""
try:
# Check GPU memory using torch if available, otherwise use PaddlePaddle
free_memory = None
# If not using GPU, always return True
if not self.use_gpu:
return True
if TORCH_AVAILABLE and torch.cuda.is_available():
free_memory = torch.cuda.mem_get_info()[0] / 1024**2
elif paddle.device.is_compiled_with_cuda():
# PaddlePaddle doesn't have direct API to get free memory,
# so we rely on cleanup and continue
logger.debug("Using PaddlePaddle GPU, memory info not directly available")
try:
# Use MemoryGuard if available for accurate multi-backend memory queries
if self._memory_guard:
is_available, stats = self._memory_guard.check_memory(
required_mb=required_mb,
device_id=self.gpu_info.get('device_id', 0)
)
if not is_available:
logger.warning(
f"GPU memory check failed: {stats.gpu_free_mb:.0f}MB free, "
f"{required_mb}MB required ({stats.gpu_used_ratio*100:.1f}% used)"
)
# Try to free memory
logger.info("Attempting memory cleanup before retry...")
self._cleanup_unused_models()
self._memory_guard.clear_gpu_cache()
# Check again
is_available, stats = self._memory_guard.check_memory(required_mb=required_mb)
if not is_available:
# Memory still insufficient after cleanup
if enable_fallback and settings.enable_cpu_fallback:
logger.warning(
f"Insufficient GPU memory ({stats.gpu_free_mb:.0f}MB) after cleanup. "
f"Activating CPU fallback mode."
)
self._activate_cpu_fallback()
return True # Continue with CPU
else:
logger.error(
f"Insufficient GPU memory: {stats.gpu_free_mb:.0f}MB available, "
f"{required_mb}MB required"
)
return False
logger.debug(
f"GPU memory check passed: {stats.gpu_free_mb:.0f}MB free "
f"({stats.gpu_used_ratio*100:.1f}% used)"
)
return True
if free_memory is not None:
if free_memory < required_mb:
logger.warning(f"Low GPU memory: {free_memory:.0f}MB available, {required_mb}MB required")
# Try to free memory
self.cleanup_gpu_memory()
# Check again
if TORCH_AVAILABLE and torch.cuda.is_available():
free_memory = torch.cuda.mem_get_info()[0] / 1024**2
if free_memory < required_mb:
logger.error(f"Insufficient GPU memory after cleanup: {free_memory:.0f}MB")
return False
logger.debug(f"GPU memory check passed: {free_memory:.0f}MB available")
else:
# Fallback to original implementation
free_memory = None
if TORCH_AVAILABLE and torch.cuda.is_available():
free_memory = torch.cuda.mem_get_info()[0] / 1024**2
elif paddle.device.is_compiled_with_cuda():
# PaddlePaddle doesn't have direct API to get free memory,
# use allocated memory to estimate
device_id = self.gpu_info.get('device_id', 0)
allocated = paddle.device.cuda.memory_allocated(device_id) / (1024**2)
total = settings.gpu_memory_limit_mb
free_memory = max(0, total - allocated)
logger.debug(f"Estimated free GPU memory: {free_memory:.0f}MB (total: {total}MB, allocated: {allocated:.0f}MB)")
if free_memory is not None:
if free_memory < required_mb:
logger.warning(f"Low GPU memory: {free_memory:.0f}MB available, {required_mb}MB required")
self.cleanup_gpu_memory()
# Recheck
if TORCH_AVAILABLE and torch.cuda.is_available():
free_memory = torch.cuda.mem_get_info()[0] / 1024**2
else:
allocated = paddle.device.cuda.memory_allocated(device_id) / (1024**2)
free_memory = max(0, total - allocated)
if free_memory < required_mb:
if enable_fallback and settings.enable_cpu_fallback:
logger.warning(f"Insufficient GPU memory after cleanup. Activating CPU fallback.")
self._activate_cpu_fallback()
return True
else:
logger.error(f"Insufficient GPU memory after cleanup: {free_memory:.0f}MB")
return False
logger.debug(f"GPU memory check passed: {free_memory:.0f}MB available")
return True
return True
except Exception as e:
logger.warning(f"GPU memory check failed: {e}")
return True # Continue processing even if check fails
def _activate_cpu_fallback(self):
"""
Activate CPU fallback mode when GPU memory is insufficient.
This disables GPU usage for the current service instance.
"""
if self._cpu_fallback_active:
return # Already in CPU mode
logger.warning("=== CPU FALLBACK MODE ACTIVATED ===")
logger.warning("GPU memory insufficient, switching to CPU processing")
logger.warning("Performance will be significantly reduced")
self._cpu_fallback_active = True
self.use_gpu = False
# Update GPU info to reflect fallback
self.gpu_info['cpu_fallback'] = True
self.gpu_info['fallback_reason'] = 'GPU memory insufficient'
# Clear GPU cache to free memory
if self._memory_guard:
self._memory_guard.clear_gpu_cache()
def _restore_gpu_mode(self):
"""
Attempt to restore GPU mode after CPU fallback.
Called when memory pressure has been relieved.
"""
if not self._cpu_fallback_active:
return
if not self.gpu_available:
return
# Check if GPU memory is now available
if self._memory_guard:
is_available, stats = self._memory_guard.check_memory(
required_mb=settings.structure_model_memory_mb
)
if is_available:
logger.info("GPU memory available, restoring GPU mode")
self._cpu_fallback_active = False
self.use_gpu = True
self.gpu_info.pop('cpu_fallback', None)
self.gpu_info.pop('fallback_reason', None)
def convert_pdf_to_images(self, pdf_path: Path, output_dir: Path) -> List[Path]:
"""
Convert PDF to images (one per page)
@@ -626,6 +834,24 @@ class OCRService:
threshold = confidence_threshold if confidence_threshold is not None else self.confidence_threshold
try:
# Pre-operation memory check: Try to restore GPU if in fallback and memory available
if self._cpu_fallback_active:
self._restore_gpu_mode()
if not self._cpu_fallback_active:
logger.info("GPU mode restored for processing")
# Initial memory check before starting any heavy processing
# Estimate memory requirement based on image type
estimated_memory_mb = 2500 # Conservative estimate for full OCR + layout
if detect_layout:
estimated_memory_mb += 500 # Additional for PP-StructureV3
if not self.check_gpu_memory(required_mb=estimated_memory_mb, enable_fallback=True):
logger.warning(
f"Pre-operation memory check failed ({estimated_memory_mb}MB required). "
f"Processing will attempt to proceed but may encounter issues."
)
# Check if file is Office document
if self.office_converter.is_office_document(image_path):
logger.info(f"Detected Office document: {image_path.name}, converting to PDF")
@@ -748,9 +974,12 @@ class OCRService:
# Get OCR engine (for non-PDF images)
ocr_engine = self.get_ocr_engine(lang)
# Check GPU memory before OCR processing
if not self.check_gpu_memory(required_mb=1500):
logger.warning("Insufficient GPU memory for OCR, attempting to proceed anyway")
# Secondary memory check before OCR processing
if not self.check_gpu_memory(required_mb=1500, enable_fallback=True):
logger.warning(
f"OCR memory check: insufficient GPU memory (1500MB required). "
f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU (low memory)'}"
)
# Get the actual image dimensions that OCR will use
from PIL import Image
@@ -950,6 +1179,18 @@ class OCRService:
Tuple of (layout_data, images_metadata)
"""
try:
# Pre-operation memory check for layout analysis
if self._cpu_fallback_active:
self._restore_gpu_mode()
if not self._cpu_fallback_active:
logger.info("GPU mode restored for layout analysis")
if not self.check_gpu_memory(required_mb=2000, enable_fallback=True):
logger.warning(
f"Layout analysis pre-check: insufficient GPU memory (2000MB required). "
f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU'}"
)
structure_engine = self._ensure_structure_engine(pp_structure_params)
# Try enhanced processing first
@@ -998,11 +1239,21 @@ class OCRService:
# Standard processing (original implementation)
logger.info(f"Running standard layout analysis on {image_path.name}")
# Check GPU memory before processing
if not self.check_gpu_memory(required_mb=2000):
logger.warning("Insufficient GPU memory for PP-StructureV3, attempting to proceed anyway")
# Memory check before PP-StructureV3 processing
if not self.check_gpu_memory(required_mb=2000, enable_fallback=True):
logger.warning(
f"PP-StructureV3 memory check: insufficient GPU memory (2000MB required). "
f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU (low memory)'}"
)
results = structure_engine.predict(str(image_path))
# Use prediction semaphore to control concurrent predictions
# This prevents OOM errors from multiple simultaneous PP-StructureV3.predict() calls
with prediction_context(timeout=settings.service_acquire_timeout_seconds) as acquired:
if not acquired:
logger.error("Failed to acquire prediction slot (timeout), returning empty layout")
return None, []
results = structure_engine.predict(str(image_path))
layout_elements = []
images_metadata = []
@@ -1254,6 +1505,46 @@ class OCRService:
if temp_pdf_path:
unified_doc.metadata.original_filename = file_path.name
# HYBRID MODE: Check if Direct track missed images (e.g., inline image blocks)
# If so, use OCR to extract images and merge them into the Direct result
pages_with_missing_images = self.direct_extraction_engine.check_document_for_missing_images(
actual_file_path
)
if pages_with_missing_images:
logger.info(f"Hybrid mode: Direct track missing images on pages {pages_with_missing_images}, using OCR to extract images")
try:
# Run OCR on the file to extract images
ocr_result = self.process_file_traditional(
actual_file_path, lang, detect_layout=True,
confidence_threshold=confidence_threshold,
output_dir=output_dir, pp_structure_params=pp_structure_params
)
# Convert OCR result to extract images
ocr_unified = self.ocr_to_unified_converter.convert(
ocr_result, actual_file_path, 0.0, lang
)
# Merge OCR-extracted images into Direct track result
images_added = self._merge_ocr_images_into_direct(
unified_doc, ocr_unified, pages_with_missing_images
)
if images_added > 0:
logger.info(f"Hybrid mode: Added {images_added} images from OCR to Direct track result")
unified_doc.metadata.processing_track = ProcessingTrack.HYBRID
else:
# Fallback: OCR didn't find images either, render inline image blocks directly
logger.info("Hybrid mode: OCR didn't find images, falling back to inline image rendering")
images_added = self.direct_extraction_engine.render_inline_image_regions(
actual_file_path, unified_doc, pages_with_missing_images, output_dir
)
if images_added > 0:
logger.info(f"Hybrid mode: Rendered {images_added} inline image regions")
unified_doc.metadata.processing_track = ProcessingTrack.HYBRID
except Exception as e:
logger.warning(f"Hybrid mode image extraction failed: {e}")
# Continue with Direct track result without images
# Use OCR track (either by recommendation or fallback)
if recommendation.track == "ocr":
# Use OCR for scanned documents, images, etc.
@@ -1269,17 +1560,19 @@ class OCRService:
)
unified_doc.document_id = document_id
# Update processing track metadata
unified_doc.metadata.processing_track = (
ProcessingTrack.DIRECT if recommendation.track == "direct"
else ProcessingTrack.OCR
)
# Update processing track metadata (only if not already set to HYBRID)
if unified_doc.metadata.processing_track != ProcessingTrack.HYBRID:
unified_doc.metadata.processing_track = (
ProcessingTrack.DIRECT if recommendation.track == "direct"
else ProcessingTrack.OCR
)
# Calculate total processing time
processing_time = (datetime.now() - start_time).total_seconds()
unified_doc.metadata.processing_time = processing_time
logger.info(f"Document processing completed in {processing_time:.2f}s using {recommendation.track} track")
actual_track = unified_doc.metadata.processing_track.value
logger.info(f"Document processing completed in {processing_time:.2f}s using {actual_track} track")
return unified_doc
@@ -1290,6 +1583,75 @@ class OCRService:
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
)
def _merge_ocr_images_into_direct(
self,
direct_doc: 'UnifiedDocument',
ocr_doc: 'UnifiedDocument',
pages_with_missing_images: List[int]
) -> int:
"""
Merge OCR-extracted images into Direct track result.
This is used in hybrid mode when Direct track couldn't extract certain
images (like logos composed of inline image blocks).
Args:
direct_doc: UnifiedDocument from Direct track
ocr_doc: UnifiedDocument from OCR track
pages_with_missing_images: List of page numbers (1-indexed) that need images
Returns:
Number of images added
"""
images_added = 0
try:
# Get image element types to look for
image_types = {ElementType.FIGURE, ElementType.IMAGE, ElementType.LOGO}
for page_num in pages_with_missing_images:
# Find the target page in direct_doc
direct_page = None
for page in direct_doc.pages:
if page.page_number == page_num:
direct_page = page
break
if not direct_page:
continue
# Find the source page in ocr_doc
ocr_page = None
for page in ocr_doc.pages:
if page.page_number == page_num:
ocr_page = page
break
if not ocr_page:
continue
# Extract image elements from OCR page
for element in ocr_page.elements:
if element.type in image_types:
# Assign new element ID to avoid conflicts
new_element_id = f"hybrid_{element.element_id}"
element.element_id = new_element_id
# Add to direct page
direct_page.elements.append(element)
images_added += 1
logger.debug(f"Added image element {new_element_id} to page {page_num}")
# Update image count in direct_doc metadata
if images_added > 0:
current_images = direct_doc.metadata.total_images or 0
direct_doc.metadata.total_images = current_images + images_added
except Exception as e:
logger.error(f"Error merging OCR images into Direct track: {e}")
return images_added
def process_file_traditional(
self,
file_path: Path,
@@ -1441,13 +1803,16 @@ class OCRService:
UnifiedDocument if dual-track is enabled and use_dual_track=True,
Dict with legacy format otherwise
"""
if use_dual_track and self.dual_track_enabled:
# Use dual-track processing
# Use dual-track processing if:
# 1. use_dual_track is True (auto-detection), OR
# 2. force_track is specified (explicit track selection)
if (use_dual_track or force_track) and self.dual_track_enabled:
# Use dual-track processing (or forced track)
return self.process_with_dual_track(
file_path, lang, detect_layout, confidence_threshold, output_dir, force_track, pp_structure_params
)
else:
# Use traditional OCR processing
# Use traditional OCR processing (no force_track support)
return self.process_file_traditional(
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
)

View File

@@ -572,8 +572,10 @@ class PDFGeneratorService:
processing_track = unified_doc.metadata.get('processing_track')
# Route to track-specific rendering method
is_direct_track = (processing_track == 'direct' or
processing_track == ProcessingTrack.DIRECT)
# ProcessingTrack is (str, Enum), so comparing with enum value works for both string and enum
# HYBRID track uses Direct track rendering (Direct text/tables + OCR images)
is_direct_track = (processing_track == ProcessingTrack.DIRECT or
processing_track == ProcessingTrack.HYBRID)
logger.info(f"Processing track: {processing_track}, using {'Direct' if is_direct_track else 'OCR'} track rendering")
@@ -675,8 +677,11 @@ class PDFGeneratorService:
logger.info("=== Direct Track PDF Generation ===")
logger.info(f"Total pages: {len(unified_doc.pages)}")
# Set current track for helper methods
self.current_processing_track = 'direct'
# Set current track for helper methods (may be DIRECT or HYBRID)
if hasattr(unified_doc, 'metadata') and unified_doc.metadata:
self.current_processing_track = unified_doc.metadata.processing_track
else:
self.current_processing_track = ProcessingTrack.DIRECT
# Get page dimensions from first page (for canvas initialization)
if not unified_doc.pages:
@@ -1074,11 +1079,16 @@ class PDFGeneratorService:
# *** 優先級 1: 檢查 ocr_dimensions (UnifiedDocument 轉換來的) ***
if 'ocr_dimensions' in ocr_data:
dims = ocr_data['ocr_dimensions']
w = float(dims.get('width', 0))
h = float(dims.get('height', 0))
if w > 0 and h > 0:
logger.info(f"使用 ocr_dimensions 欄位的頁面尺寸: {w:.1f} x {h:.1f}")
return (w, h)
# Handle both dict format {'width': w, 'height': h} and
# list format [{'page': 1, 'width': w, 'height': h}, ...]
if isinstance(dims, list) and len(dims) > 0:
dims = dims[0] # Use first page dimensions
if isinstance(dims, dict):
w = float(dims.get('width', 0))
h = float(dims.get('height', 0))
if w > 0 and h > 0:
logger.info(f"使用 ocr_dimensions 欄位的頁面尺寸: {w:.1f} x {h:.1f}")
return (w, h)
# *** 優先級 2: 檢查原始 JSON 的 dimensions ***
if 'dimensions' in ocr_data:
@@ -1418,8 +1428,8 @@ class PDFGeneratorService:
# Set font with track-specific styling
# Note: OCR track has no StyleInfo (extracted from images), so no advanced formatting
style_info = region.get('style')
is_direct_track = (self.current_processing_track == 'direct' or
self.current_processing_track == ProcessingTrack.DIRECT)
is_direct_track = (self.current_processing_track == ProcessingTrack.DIRECT or
self.current_processing_track == ProcessingTrack.HYBRID)
if style_info and is_direct_track:
# Direct track: Apply rich styling from StyleInfo
@@ -1661,10 +1671,15 @@ class PDFGeneratorService:
return
# Construct full path to image
# saved_path is relative to result_dir (e.g., "imgs/element_id.png")
image_path = result_dir / image_path_str
# Fallback for legacy data
if not image_path.exists():
logger.warning(f"Image not found: {image_path}")
image_path = result_dir / Path(image_path_str).name
if not image_path.exists():
logger.warning(f"Image not found: {image_path_str} (in {result_dir})")
return
# Get bbox for positioning
@@ -2289,12 +2304,30 @@ class PDFGeneratorService:
col_widths = element.metadata['column_widths']
logger.debug(f"Using extracted column widths: {col_widths}")
# Create table without rowHeights (will use canvas scaling instead)
t = Table(table_content, colWidths=col_widths)
# Use original row heights from extraction if available
# Row heights must match the number of data rows exactly
row_heights_list = None
if element.metadata and 'row_heights' in element.metadata:
extracted_row_heights = element.metadata['row_heights']
num_data_rows = len(table_content)
num_height_rows = len(extracted_row_heights)
if num_height_rows == num_data_rows:
row_heights_list = extracted_row_heights
logger.debug(f"Using extracted row heights ({num_height_rows} rows): {row_heights_list}")
else:
# Row counts don't match - this can happen with merged cells or empty rows
logger.warning(f"Row height mismatch: {num_height_rows} heights for {num_data_rows} data rows, falling back to auto-sizing")
# Create table with both column widths and row heights for accurate sizing
t = Table(table_content, colWidths=col_widths, rowHeights=row_heights_list)
# Apply style with minimal padding to reduce table extension
# Use Chinese font to support special characters (℃, μm, ≦, ×, Ω, etc.)
font_for_table = self.font_name if self.font_registered else 'Helvetica'
style = TableStyle([
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
('FONTNAME', (0, 0), (-1, -1), font_for_table),
('FONTSIZE', (0, 0), (-1, -1), 8),
('ALIGN', (0, 0), (-1, -1), 'LEFT'),
('VALIGN', (0, 0), (-1, -1), 'TOP'),
@@ -2307,8 +2340,8 @@ class PDFGeneratorService:
])
t.setStyle(style)
# CRITICAL: Use canvas scaling to fit table within bbox
# This is more reliable than rowHeights which doesn't always work
# Use canvas scaling as fallback to fit table within bbox
# With proper row heights, scaling should be minimal (close to 1.0)
# Step 1: Wrap to get actual rendered size
actual_width, actual_height = t.wrapOn(pdf_canvas, table_width * 10, table_height * 10)
@@ -2358,11 +2391,16 @@ class PDFGeneratorService:
logger.warning(f"No image path for element {element.element_id}")
return
# Construct full path
# Construct full path to image
# saved_path is relative to result_dir (e.g., "document_id_p1_img0.png")
image_path = result_dir / image_path_str
# Fallback for legacy data
if not image_path.exists():
logger.warning(f"Image not found: {image_path}")
image_path = result_dir / Path(image_path_str).name
if not image_path.exists():
logger.warning(f"Image not found: {image_path_str} (in {result_dir})")
return
# Get bbox
@@ -2388,7 +2426,7 @@ class PDFGeneratorService:
preserveAspectRatio=True
)
logger.debug(f"Drew image: {image_path_str}")
logger.debug(f"Drew image: {image_path} (from: {original_path_str})")
except Exception as e:
logger.error(f"Failed to draw image element {element.element_id}: {e}")

View File

@@ -21,6 +21,8 @@ except ImportError:
import paddle
from paddleocr import PPStructureV3
from app.models.unified_document import ElementType
from app.core.config import settings
from app.services.memory_manager import prediction_context
logger = logging.getLogger(__name__)
@@ -96,8 +98,22 @@ class PPStructureEnhanced:
try:
logger.info(f"Enhanced PP-StructureV3 analysis on {image_path.name}")
# Perform structure analysis
results = self.structure_engine.predict(str(image_path))
# Perform structure analysis with semaphore control
# This prevents OOM errors from multiple simultaneous predictions
with prediction_context(timeout=settings.service_acquire_timeout_seconds) as acquired:
if not acquired:
logger.error("Failed to acquire prediction slot (timeout), returning empty result")
return {
'has_parsing_res_list': False,
'elements': [],
'total_elements': 0,
'images': [],
'tables': [],
'element_types': {},
'error': 'Prediction slot timeout'
}
results = self.structure_engine.predict(str(image_path))
all_elements = []
all_images = []

View File

@@ -0,0 +1,468 @@
"""
Tool_OCR - OCR Service Pool
Manages a pool of OCRService instances to prevent duplicate model loading
and control concurrent GPU operations.
"""
import asyncio
import logging
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from app.services.memory_manager import get_model_manager, MemoryConfig
if TYPE_CHECKING:
from app.services.ocr_service import OCRService
logger = logging.getLogger(__name__)
class ServiceState(Enum):
"""State of a pooled service"""
AVAILABLE = "available"
IN_USE = "in_use"
UNHEALTHY = "unhealthy"
INITIALIZING = "initializing"
@dataclass
class PooledService:
"""Wrapper for a pooled OCRService instance"""
service: Any # OCRService
device: str
state: ServiceState = ServiceState.AVAILABLE
created_at: float = field(default_factory=time.time)
last_used: float = field(default_factory=time.time)
use_count: int = 0
error_count: int = 0
current_task_id: Optional[str] = None
class PoolConfig:
"""Configuration for the service pool"""
def __init__(
self,
max_services_per_device: int = 1,
max_total_services: int = 2,
acquire_timeout_seconds: float = 300.0,
max_queue_size: int = 50,
health_check_interval_seconds: int = 60,
max_consecutive_errors: int = 3,
service_idle_timeout_seconds: int = 600,
enable_auto_scaling: bool = False,
):
self.max_services_per_device = max_services_per_device
self.max_total_services = max_total_services
self.acquire_timeout_seconds = acquire_timeout_seconds
self.max_queue_size = max_queue_size
self.health_check_interval_seconds = health_check_interval_seconds
self.max_consecutive_errors = max_consecutive_errors
self.service_idle_timeout_seconds = service_idle_timeout_seconds
self.enable_auto_scaling = enable_auto_scaling
class OCRServicePool:
"""
Pool of OCRService instances with concurrency control.
Features:
- Per-device instance management (one service per GPU)
- Queue-based task distribution
- Semaphore-based concurrency limits
- Health monitoring
- Automatic service recovery
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
"""Singleton pattern"""
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, config: Optional[PoolConfig] = None):
if self._initialized:
return
self.config = config or PoolConfig()
self.services: Dict[str, List[PooledService]] = {}
self.semaphores: Dict[str, threading.Semaphore] = {}
self.queues: Dict[str, List] = {}
self._pool_lock = threading.RLock()
self._condition = threading.Condition(self._pool_lock)
# Metrics
self._metrics = {
"total_acquisitions": 0,
"total_releases": 0,
"total_timeouts": 0,
"total_errors": 0,
"queue_waits": 0,
}
# Initialize default device pool
self._initialize_device("GPU:0")
self._initialized = True
logger.info("OCRServicePool initialized")
def _initialize_device(self, device: str):
"""Initialize pool resources for a device"""
with self._pool_lock:
if device not in self.services:
self.services[device] = []
self.semaphores[device] = threading.Semaphore(
self.config.max_services_per_device
)
self.queues[device] = []
logger.info(f"Initialized pool for device {device}")
def _create_service(self, device: str) -> PooledService:
"""
Create a new OCRService instance for the pool.
Args:
device: Device identifier (e.g., "GPU:0", "CPU")
Returns:
PooledService wrapper
"""
# Import here to avoid circular imports
from app.services.ocr_service import OCRService
logger.info(f"Creating new OCRService for device {device}")
start_time = time.time()
# Create service instance
service = OCRService()
creation_time = time.time() - start_time
logger.info(f"OCRService created in {creation_time:.2f}s for device {device}")
return PooledService(
service=service,
device=device,
state=ServiceState.AVAILABLE
)
def acquire(
self,
device: str = "GPU:0",
timeout: Optional[float] = None,
task_id: Optional[str] = None
) -> Optional[PooledService]:
"""
Acquire an OCRService from the pool.
Args:
device: Preferred device (e.g., "GPU:0")
timeout: Maximum time to wait for a service
task_id: Optional task ID for tracking
Returns:
PooledService if available, None if timeout
"""
timeout = timeout or self.config.acquire_timeout_seconds
self._initialize_device(device)
start_time = time.time()
deadline = start_time + timeout
with self._condition:
while True:
# Try to get an available service
service = self._try_acquire_service(device, task_id)
if service:
self._metrics["total_acquisitions"] += 1
return service
# Check if we can create a new service
if self._can_create_service(device):
try:
pooled = self._create_service(device)
pooled.state = ServiceState.IN_USE
pooled.current_task_id = task_id
pooled.use_count += 1
self.services[device].append(pooled)
self._metrics["total_acquisitions"] += 1
logger.info(f"Created and acquired new service for {device}")
return pooled
except Exception as e:
logger.error(f"Failed to create service for {device}: {e}")
self._metrics["total_errors"] += 1
# Wait for a service to become available
remaining = deadline - time.time()
if remaining <= 0:
self._metrics["total_timeouts"] += 1
logger.warning(f"Timeout waiting for service on {device}")
return None
self._metrics["queue_waits"] += 1
logger.debug(f"Waiting for service on {device} (timeout: {remaining:.1f}s)")
self._condition.wait(timeout=min(remaining, 1.0))
def _try_acquire_service(self, device: str, task_id: Optional[str]) -> Optional[PooledService]:
"""Try to acquire an available service without waiting"""
for pooled in self.services.get(device, []):
if pooled.state == ServiceState.AVAILABLE:
pooled.state = ServiceState.IN_USE
pooled.last_used = time.time()
pooled.use_count += 1
pooled.current_task_id = task_id
logger.debug(f"Acquired existing service for {device} (use #{pooled.use_count})")
return pooled
return None
def _can_create_service(self, device: str) -> bool:
"""Check if a new service can be created"""
device_count = len(self.services.get(device, []))
total_count = sum(len(services) for services in self.services.values())
return (
device_count < self.config.max_services_per_device and
total_count < self.config.max_total_services
)
def release(self, pooled: PooledService, error: Optional[Exception] = None):
"""
Release a service back to the pool.
Args:
pooled: The pooled service to release
error: Optional error that occurred during use
"""
with self._condition:
if error:
pooled.error_count += 1
self._metrics["total_errors"] += 1
logger.warning(f"Service released with error: {error}")
# Mark unhealthy if too many errors
if pooled.error_count >= self.config.max_consecutive_errors:
pooled.state = ServiceState.UNHEALTHY
logger.error(f"Service marked unhealthy after {pooled.error_count} errors")
else:
pooled.state = ServiceState.AVAILABLE
else:
pooled.error_count = 0 # Reset error count on success
pooled.state = ServiceState.AVAILABLE
pooled.last_used = time.time()
pooled.current_task_id = None
self._metrics["total_releases"] += 1
# Clean up GPU memory after release
try:
model_manager = get_model_manager()
model_manager.memory_guard.clear_gpu_cache()
except Exception as e:
logger.debug(f"Cache clear after release failed: {e}")
# Notify waiting threads
self._condition.notify_all()
logger.debug(f"Service released for device {pooled.device}")
@contextmanager
def acquire_context(
self,
device: str = "GPU:0",
timeout: Optional[float] = None,
task_id: Optional[str] = None
):
"""
Context manager for acquiring and releasing a service.
Usage:
with pool.acquire_context("GPU:0") as pooled:
result = pooled.service.process(...)
"""
pooled = None
error = None
try:
pooled = self.acquire(device, timeout, task_id)
if pooled is None:
raise TimeoutError(f"Failed to acquire service for {device}")
yield pooled
except Exception as e:
error = e
raise
finally:
if pooled:
self.release(pooled, error)
def get_service(self, device: str = "GPU:0") -> Optional["OCRService"]:
"""
Get a service directly (for backward compatibility).
This acquires a service and returns the underlying OCRService.
The caller is responsible for calling release_service() when done.
Args:
device: Device identifier
Returns:
OCRService instance or None
"""
pooled = self.acquire(device)
if pooled:
return pooled.service
return None
def get_pool_stats(self) -> Dict:
"""Get current pool statistics"""
with self._pool_lock:
stats = {
"devices": {},
"metrics": self._metrics.copy(),
"total_services": 0,
"available_services": 0,
"in_use_services": 0,
}
for device, services in self.services.items():
available = sum(1 for s in services if s.state == ServiceState.AVAILABLE)
in_use = sum(1 for s in services if s.state == ServiceState.IN_USE)
unhealthy = sum(1 for s in services if s.state == ServiceState.UNHEALTHY)
stats["devices"][device] = {
"total": len(services),
"available": available,
"in_use": in_use,
"unhealthy": unhealthy,
"max_allowed": self.config.max_services_per_device,
}
stats["total_services"] += len(services)
stats["available_services"] += available
stats["in_use_services"] += in_use
return stats
def health_check(self) -> Dict:
"""
Perform health check on all pooled services.
Returns:
Health check results
"""
results = {
"healthy": True,
"services": [],
"timestamp": time.time()
}
with self._pool_lock:
for device, services in self.services.items():
for idx, pooled in enumerate(services):
service_health = {
"device": device,
"index": idx,
"state": pooled.state.value,
"error_count": pooled.error_count,
"use_count": pooled.use_count,
"idle_seconds": time.time() - pooled.last_used,
}
# Check if service is responsive
if pooled.state == ServiceState.AVAILABLE:
try:
# Simple check - verify service has required attributes
has_process = hasattr(pooled.service, 'process')
has_gpu_status = hasattr(pooled.service, 'get_gpu_status')
service_health["responsive"] = has_process and has_gpu_status
except Exception as e:
service_health["responsive"] = False
service_health["error"] = str(e)
results["healthy"] = False
else:
service_health["responsive"] = pooled.state != ServiceState.UNHEALTHY
if pooled.state == ServiceState.UNHEALTHY:
results["healthy"] = False
results["services"].append(service_health)
return results
def recover_unhealthy(self):
"""
Attempt to recover unhealthy services.
"""
with self._pool_lock:
for device, services in self.services.items():
for idx, pooled in enumerate(services):
if pooled.state == ServiceState.UNHEALTHY:
logger.info(f"Attempting to recover unhealthy service {device}:{idx}")
try:
# Remove old service
services.remove(pooled)
# Create new service
new_pooled = self._create_service(device)
services.append(new_pooled)
logger.info(f"Successfully recovered service {device}:{idx}")
except Exception as e:
logger.error(f"Failed to recover service {device}:{idx}: {e}")
def shutdown(self):
"""
Shutdown the pool and cleanup all services.
"""
logger.info("OCRServicePool shutdown started")
with self._pool_lock:
for device, services in self.services.items():
for pooled in services:
try:
# Clean up service resources
if hasattr(pooled.service, 'cleanup_gpu_memory'):
pooled.service.cleanup_gpu_memory()
except Exception as e:
logger.warning(f"Error cleaning up service: {e}")
# Clear all pools
self.services.clear()
self.semaphores.clear()
self.queues.clear()
logger.info("OCRServicePool shutdown completed")
# Global singleton instance
_service_pool: Optional[OCRServicePool] = None
def get_service_pool(config: Optional[PoolConfig] = None) -> OCRServicePool:
"""
Get the global OCRServicePool instance.
Args:
config: Optional configuration (only used on first call)
Returns:
OCRServicePool singleton instance
"""
global _service_pool
if _service_pool is None:
_service_pool = OCRServicePool(config)
return _service_pool
def shutdown_service_pool():
"""Shutdown the global service pool"""
global _service_pool
if _service_pool is not None:
_service_pool.shutdown()
_service_pool = None