feat: implement table cell boxes extraction with SLANeXt

Phase 1-3 implementation of extract-table-cell-boxes proposal:

- Add enable_table_cell_boxes_extraction config option
- Implement lazy-loaded SLANeXt model caching in PPStructureEnhanced
- Add _extract_cell_boxes_with_slanet() method for direct model invocation
- Supplement PPStructureV3 table processing with SLANeXt cell boxes
- Add _compute_table_grid_from_cell_boxes() for column width calculation
- Modify draw_table_region() to use cell_boxes for accurate layout

Key features:
- Auto-detect table type (wired/wireless) using PP-LCNet classifier
- Convert 8-point polygon bbox to 4-point rectangle
- Graceful fallback to equal distribution when cell_boxes unavailable
- Proper coordinate transformation with scaling support

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
egg
2025-11-28 12:20:32 +08:00
parent 801ee9c4b6
commit 715805b3b8
3 changed files with 379 additions and 7 deletions

View File

@@ -161,6 +161,14 @@ class Settings(BaseSettings):
description="Cell detection model for borderless tables. RT-DETR-L provides best accuracy." description="Cell detection model for borderless tables. RT-DETR-L provides best accuracy."
) )
# Table Cell Boxes Extraction - supplement PPStructureV3 with direct SLANeXt calls
# When enabled, directly invokes SLANeXt models to extract cell bounding boxes
# which are not exposed by the PPStructureV3 high-level API
enable_table_cell_boxes_extraction: bool = Field(
default=True,
description="Enable direct SLANeXt model calls to extract table cell bounding boxes for accurate PDF layout."
)
# Formula Recognition Model Configuration (Stage 4) # Formula Recognition Model Configuration (Stage 4)
# Available models: # Available models:
# - "PP-FormulaNet_plus-L": Best for Chinese formulas (90.64% Chinese, 92.22% English BLEU) # - "PP-FormulaNet_plus-L": Best for Chinese formulas (90.64% Chinese, 92.22% English BLEU)

View File

