Phase 1-3 implementation of extract-table-cell-boxes proposal: - Add enable_table_cell_boxes_extraction config option - Implement lazy-loaded SLANeXt model caching in PPStructureEnhanced - Add _extract_cell_boxes_with_slanet() method for direct model invocation - Supplement PPStructureV3 table processing with SLANeXt cell boxes - Add _compute_table_grid_from_cell_boxes() for column width calculation - Modify draw_table_region() to use cell_boxes for accurate layout Key features: - Auto-detect table type (wired/wireless) using PP-LCNet classifier - Convert 8-point polygon bbox to 4-point rectangle - Graceful fallback to equal distribution when cell_boxes unavailable - Proper coordinate transformation with scaling support 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
911 lines
38 KiB
Python
911 lines
38 KiB
Python
"""
|
||
Enhanced PP-StructureV3 processing with full element extraction
|
||
|
||
This module provides enhanced PP-StructureV3 processing that extracts all
|
||
23 element types with their bbox coordinates and reading order.
|
||
"""
|
||
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Tuple, Any, TYPE_CHECKING
|
||
import json
|
||
import gc
|
||
|
||
# Import ScalingInfo for type checking (avoid circular imports at runtime)
|
||
if TYPE_CHECKING:
|
||
from app.services.layout_preprocessing_service import ScalingInfo
|
||
|
||
# Optional torch import for additional GPU memory management
|
||
try:
|
||
import torch
|
||
TORCH_AVAILABLE = True
|
||
except ImportError:
|
||
TORCH_AVAILABLE = False
|
||
|
||
import paddle
|
||
from paddleocr import PPStructureV3
|
||
from PIL import Image
|
||
import numpy as np
|
||
from app.models.unified_document import ElementType
|
||
from app.core.config import settings
|
||
from app.services.memory_manager import prediction_context
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class PPStructureEnhanced:
|
||
"""
|
||
Enhanced PP-StructureV3 processor that extracts all available element types
|
||
and structure information from parsing_res_list.
|
||
"""
|
||
|
||
# Mapping from PP-StructureV3 types to our ElementType
|
||
ELEMENT_TYPE_MAPPING = {
|
||
'title': ElementType.TITLE,
|
||
'paragraph_title': ElementType.TITLE, # PP-StructureV3 block_label
|
||
'text': ElementType.TEXT,
|
||
'paragraph': ElementType.PARAGRAPH,
|
||
'figure': ElementType.FIGURE,
|
||
'figure_caption': ElementType.CAPTION,
|
||
'table': ElementType.TABLE,
|
||
'table_caption': ElementType.TABLE_CAPTION,
|
||
'header': ElementType.HEADER,
|
||
'footer': ElementType.FOOTER,
|
||
'reference': ElementType.REFERENCE,
|
||
'equation': ElementType.EQUATION,
|
||
'formula': ElementType.FORMULA,
|
||
'list-item': ElementType.LIST_ITEM,
|
||
'list': ElementType.LIST,
|
||
'code': ElementType.CODE,
|
||
'footnote': ElementType.FOOTNOTE,
|
||
'page-number': ElementType.PAGE_NUMBER,
|
||
'watermark': ElementType.WATERMARK,
|
||
'signature': ElementType.SIGNATURE,
|
||
'stamp': ElementType.STAMP,
|
||
'logo': ElementType.LOGO,
|
||
'barcode': ElementType.BARCODE,
|
||
'qr-code': ElementType.QR_CODE,
|
||
# Default fallback
|
||
'image': ElementType.IMAGE,
|
||
'chart': ElementType.CHART,
|
||
'diagram': ElementType.DIAGRAM,
|
||
}
|
||
|
||
def __init__(self, structure_engine: PPStructureV3):
|
||
"""
|
||
Initialize with existing PP-StructureV3 engine.
|
||
|
||
Args:
|
||
structure_engine: Initialized PPStructureV3 instance
|
||
"""
|
||
self.structure_engine = structure_engine
|
||
|
||
# Lazy-loaded SLANeXt models for cell boxes extraction
|
||
# These are loaded on-demand when enable_table_cell_boxes_extraction is True
|
||
self._slanet_wired_model = None
|
||
self._slanet_wireless_model = None
|
||
self._table_cls_model = None
|
||
|
||
def _get_slanet_model(self, is_wired: bool = True):
|
||
"""
|
||
Get or create SLANeXt model for cell boxes extraction (lazy loading).
|
||
|
||
Args:
|
||
is_wired: True for wired (bordered) tables, False for wireless
|
||
|
||
Returns:
|
||
SLANeXt model instance or None if loading fails
|
||
"""
|
||
if not settings.enable_table_cell_boxes_extraction:
|
||
return None
|
||
|
||
try:
|
||
from paddlex import create_model
|
||
|
||
if is_wired:
|
||
if self._slanet_wired_model is None:
|
||
model_name = settings.wired_table_model_name or "SLANeXt_wired"
|
||
logger.info(f"Loading SLANeXt wired model: {model_name}")
|
||
self._slanet_wired_model = create_model(model_name)
|
||
return self._slanet_wired_model
|
||
else:
|
||
if self._slanet_wireless_model is None:
|
||
model_name = settings.wireless_table_model_name or "SLANeXt_wireless"
|
||
logger.info(f"Loading SLANeXt wireless model: {model_name}")
|
||
self._slanet_wireless_model = create_model(model_name)
|
||
return self._slanet_wireless_model
|
||
except Exception as e:
|
||
logger.error(f"Failed to load SLANeXt model: {e}")
|
||
return None
|
||
|
||
def _get_table_classifier(self):
|
||
"""
|
||
Get or create table classification model (lazy loading).
|
||
|
||
Returns:
|
||
Table classifier model instance or None if loading fails
|
||
"""
|
||
if not settings.enable_table_cell_boxes_extraction:
|
||
return None
|
||
|
||
try:
|
||
from paddlex import create_model
|
||
|
||
if self._table_cls_model is None:
|
||
model_name = settings.table_classification_model_name or "PP-LCNet_x1_0_table_cls"
|
||
logger.info(f"Loading table classification model: {model_name}")
|
||
self._table_cls_model = create_model(model_name)
|
||
return self._table_cls_model
|
||
except Exception as e:
|
||
logger.error(f"Failed to load table classifier: {e}")
|
||
return None
|
||
|
||
def _extract_cell_boxes_with_slanet(
|
||
self,
|
||
table_image: np.ndarray,
|
||
table_bbox: List[float],
|
||
is_wired: Optional[bool] = None
|
||
) -> Optional[List[List[float]]]:
|
||
"""
|
||
Extract cell bounding boxes using direct SLANeXt model call.
|
||
|
||
This supplements PPStructureV3 which doesn't expose cell boxes in its output.
|
||
|
||
Args:
|
||
table_image: Cropped table image as numpy array (BGR format)
|
||
table_bbox: Table bounding box in page coordinates [x1, y1, x2, y2]
|
||
is_wired: If None, auto-detect using classifier. True for bordered tables.
|
||
|
||
Returns:
|
||
List of cell bounding boxes in page coordinates [[x1,y1,x2,y2], ...],
|
||
or None if extraction fails
|
||
"""
|
||
if not settings.enable_table_cell_boxes_extraction:
|
||
return None
|
||
|
||
try:
|
||
# Auto-detect table type if not specified
|
||
if is_wired is None:
|
||
classifier = self._get_table_classifier()
|
||
if classifier:
|
||
try:
|
||
cls_result = classifier.predict(table_image)
|
||
# PP-LCNet returns classification result
|
||
for res in cls_result:
|
||
label_names = res.get('label_names', [])
|
||
if label_names:
|
||
is_wired = 'wired' in str(label_names[0]).lower()
|
||
logger.debug(f"Table classified as: {'wired' if is_wired else 'wireless'}")
|
||
break
|
||
except Exception as e:
|
||
logger.warning(f"Table classification failed, defaulting to wired: {e}")
|
||
is_wired = True
|
||
else:
|
||
is_wired = True # Default to wired if classifier unavailable
|
||
|
||
# Get appropriate SLANeXt model
|
||
model = self._get_slanet_model(is_wired=is_wired)
|
||
if model is None:
|
||
return None
|
||
|
||
# Run SLANeXt prediction
|
||
results = model.predict(table_image)
|
||
|
||
# Extract cell boxes from result
|
||
cell_boxes = []
|
||
table_x, table_y = table_bbox[0], table_bbox[1]
|
||
|
||
for result in results:
|
||
# SLANeXt returns 'bbox' with 8-point polygon format
|
||
# [[x1,y1,x2,y2,x3,y3,x4,y4], ...]
|
||
boxes = result.get('bbox', [])
|
||
for box in boxes:
|
||
if isinstance(box, (list, tuple)):
|
||
if len(box) >= 8:
|
||
# 8-point polygon: convert to 4-point rectangle
|
||
xs = [box[i] for i in range(0, 8, 2)]
|
||
ys = [box[i] for i in range(1, 8, 2)]
|
||
x1, y1 = min(xs), min(ys)
|
||
x2, y2 = max(xs), max(ys)
|
||
elif len(box) >= 4:
|
||
# Already 4-point rectangle
|
||
x1, y1, x2, y2 = box[:4]
|
||
else:
|
||
continue
|
||
|
||
# Convert to absolute page coordinates
|
||
abs_box = [
|
||
float(x1 + table_x),
|
||
float(y1 + table_y),
|
||
float(x2 + table_x),
|
||
float(y2 + table_y)
|
||
]
|
||
cell_boxes.append(abs_box)
|
||
|
||
logger.info(f"SLANeXt extracted {len(cell_boxes)} cell boxes (is_wired={is_wired})")
|
||
return cell_boxes if cell_boxes else None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Cell boxes extraction with SLANeXt failed: {e}")
|
||
return None
|
||
|
||
def release_slanet_models(self):
|
||
"""Release SLANeXt models to free GPU memory."""
|
||
if self._slanet_wired_model is not None:
|
||
del self._slanet_wired_model
|
||
self._slanet_wired_model = None
|
||
logger.info("Released SLANeXt wired model")
|
||
|
||
if self._slanet_wireless_model is not None:
|
||
del self._slanet_wireless_model
|
||
self._slanet_wireless_model = None
|
||
logger.info("Released SLANeXt wireless model")
|
||
|
||
if self._table_cls_model is not None:
|
||
del self._table_cls_model
|
||
self._table_cls_model = None
|
||
logger.info("Released table classifier model")
|
||
|
||
gc.collect()
|
||
if TORCH_AVAILABLE:
|
||
torch.cuda.empty_cache()
|
||
|
||
def analyze_with_full_structure(
|
||
self,
|
||
image_path: Path,
|
||
output_dir: Optional[Path] = None,
|
||
current_page: int = 0,
|
||
preprocessed_image: Optional[Image.Image] = None,
|
||
scaling_info: Optional['ScalingInfo'] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Analyze document with full PP-StructureV3 capabilities.
|
||
|
||
Args:
|
||
image_path: Path to original image file (used for cropping extracted images)
|
||
output_dir: Optional output directory for saving extracted content
|
||
current_page: Current page number (0-based)
|
||
preprocessed_image: Optional preprocessed PIL Image for layout detection.
|
||
If provided, this is used for PP-Structure prediction,
|
||
but original image_path is still used for cropping images.
|
||
scaling_info: Optional ScalingInfo from preprocessing. If image was scaled
|
||
for layout detection, all bbox coordinates will be scaled back
|
||
to original image coordinates for proper cropping.
|
||
|
||
Returns:
|
||
Dictionary with complete structure information including:
|
||
- elements: List of all detected elements with types and bbox (in original coords)
|
||
- reading_order: Reading order indices
|
||
- images: Extracted images with metadata
|
||
- tables: Extracted tables with structure
|
||
"""
|
||
try:
|
||
logger.info(f"Enhanced PP-StructureV3 analysis on {image_path.name}")
|
||
if preprocessed_image:
|
||
logger.info("Using preprocessed image for layout detection")
|
||
|
||
# Perform structure analysis with semaphore control
|
||
# This prevents OOM errors from multiple simultaneous predictions
|
||
with prediction_context(timeout=settings.service_acquire_timeout_seconds) as acquired:
|
||
if not acquired:
|
||
logger.error("Failed to acquire prediction slot (timeout), returning empty result")
|
||
return {
|
||
'has_parsing_res_list': False,
|
||
'elements': [],
|
||
'total_elements': 0,
|
||
'images': [],
|
||
'tables': [],
|
||
'element_types': {},
|
||
'error': 'Prediction slot timeout'
|
||
}
|
||
|
||
# Use preprocessed image if provided, otherwise use original path
|
||
if preprocessed_image is not None:
|
||
# Convert PIL to numpy array (BGR format for PP-Structure)
|
||
predict_input = np.array(preprocessed_image)
|
||
if len(predict_input.shape) == 3 and predict_input.shape[2] == 3:
|
||
# Convert RGB to BGR
|
||
predict_input = predict_input[:, :, ::-1]
|
||
results = self.structure_engine.predict(predict_input)
|
||
else:
|
||
results = self.structure_engine.predict(str(image_path))
|
||
|
||
all_elements = []
|
||
all_images = []
|
||
all_tables = []
|
||
|
||
# Process each page result
|
||
for page_idx, page_result in enumerate(results):
|
||
# Try to access parsing_res_list (the complete structure)
|
||
parsing_res_list = None
|
||
|
||
# Method 1: Direct access to json attribute (check both top-level and res)
|
||
if hasattr(page_result, 'json'):
|
||
result_json = page_result.json
|
||
if isinstance(result_json, dict):
|
||
# Check top-level
|
||
if 'parsing_res_list' in result_json:
|
||
parsing_res_list = result_json['parsing_res_list']
|
||
logger.info(f"Found parsing_res_list at top level with {len(parsing_res_list)} elements")
|
||
# Check inside 'res' (new structure in paddlex)
|
||
elif 'res' in result_json and isinstance(result_json['res'], dict):
|
||
if 'parsing_res_list' in result_json['res']:
|
||
parsing_res_list = result_json['res']['parsing_res_list']
|
||
logger.info(f"Found parsing_res_list inside 'res' with {len(parsing_res_list)} elements")
|
||
|
||
# Method 2: Try direct dict access (LayoutParsingResultV2 inherits from dict)
|
||
elif isinstance(page_result, dict):
|
||
if 'parsing_res_list' in page_result:
|
||
parsing_res_list = page_result['parsing_res_list']
|
||
logger.info(f"Found parsing_res_list via dict access with {len(parsing_res_list)} elements")
|
||
elif 'res' in page_result and isinstance(page_result['res'], dict):
|
||
if 'parsing_res_list' in page_result['res']:
|
||
parsing_res_list = page_result['res']['parsing_res_list']
|
||
logger.info(f"Found parsing_res_list inside page_result['res'] with {len(parsing_res_list)} elements")
|
||
|
||
# Method 3: Try to access as attribute
|
||
elif hasattr(page_result, 'parsing_res_list'):
|
||
parsing_res_list = page_result.parsing_res_list
|
||
logger.info(f"Found parsing_res_list attribute with {len(parsing_res_list)} elements")
|
||
|
||
# Method 4: Check if result has to_dict method
|
||
elif hasattr(page_result, 'to_dict'):
|
||
result_dict = page_result.to_dict()
|
||
if 'parsing_res_list' in result_dict:
|
||
parsing_res_list = result_dict['parsing_res_list']
|
||
logger.info(f"Found parsing_res_list in to_dict with {len(parsing_res_list)} elements")
|
||
elif 'res' in result_dict and isinstance(result_dict['res'], dict):
|
||
if 'parsing_res_list' in result_dict['res']:
|
||
parsing_res_list = result_dict['res']['parsing_res_list']
|
||
logger.info(f"Found parsing_res_list in to_dict['res'] with {len(parsing_res_list)} elements")
|
||
|
||
# Process parsing_res_list if found
|
||
if parsing_res_list:
|
||
elements = self._process_parsing_res_list(
|
||
parsing_res_list, current_page, output_dir, image_path, scaling_info
|
||
)
|
||
all_elements.extend(elements)
|
||
|
||
# Extract tables and images from elements
|
||
for elem in elements:
|
||
if elem['type'] == ElementType.TABLE:
|
||
all_tables.append(elem)
|
||
elif elem['type'] in [ElementType.IMAGE, ElementType.FIGURE]:
|
||
all_images.append(elem)
|
||
else:
|
||
# Fallback to markdown if parsing_res_list not available
|
||
logger.warning("parsing_res_list not found, falling back to markdown")
|
||
elements = self._process_markdown_fallback(
|
||
page_result, current_page, output_dir
|
||
)
|
||
all_elements.extend(elements)
|
||
|
||
# Create reading order based on element positions
|
||
reading_order = self._determine_reading_order(all_elements)
|
||
|
||
return {
|
||
'elements': all_elements,
|
||
'total_elements': len(all_elements),
|
||
'reading_order': reading_order,
|
||
'tables': all_tables,
|
||
'images': all_images,
|
||
'element_types': self._count_element_types(all_elements),
|
||
'has_parsing_res_list': parsing_res_list is not None
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Enhanced PP-StructureV3 analysis error: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# Clean up GPU memory on error
|
||
try:
|
||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
torch.cuda.synchronize()
|
||
if paddle.device.is_compiled_with_cuda():
|
||
paddle.device.cuda.empty_cache()
|
||
gc.collect()
|
||
except:
|
||
pass # Ignore cleanup errors
|
||
|
||
return {
|
||
'elements': [],
|
||
'total_elements': 0,
|
||
'reading_order': [],
|
||
'tables': [],
|
||
'images': [],
|
||
'element_types': {},
|
||
'has_parsing_res_list': False,
|
||
'error': str(e)
|
||
}
|
||
|
||
def _process_parsing_res_list(
|
||
self,
|
||
parsing_res_list: List[Dict],
|
||
current_page: int,
|
||
output_dir: Optional[Path],
|
||
source_image_path: Optional[Path] = None,
|
||
scaling_info: Optional['ScalingInfo'] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
Process parsing_res_list to extract all elements.
|
||
|
||
Args:
|
||
parsing_res_list: List of parsed elements from PP-StructureV3
|
||
scaling_info: Scaling information for bbox coordinate restoration
|
||
current_page: Current page number
|
||
output_dir: Optional output directory
|
||
source_image_path: Path to source image for cropping image regions
|
||
|
||
Returns:
|
||
List of processed elements with normalized structure
|
||
"""
|
||
elements = []
|
||
|
||
for idx, item in enumerate(parsing_res_list):
|
||
# Debug: log the structure of the first item
|
||
if idx == 0:
|
||
logger.info(f"First parsing_res_list item structure: {list(item.keys()) if isinstance(item, dict) else type(item)}")
|
||
logger.info(f"First parsing_res_list item sample: {str(item)[:500]}")
|
||
|
||
# Extract element type (check both 'type' and 'block_label')
|
||
element_type = item.get('type', '') or item.get('block_label', 'text')
|
||
element_type = element_type.lower()
|
||
mapped_type = self.ELEMENT_TYPE_MAPPING.get(
|
||
element_type, ElementType.TEXT
|
||
)
|
||
|
||
# Extract bbox (check multiple possible keys)
|
||
layout_bbox = (
|
||
item.get('layout_bbox', []) or
|
||
item.get('block_bbox', []) or
|
||
item.get('bbox', [])
|
||
)
|
||
|
||
# Ensure bbox has 4 values
|
||
if len(layout_bbox) >= 4:
|
||
bbox = list(layout_bbox[:4]) # [x1, y1, x2, y2]
|
||
else:
|
||
bbox = [0, 0, 0, 0] # Default if bbox missing
|
||
logger.warning(f"Element {idx} has invalid bbox: {layout_bbox}")
|
||
|
||
# Scale bbox back to original image coordinates if image was scaled
|
||
# This is critical for proper cropping from original high-resolution image
|
||
if scaling_info and scaling_info.was_scaled and bbox != [0, 0, 0, 0]:
|
||
scale_factor = scaling_info.scale_factor
|
||
bbox = [
|
||
bbox[0] * scale_factor, # x1
|
||
bbox[1] * scale_factor, # y1
|
||
bbox[2] * scale_factor, # x2
|
||
bbox[3] * scale_factor # y2
|
||
]
|
||
if idx == 0: # Log only for first element to avoid spam
|
||
logger.info(
|
||
f"Scaled bbox to original coords: "
|
||
f"{[round(x, 1) for x in layout_bbox[:4]]} -> {[round(x, 1) for x in bbox]} "
|
||
f"(factor={scale_factor:.3f})"
|
||
)
|
||
|
||
# Extract content (check multiple possible keys)
|
||
content = (
|
||
item.get('content', '') or
|
||
item.get('block_content', '') or
|
||
''
|
||
)
|
||
|
||
# Additional fallback for content in 'res' field
|
||
if not content and 'res' in item:
|
||
res = item.get('res', {})
|
||
if isinstance(res, dict):
|
||
content = res.get('content', '') or res.get('text', '')
|
||
elif isinstance(res, str):
|
||
content = res
|
||
|
||
# Content-based HTML table detection: PP-StructureV3 sometimes
|
||
# classifies tables as 'text' but returns HTML table content
|
||
html_table_content = None
|
||
if content and '<table' in content.lower():
|
||
if mapped_type == ElementType.TEXT or element_type == 'text':
|
||
logger.info(f"Element {idx}: Detected HTML table content in 'text' type, reclassifying to TABLE")
|
||
mapped_type = ElementType.TABLE
|
||
html_table_content = content # Store for later use
|
||
|
||
# Create element
|
||
element = {
|
||
'element_id': f"pp3_{current_page}_{idx}",
|
||
'type': mapped_type,
|
||
'original_type': element_type,
|
||
'content': content,
|
||
'page': current_page,
|
||
'bbox': bbox, # [x1, y1, x2, y2]
|
||
'index': idx, # Original index in reading order
|
||
'confidence': item.get('score', 1.0)
|
||
}
|
||
|
||
# Special handling for tables
|
||
if mapped_type == ElementType.TABLE:
|
||
# 1. 提取 HTML (原有邏輯)
|
||
html_content = html_table_content
|
||
res_data = {}
|
||
|
||
# 獲取 res 字典 (包含 html 和 boxes)
|
||
if 'res' in item and isinstance(item['res'], dict):
|
||
res_data = item['res']
|
||
logger.info(f"[TABLE] Found 'res' dict with keys: {list(res_data.keys())}")
|
||
if not html_content:
|
||
html_content = res_data.get('html', '')
|
||
else:
|
||
logger.info(f"[TABLE] No 'res' key in item. Available keys: {list(item.keys())}")
|
||
|
||
if html_content:
|
||
element['html'] = html_content
|
||
element['extracted_text'] = self._extract_text_from_html(html_content)
|
||
|
||
# 2. 提取 Cell 座標 (boxes)
|
||
# 優先使用 PPStructureV3 返回的 boxes,若無則調用 SLANeXt 補充
|
||
cell_boxes_extracted = False
|
||
|
||
if 'boxes' in res_data:
|
||
# PPStructureV3 returned cell boxes (unlikely in PaddleX 3.x)
|
||
cell_boxes = res_data['boxes']
|
||
logger.info(f"[TABLE] Found {len(cell_boxes)} cell boxes in res_data")
|
||
|
||
# 獲取表格自身的偏移量 (用於將 Cell 的相對座標轉為絕對座標)
|
||
table_x, table_y = 0, 0
|
||
if len(bbox) >= 2: # bbox is [x1, y1, x2, y2]
|
||
table_x, table_y = bbox[0], bbox[1]
|
||
|
||
processed_cells = []
|
||
for cell_box in cell_boxes:
|
||
# 確保格式正確
|
||
if isinstance(cell_box, (list, tuple)) and len(cell_box) >= 4:
|
||
# 轉換為絕對座標: Cell x + 表格 x
|
||
abs_cell_box = [
|
||
cell_box[0] + table_x,
|
||
cell_box[1] + table_y,
|
||
cell_box[2] + table_x,
|
||
cell_box[3] + table_y
|
||
]
|
||
processed_cells.append(abs_cell_box)
|
||
|
||
# 將處理後的 Cell 座標存入 element
|
||
element['cell_boxes'] = processed_cells
|
||
element['raw_cell_boxes'] = cell_boxes
|
||
element['cell_boxes_source'] = 'ppstructure'
|
||
logger.info(f"[TABLE] Processed {len(processed_cells)} cell boxes with table offset ({table_x}, {table_y})")
|
||
cell_boxes_extracted = True
|
||
|
||
# Supplement with direct SLANeXt call if PPStructureV3 didn't provide boxes
|
||
if not cell_boxes_extracted and source_image_path and bbox != [0, 0, 0, 0]:
|
||
logger.info(f"[TABLE] No boxes from PPStructureV3, attempting SLANeXt extraction...")
|
||
try:
|
||
# Load source image and crop table region
|
||
source_img = Image.open(source_image_path)
|
||
source_array = np.array(source_img)
|
||
|
||
# Crop table region (bbox is in original image coordinates)
|
||
x1, y1, x2, y2 = [int(round(c)) for c in bbox]
|
||
# Ensure coordinates are within image bounds
|
||
h, w = source_array.shape[:2]
|
||
x1, y1 = max(0, x1), max(0, y1)
|
||
x2, y2 = min(w, x2), min(h, y2)
|
||
|
||
if x2 > x1 and y2 > y1:
|
||
table_crop = source_array[y1:y2, x1:x2]
|
||
|
||
# Convert RGB to BGR for SLANeXt
|
||
if len(table_crop.shape) == 3 and table_crop.shape[2] == 3:
|
||
table_crop_bgr = table_crop[:, :, ::-1]
|
||
else:
|
||
table_crop_bgr = table_crop
|
||
|
||
# Extract cell boxes using SLANeXt
|
||
slanet_boxes = self._extract_cell_boxes_with_slanet(
|
||
table_crop_bgr,
|
||
bbox, # Pass original bbox for coordinate offset
|
||
is_wired=None # Auto-detect
|
||
)
|
||
|
||
if slanet_boxes:
|
||
element['cell_boxes'] = slanet_boxes
|
||
element['cell_boxes_source'] = 'slanet'
|
||
cell_boxes_extracted = True
|
||
logger.info(f"[TABLE] SLANeXt extracted {len(slanet_boxes)} cell boxes")
|
||
else:
|
||
logger.warning(f"[TABLE] Invalid crop region: ({x1},{y1})-({x2},{y2})")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[TABLE] SLANeXt extraction failed: {e}")
|
||
|
||
if not cell_boxes_extracted:
|
||
logger.info(f"[TABLE] No cell boxes available. PPStructureV3 keys: {list(res_data.keys()) if res_data else 'empty'}")
|
||
|
||
# Special handling for images/figures
|
||
elif mapped_type in [ElementType.IMAGE, ElementType.FIGURE]:
|
||
# Save image if path provided
|
||
if 'img_path' in item and output_dir:
|
||
saved_path = self._save_image(item['img_path'], output_dir, element['element_id'])
|
||
if saved_path:
|
||
element['saved_path'] = saved_path
|
||
element['img_path'] = item['img_path'] # Keep original for reference
|
||
else:
|
||
logger.warning(f"Failed to save image for element {element['element_id']}")
|
||
# Crop image from source if no img_path but source image is available
|
||
elif source_image_path and output_dir and bbox != [0, 0, 0, 0]:
|
||
cropped_path = self._crop_and_save_image(
|
||
source_image_path, bbox, output_dir, element['element_id']
|
||
)
|
||
if cropped_path:
|
||
element['saved_path'] = cropped_path
|
||
element['img_path'] = cropped_path
|
||
logger.info(f"Cropped and saved image region for {element['element_id']}")
|
||
else:
|
||
logger.warning(f"Failed to crop image for element {element['element_id']}")
|
||
|
||
# Add any additional metadata
|
||
if 'metadata' in item:
|
||
element['metadata'] = item['metadata']
|
||
|
||
elements.append(element)
|
||
logger.debug(f"Processed element {idx}: type={mapped_type}, bbox={bbox}")
|
||
|
||
return elements
|
||
|
||
def _process_markdown_fallback(
|
||
self,
|
||
page_result: Any,
|
||
current_page: int,
|
||
output_dir: Optional[Path]
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
Fallback to markdown processing if parsing_res_list not available.
|
||
|
||
Args:
|
||
page_result: PP-StructureV3 page result
|
||
current_page: Current page number
|
||
output_dir: Optional output directory
|
||
|
||
Returns:
|
||
List of elements extracted from markdown
|
||
"""
|
||
elements = []
|
||
|
||
# Extract from markdown if available
|
||
if hasattr(page_result, 'markdown'):
|
||
markdown_dict = page_result.markdown
|
||
|
||
if isinstance(markdown_dict, dict):
|
||
# Extract markdown texts
|
||
markdown_texts = markdown_dict.get('markdown_texts', '')
|
||
if markdown_texts:
|
||
# Detect if it's a table
|
||
is_table = '<table' in markdown_texts.lower()
|
||
|
||
element = {
|
||
'element_id': f"md_{current_page}_0",
|
||
'type': ElementType.TABLE if is_table else ElementType.TEXT,
|
||
'content': markdown_texts,
|
||
'page': current_page,
|
||
'bbox': [0, 0, 0, 0], # No bbox in markdown
|
||
'index': 0,
|
||
'from_markdown': True
|
||
}
|
||
|
||
if is_table:
|
||
element['extracted_text'] = self._extract_text_from_html(markdown_texts)
|
||
|
||
elements.append(element)
|
||
|
||
# Process images
|
||
markdown_images = markdown_dict.get('markdown_images', {})
|
||
for img_idx, (img_path, img_obj) in enumerate(markdown_images.items()):
|
||
# Save image
|
||
if output_dir and hasattr(img_obj, 'save'):
|
||
self._save_pil_image(img_obj, output_dir, f"md_img_{current_page}_{img_idx}")
|
||
|
||
# Try to extract bbox from filename
|
||
bbox = self._extract_bbox_from_filename(img_path)
|
||
|
||
element = {
|
||
'element_id': f"md_img_{current_page}_{img_idx}",
|
||
'type': ElementType.IMAGE,
|
||
'content': img_path,
|
||
'page': current_page,
|
||
'bbox': bbox,
|
||
'index': img_idx + 1,
|
||
'from_markdown': True
|
||
}
|
||
elements.append(element)
|
||
|
||
return elements
|
||
|
||
def _determine_reading_order(self, elements: List[Dict]) -> List[int]:
|
||
"""
|
||
Determine reading order based on element positions.
|
||
|
||
Args:
|
||
elements: List of elements with bbox
|
||
|
||
Returns:
|
||
List of indices representing reading order
|
||
"""
|
||
if not elements:
|
||
return []
|
||
|
||
# If elements have original indices, use them
|
||
if all('index' in elem for elem in elements):
|
||
# Sort by original index
|
||
indexed_elements = [(i, elem['index']) for i, elem in enumerate(elements)]
|
||
indexed_elements.sort(key=lambda x: x[1])
|
||
return [i for i, _ in indexed_elements]
|
||
|
||
# Otherwise, sort by position (top to bottom, left to right)
|
||
indexed_elements = []
|
||
for i, elem in enumerate(elements):
|
||
bbox = elem.get('bbox', [0, 0, 0, 0])
|
||
if len(bbox) >= 2:
|
||
# Use top-left corner for sorting
|
||
indexed_elements.append((i, bbox[1], bbox[0])) # (index, y, x)
|
||
else:
|
||
indexed_elements.append((i, 0, 0))
|
||
|
||
# Sort by y first (top to bottom), then x (left to right)
|
||
indexed_elements.sort(key=lambda x: (x[1], x[2]))
|
||
|
||
return [i for i, _, _ in indexed_elements]
|
||
|
||
def _count_element_types(self, elements: List[Dict]) -> Dict[str, int]:
|
||
"""
|
||
Count occurrences of each element type.
|
||
|
||
Args:
|
||
elements: List of elements
|
||
|
||
Returns:
|
||
Dictionary with element type counts
|
||
"""
|
||
type_counts = {}
|
||
for elem in elements:
|
||
elem_type = elem.get('type', ElementType.TEXT)
|
||
type_counts[elem_type] = type_counts.get(elem_type, 0) + 1
|
||
return type_counts
|
||
|
||
def _extract_text_from_html(self, html: str) -> str:
|
||
"""Extract plain text from HTML content."""
|
||
try:
|
||
from bs4 import BeautifulSoup
|
||
soup = BeautifulSoup(html, 'html.parser')
|
||
return soup.get_text(separator=' ', strip=True)
|
||
except:
|
||
# Fallback: just remove HTML tags
|
||
import re
|
||
text = re.sub(r'<[^>]+>', ' ', html)
|
||
text = re.sub(r'\s+', ' ', text)
|
||
return text.strip()
|
||
|
||
def _extract_bbox_from_filename(self, filename: str) -> List[int]:
|
||
"""Extract bbox from filename if it contains coordinate information."""
|
||
import re
|
||
match = re.search(r'box_(\d+)_(\d+)_(\d+)_(\d+)', filename)
|
||
if match:
|
||
return list(map(int, match.groups()))
|
||
return [0, 0, 0, 0]
|
||
|
||
def _save_image(self, img_path: str, output_dir: Path, element_id: str) -> Optional[str]:
|
||
"""Save image file to output directory and return relative path.
|
||
|
||
Args:
|
||
img_path: Path to image file or image data
|
||
output_dir: Base output directory for results
|
||
element_id: Unique identifier for the element
|
||
|
||
Returns:
|
||
Relative path to saved image, or None if save failed
|
||
"""
|
||
import shutil
|
||
import numpy as np
|
||
from PIL import Image
|
||
|
||
try:
|
||
# Create imgs subdirectory
|
||
img_dir = output_dir / "imgs"
|
||
img_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Determine output file path
|
||
dst_path = img_dir / f"{element_id}.png"
|
||
relative_path = f"imgs/{element_id}.png"
|
||
|
||
# Handle different input types
|
||
if isinstance(img_path, str):
|
||
src_path = Path(img_path)
|
||
if src_path.exists() and src_path.is_file():
|
||
# Copy existing file
|
||
shutil.copy2(src_path, dst_path)
|
||
logger.info(f"Copied image from {src_path} to {dst_path}")
|
||
else:
|
||
logger.warning(f"Image file not found: {img_path}")
|
||
return None
|
||
elif isinstance(img_path, np.ndarray):
|
||
# Save numpy array as image
|
||
Image.fromarray(img_path).save(dst_path)
|
||
logger.info(f"Saved numpy array image to {dst_path}")
|
||
else:
|
||
logger.warning(f"Unknown image type: {type(img_path)}")
|
||
return None
|
||
|
||
# Return relative path for reference
|
||
return relative_path
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to save image for element {element_id}: {e}")
|
||
return None
|
||
|
||
def _save_pil_image(self, img_obj, output_dir: Path, element_id: str):
|
||
"""Save PIL image object to output directory."""
|
||
try:
|
||
img_dir = output_dir / "imgs"
|
||
img_dir.mkdir(parents=True, exist_ok=True)
|
||
img_path = img_dir / f"{element_id}.png"
|
||
img_obj.save(str(img_path))
|
||
logger.info(f"Saved image to {img_path}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to save PIL image: {e}")
|
||
|
||
def _crop_and_save_image(
|
||
self,
|
||
source_image_path: Path,
|
||
bbox: List[float],
|
||
output_dir: Path,
|
||
element_id: str
|
||
) -> Optional[str]:
|
||
"""
|
||
Crop image region from source image and save to output directory.
|
||
|
||
Args:
|
||
source_image_path: Path to the source image
|
||
bbox: Bounding box [x1, y1, x2, y2]
|
||
output_dir: Output directory for saving cropped image
|
||
element_id: Element ID for naming
|
||
|
||
Returns:
|
||
Relative filename (not full path) to saved image, consistent with
|
||
Direct Track which stores "filename.png" that gets joined with
|
||
result_dir by pdf_generator_service.
|
||
"""
|
||
try:
|
||
from PIL import Image
|
||
|
||
# Open source image
|
||
with Image.open(source_image_path) as img:
|
||
# Ensure bbox values are integers
|
||
x1, y1, x2, y2 = [int(v) for v in bbox[:4]]
|
||
|
||
# Validate bbox
|
||
img_width, img_height = img.size
|
||
x1 = max(0, min(x1, img_width))
|
||
x2 = max(0, min(x2, img_width))
|
||
y1 = max(0, min(y1, img_height))
|
||
y2 = max(0, min(y2, img_height))
|
||
|
||
if x2 <= x1 or y2 <= y1:
|
||
logger.warning(f"Invalid bbox for cropping: {bbox}")
|
||
return None
|
||
|
||
# Crop the region
|
||
cropped = img.crop((x1, y1, x2, y2))
|
||
|
||
# Save directly to output directory (no subdirectory)
|
||
# Consistent with Direct Track which saves to output_dir directly
|
||
image_filename = f"{element_id}.png"
|
||
img_path = output_dir / image_filename
|
||
cropped.save(str(img_path), "PNG")
|
||
|
||
# Return just the filename (relative to result_dir)
|
||
# PDF generator will join with result_dir to get full path
|
||
logger.info(f"Cropped image saved: {img_path} ({x2-x1}x{y2-y1} pixels)")
|
||
return image_filename
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to crop and save image for {element_id}: {e}")
|
||
return None |