@@ -1605,6 +1605,120 @@ class PDFGeneratorService:
except Exception as e: except Exception as e:
logger.warning(f"Failed to draw text region '{text[:20]}...': {e}") logger.warning(f"Failed to draw text region '{text[:20]}...': {e}")
def _compute_table_grid_from_cell_boxes(
self,
cell_boxes: List[List[float]],
table_bbox: List[float],
num_rows: int,
num_cols: int
) -> Tuple[Optional[List[float]], Optional[List[float]]]:
"""
Compute column widths and row heights from cell bounding boxes.
This uses the cell boxes extracted by SLANeXt to calculate the actual
column widths and row heights, which provides more accurate table rendering
than uniform distribution.
Args:
cell_boxes: List of cell bboxes [[x1,y1,x2,y2], ...]
table_bbox: Table bounding box [x1,y1,x2,y2]
num_rows: Number of rows in the table
num_cols: Number of columns in the table
Returns:
Tuple of (col_widths, row_heights) or (None, None) if calculation fails
"""
if not cell_boxes or len(cell_boxes) < 2:
return None, None
try:
table_x1, table_y1, table_x2, table_y2 = table_bbox
table_width = table_x2 - table_x1
table_height = table_y2 - table_y1
# Collect all unique X and Y boundaries from cell boxes
x_boundaries = set()
y_boundaries = set()
for box in cell_boxes:
if len(box) >= 4:
x1, y1, x2, y2 = box[:4]
# Convert to relative coordinates within table
x_boundaries.add(x1 - table_x1)
x_boundaries.add(x2 - table_x1)
y_boundaries.add(y1 - table_y1)
y_boundaries.add(y2 - table_y1)
# Sort boundaries
x_boundaries = sorted(x_boundaries)
y_boundaries = sorted(y_boundaries)
# Ensure we have boundaries at table edges
if x_boundaries and x_boundaries[0] > 5:
x_boundaries.insert(0, 0)
if x_boundaries and x_boundaries[-1] < table_width - 5:
x_boundaries.append(table_width)
if y_boundaries and y_boundaries[0] > 5:
y_boundaries.insert(0, 0)
if y_boundaries and y_boundaries[-1] < table_height - 5:
y_boundaries.append(table_height)
# Calculate column widths from X boundaries
# Merge boundaries that are too close (< 5px)
merged_x = [x_boundaries[0]] if x_boundaries else []
for x in x_boundaries[1:]:
if x - merged_x[-1] > 5:
merged_x.append(x)
x_boundaries = merged_x
# Calculate row heights from Y boundaries
merged_y = [y_boundaries[0]] if y_boundaries else []
for y in y_boundaries[1:]:
if y - merged_y[-1] > 5:
merged_y.append(y)
y_boundaries = merged_y
# Calculate widths and heights
col_widths = []
for i in range(len(x_boundaries) - 1):
col_widths.append(x_boundaries[i + 1] - x_boundaries[i])
row_heights = []
for i in range(len(y_boundaries) - 1):
row_heights.append(y_boundaries[i + 1] - y_boundaries[i])
# Validate: number of columns/rows should match expected
if len(col_widths) == num_cols and len(row_heights) == num_rows:
logger.info(f"[TABLE] Cell boxes grid: {num_cols} cols, {num_rows} rows")
logger.debug(f"[TABLE] Col widths from cell_boxes: {[f'{w:.1f}' for w in col_widths]}")
logger.debug(f"[TABLE] Row heights from cell_boxes: {[f'{h:.1f}' for h in row_heights]}")
return col_widths, row_heights
else:
# Grid doesn't match, might be due to merged cells
logger.debug(
f"[TABLE] Cell boxes grid mismatch: "
f"got {len(col_widths)}x{len(row_heights)}, expected {num_cols}x{num_rows}"
)
# Still return the widths/heights if counts are close
if abs(len(col_widths) - num_cols) <= 1 and abs(len(row_heights) - num_rows) <= 1:
# Adjust to match expected count
while len(col_widths) < num_cols:
col_widths.append(col_widths[-1] if col_widths else table_width / num_cols)
while len(col_widths) > num_cols:
col_widths.pop()
while len(row_heights) < num_rows:
row_heights.append(row_heights[-1] if row_heights else table_height / num_rows)
while len(row_heights) > num_rows:
row_heights.pop()
return col_widths, row_heights
return None, None
except Exception as e:
logger.warning(f"[TABLE] Failed to compute grid from cell boxes: {e}")
return None, None
def draw_table_region( def draw_table_region(
self, self,
pdf_canvas: canvas.Canvas, pdf_canvas: canvas.Canvas,
@@ -1765,8 +1879,36 @@ class PDFGeneratorService:
reportlab_data.append(row_data) reportlab_data.append(row_data)
# Calculate column widths (equal distribution) # Calculate column widths and row heights
# First, try to use cell_boxes if available for more accurate layout
cell_boxes = table_element.get('cell_boxes')
raw_table_bbox = [ocr_x_left_raw, ocr_y_top_raw, ocr_x_right_raw, ocr_y_bottom_raw]
computed_col_widths = None
computed_row_heights = None
if cell_boxes:
cell_boxes_source = table_element.get('cell_boxes_source', 'unknown')
logger.info(f"[TABLE] Using {len(cell_boxes)} cell boxes from {cell_boxes_source}")
computed_col_widths, computed_row_heights = self._compute_table_grid_from_cell_boxes(
cell_boxes, raw_table_bbox, num_rows, max_cols
)
# Use computed widths if available, otherwise fall back to equal distribution
if computed_col_widths:
# Scale col_widths to PDF coordinates
col_widths = [w * scale_w for w in computed_col_widths]
logger.info(f"[TABLE] Using cell_boxes col widths (scaled)")
else:
col_widths = [table_width / max_cols] * max_cols col_widths = [table_width / max_cols] * max_cols
logger.info(f"[TABLE] Using equal distribution col widths")
# Row heights are used optionally (ReportLab can auto-size)
row_heights = None
if computed_row_heights:
# Scale row_heights to PDF coordinates
row_heights = [h * scale_h for h in computed_row_heights]
logger.debug(f"[TABLE] Cell_boxes row heights available (scaled)")
# Create ReportLab Table # Create ReportLab Table
# Use smaller font to fit content with auto-wrap # Use smaller font to fit content with auto-wrap
@@ -1790,7 +1932,11 @@ class PDFGeneratorService:
escaped_text = cell_text.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;') escaped_text = cell_text.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;')
reportlab_data[row_idx][col_idx] = Paragraph(escaped_text, cell_style) reportlab_data[row_idx][col_idx] = Paragraph(escaped_text, cell_style)
# Create table WITHOUT fixed row heights - let it auto-size based on content # Create table with computed col widths
# Note: We don't use row_heights even when available from cell_boxes because:
# 1. ReportLab's auto-sizing handles content overflow better
# 2. Fixed heights can cause text clipping when content exceeds cell size
# 3. The col_widths from cell_boxes provide the main layout benefit
table = Table(reportlab_data, colWidths=col_widths) table = Table(reportlab_data, colWidths=col_widths)
# Apply table style # Apply table style

View File

@@ -80,6 +80,176 @@ class PPStructureEnhanced:
""" """
self.structure_engine = structure_engine self.structure_engine = structure_engine
# Lazy-loaded SLANeXt models for cell boxes extraction
# These are loaded on-demand when enable_table_cell_boxes_extraction is True
self._slanet_wired_model = None
self._slanet_wireless_model = None
self._table_cls_model = None
def _get_slanet_model(self, is_wired: bool = True):
"""
Get or create SLANeXt model for cell boxes extraction (lazy loading).
Args:
is_wired: True for wired (bordered) tables, False for wireless
Returns:
SLANeXt model instance or None if loading fails
"""
if not settings.enable_table_cell_boxes_extraction:
return None
try:
from paddlex import create_model
if is_wired:
if self._slanet_wired_model is None:
model_name = settings.wired_table_model_name or "SLANeXt_wired"
logger.info(f"Loading SLANeXt wired model: {model_name}")
self._slanet_wired_model = create_model(model_name)
return self._slanet_wired_model
else:
if self._slanet_wireless_model is None:
model_name = settings.wireless_table_model_name or "SLANeXt_wireless"
logger.info(f"Loading SLANeXt wireless model: {model_name}")
self._slanet_wireless_model = create_model(model_name)
return self._slanet_wireless_model
except Exception as e:
logger.error(f"Failed to load SLANeXt model: {e}")
return None
def _get_table_classifier(self):
"""
Get or create table classification model (lazy loading).
Returns:
Table classifier model instance or None if loading fails
"""
if not settings.enable_table_cell_boxes_extraction:
return None
try:
from paddlex import create_model
if self._table_cls_model is None:
model_name = settings.table_classification_model_name or "PP-LCNet_x1_0_table_cls"
logger.info(f"Loading table classification model: {model_name}")
self._table_cls_model = create_model(model_name)
return self._table_cls_model
except Exception as e:
logger.error(f"Failed to load table classifier: {e}")
return None
def _extract_cell_boxes_with_slanet(
self,
table_image: np.ndarray,
table_bbox: List[float],
is_wired: Optional[bool] = None
) -> Optional[List[List[float]]]:
"""
Extract cell bounding boxes using direct SLANeXt model call.
This supplements PPStructureV3 which doesn't expose cell boxes in its output.
Args:
table_image: Cropped table image as numpy array (BGR format)
table_bbox: Table bounding box in page coordinates [x1, y1, x2, y2]
is_wired: If None, auto-detect using classifier. True for bordered tables.
Returns:
List of cell bounding boxes in page coordinates [[x1,y1,x2,y2], ...],
or None if extraction fails
"""
if not settings.enable_table_cell_boxes_extraction:
return None
try:
# Auto-detect table type if not specified
if is_wired is None:
classifier = self._get_table_classifier()
if classifier:
try:
cls_result = classifier.predict(table_image)
# PP-LCNet returns classification result
for res in cls_result:
label_names = res.get('label_names', [])
if label_names:
is_wired = 'wired' in str(label_names[0]).lower()
logger.debug(f"Table classified as: {'wired' if is_wired else 'wireless'}")
break
except Exception as e:
logger.warning(f"Table classification failed, defaulting to wired: {e}")
is_wired = True
else:
is_wired = True # Default to wired if classifier unavailable
# Get appropriate SLANeXt model
model = self._get_slanet_model(is_wired=is_wired)
if model is None:
return None
# Run SLANeXt prediction
results = model.predict(table_image)
# Extract cell boxes from result
cell_boxes = []
table_x, table_y = table_bbox[0], table_bbox[1]
for result in results:
# SLANeXt returns 'bbox' with 8-point polygon format
# [[x1,y1,x2,y2,x3,y3,x4,y4], ...]
boxes = result.get('bbox', [])
for box in boxes:
if isinstance(box, (list, tuple)):
if len(box) >= 8:
# 8-point polygon: convert to 4-point rectangle
xs = [box[i] for i in range(0, 8, 2)]
ys = [box[i] for i in range(1, 8, 2)]
x1, y1 = min(xs), min(ys)
x2, y2 = max(xs), max(ys)
elif len(box) >= 4:
# Already 4-point rectangle
x1, y1, x2, y2 = box[:4]
else:
continue
# Convert to absolute page coordinates
abs_box = [
float(x1 + table_x),
float(y1 + table_y),
float(x2 + table_x),
float(y2 + table_y)
]
cell_boxes.append(abs_box)
logger.info(f"SLANeXt extracted {len(cell_boxes)} cell boxes (is_wired={is_wired})")
return cell_boxes if cell_boxes else None
except Exception as e:
logger.error(f"Cell boxes extraction with SLANeXt failed: {e}")
return None
def release_slanet_models(self):
"""Release SLANeXt models to free GPU memory."""
if self._slanet_wired_model is not None:
del self._slanet_wired_model
self._slanet_wired_model = None
logger.info("Released SLANeXt wired model")
if self._slanet_wireless_model is not None:
del self._slanet_wireless_model
self._slanet_wireless_model = None
logger.info("Released SLANeXt wireless model")
if self._table_cls_model is not None:
del self._table_cls_model
self._table_cls_model = None
logger.info("Released table classifier model")
gc.collect()
if TORCH_AVAILABLE:
torch.cuda.empty_cache()
def analyze_with_full_structure( def analyze_with_full_structure(
self, self,
image_path: Path, image_path: Path,
@@ -372,9 +542,12 @@ class PPStructureEnhanced:
element['html'] = html_content element['html'] = html_content
element['extracted_text'] = self._extract_text_from_html(html_content) element['extracted_text'] = self._extract_text_from_html(html_content)
# 2. 【新增】提取 Cell 座標 (boxes) # 2. 提取 Cell 座標 (boxes)
# SLANet 回傳的格式通常是 [[x1, y1, x2, y2], ...] # 優先使用 PPStructureV3 返回的 boxes若無則調用 SLANeXt 補充
cell_boxes_extracted = False
if 'boxes' in res_data: if 'boxes' in res_data:
# PPStructureV3 returned cell boxes (unlikely in PaddleX 3.x)
cell_boxes = res_data['boxes'] cell_boxes = res_data['boxes']
logger.info(f"[TABLE] Found {len(cell_boxes)} cell boxes in res_data") logger.info(f"[TABLE] Found {len(cell_boxes)} cell boxes in res_data")
@@ -399,9 +572,54 @@ class PPStructureEnhanced:
# 將處理後的 Cell 座標存入 element # 將處理後的 Cell 座標存入 element
element['cell_boxes'] = processed_cells element['cell_boxes'] = processed_cells
element['raw_cell_boxes'] = cell_boxes element['raw_cell_boxes'] = cell_boxes
element['cell_boxes_source'] = 'ppstructure'
logger.info(f"[TABLE] Processed {len(processed_cells)} cell boxes with table offset ({table_x}, {table_y})") logger.info(f"[TABLE] Processed {len(processed_cells)} cell boxes with table offset ({table_x}, {table_y})")
cell_boxes_extracted = True
# Supplement with direct SLANeXt call if PPStructureV3 didn't provide boxes
if not cell_boxes_extracted and source_image_path and bbox != [0, 0, 0, 0]:
logger.info(f"[TABLE] No boxes from PPStructureV3, attempting SLANeXt extraction...")
try:
# Load source image and crop table region
source_img = Image.open(source_image_path)
source_array = np.array(source_img)
# Crop table region (bbox is in original image coordinates)
x1, y1, x2, y2 = [int(round(c)) for c in bbox]
# Ensure coordinates are within image bounds
h, w = source_array.shape[:2]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
if x2 > x1 and y2 > y1:
table_crop = source_array[y1:y2, x1:x2]
# Convert RGB to BGR for SLANeXt
if len(table_crop.shape) == 3 and table_crop.shape[2] == 3:
table_crop_bgr = table_crop[:, :, ::-1]
else: else:
logger.info(f"[TABLE] No 'boxes' key in res_data. Available: {list(res_data.keys()) if res_data else 'empty'}") table_crop_bgr = table_crop
# Extract cell boxes using SLANeXt
slanet_boxes = self._extract_cell_boxes_with_slanet(
table_crop_bgr,
bbox, # Pass original bbox for coordinate offset
is_wired=None # Auto-detect
)
if slanet_boxes:
element['cell_boxes'] = slanet_boxes
element['cell_boxes_source'] = 'slanet'
cell_boxes_extracted = True
logger.info(f"[TABLE] SLANeXt extracted {len(slanet_boxes)} cell boxes")
else:
logger.warning(f"[TABLE] Invalid crop region: ({x1},{y1})-({x2},{y2})")
except Exception as e:
logger.error(f"[TABLE] SLANeXt extraction failed: {e}")
if not cell_boxes_extracted:
logger.info(f"[TABLE] No cell boxes available. PPStructureV3 keys: {list(res_data.keys()) if res_data else 'empty'}")
# Special handling for images/figures # Special handling for images/figures
elif mapped_type in [ElementType.IMAGE, ElementType.FIGURE]: elif mapped_type in [ElementType.IMAGE, ElementType.FIGURE]: