feat: implement table cell boxes extraction with SLANeXt
Phase 1-3 implementation of extract-table-cell-boxes proposal: - Add enable_table_cell_boxes_extraction config option - Implement lazy-loaded SLANeXt model caching in PPStructureEnhanced - Add _extract_cell_boxes_with_slanet() method for direct model invocation - Supplement PPStructureV3 table processing with SLANeXt cell boxes - Add _compute_table_grid_from_cell_boxes() for column width calculation - Modify draw_table_region() to use cell_boxes for accurate layout Key features: - Auto-detect table type (wired/wireless) using PP-LCNet classifier - Convert 8-point polygon bbox to 4-point rectangle - Graceful fallback to equal distribution when cell_boxes unavailable - Proper coordinate transformation with scaling support 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1605,6 +1605,120 @@ class PDFGeneratorService:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to draw text region '{text[:20]}...': {e}")
|
||||
|
||||
def _compute_table_grid_from_cell_boxes(
|
||||
self,
|
||||
cell_boxes: List[List[float]],
|
||||
table_bbox: List[float],
|
||||
num_rows: int,
|
||||
num_cols: int
|
||||
) -> Tuple[Optional[List[float]], Optional[List[float]]]:
|
||||
"""
|
||||
Compute column widths and row heights from cell bounding boxes.
|
||||
|
||||
This uses the cell boxes extracted by SLANeXt to calculate the actual
|
||||
column widths and row heights, which provides more accurate table rendering
|
||||
than uniform distribution.
|
||||
|
||||
Args:
|
||||
cell_boxes: List of cell bboxes [[x1,y1,x2,y2], ...]
|
||||
table_bbox: Table bounding box [x1,y1,x2,y2]
|
||||
num_rows: Number of rows in the table
|
||||
num_cols: Number of columns in the table
|
||||
|
||||
Returns:
|
||||
Tuple of (col_widths, row_heights) or (None, None) if calculation fails
|
||||
"""
|
||||
if not cell_boxes or len(cell_boxes) < 2:
|
||||
return None, None
|
||||
|
||||
try:
|
||||
table_x1, table_y1, table_x2, table_y2 = table_bbox
|
||||
table_width = table_x2 - table_x1
|
||||
table_height = table_y2 - table_y1
|
||||
|
||||
# Collect all unique X and Y boundaries from cell boxes
|
||||
x_boundaries = set()
|
||||
y_boundaries = set()
|
||||
|
||||
for box in cell_boxes:
|
||||
if len(box) >= 4:
|
||||
x1, y1, x2, y2 = box[:4]
|
||||
# Convert to relative coordinates within table
|
||||
x_boundaries.add(x1 - table_x1)
|
||||
x_boundaries.add(x2 - table_x1)
|
||||
y_boundaries.add(y1 - table_y1)
|
||||
y_boundaries.add(y2 - table_y1)
|
||||
|
||||
# Sort boundaries
|
||||
x_boundaries = sorted(x_boundaries)
|
||||
y_boundaries = sorted(y_boundaries)
|
||||
|
||||
# Ensure we have boundaries at table edges
|
||||
if x_boundaries and x_boundaries[0] > 5:
|
||||
x_boundaries.insert(0, 0)
|
||||
if x_boundaries and x_boundaries[-1] < table_width - 5:
|
||||
x_boundaries.append(table_width)
|
||||
|
||||
if y_boundaries and y_boundaries[0] > 5:
|
||||
y_boundaries.insert(0, 0)
|
||||
if y_boundaries and y_boundaries[-1] < table_height - 5:
|
||||
y_boundaries.append(table_height)
|
||||
|
||||
# Calculate column widths from X boundaries
|
||||
# Merge boundaries that are too close (< 5px)
|
||||
merged_x = [x_boundaries[0]] if x_boundaries else []
|
||||
for x in x_boundaries[1:]:
|
||||
if x - merged_x[-1] > 5:
|
||||
merged_x.append(x)
|
||||
x_boundaries = merged_x
|
||||
|
||||
# Calculate row heights from Y boundaries
|
||||
merged_y = [y_boundaries[0]] if y_boundaries else []
|
||||
for y in y_boundaries[1:]:
|
||||
if y - merged_y[-1] > 5:
|
||||
merged_y.append(y)
|
||||
y_boundaries = merged_y
|
||||
|
||||
# Calculate widths and heights
|
||||
col_widths = []
|
||||
for i in range(len(x_boundaries) - 1):
|
||||
col_widths.append(x_boundaries[i + 1] - x_boundaries[i])
|
||||
|
||||
row_heights = []
|
||||
for i in range(len(y_boundaries) - 1):
|
||||
row_heights.append(y_boundaries[i + 1] - y_boundaries[i])
|
||||
|
||||
# Validate: number of columns/rows should match expected
|
||||
if len(col_widths) == num_cols and len(row_heights) == num_rows:
|
||||
logger.info(f"[TABLE] Cell boxes grid: {num_cols} cols, {num_rows} rows")
|
||||
logger.debug(f"[TABLE] Col widths from cell_boxes: {[f'{w:.1f}' for w in col_widths]}")
|
||||
logger.debug(f"[TABLE] Row heights from cell_boxes: {[f'{h:.1f}' for h in row_heights]}")
|
||||
return col_widths, row_heights
|
||||
else:
|
||||
# Grid doesn't match, might be due to merged cells
|
||||
logger.debug(
|
||||
f"[TABLE] Cell boxes grid mismatch: "
|
||||
f"got {len(col_widths)}x{len(row_heights)}, expected {num_cols}x{num_rows}"
|
||||
)
|
||||
# Still return the widths/heights if counts are close
|
||||
if abs(len(col_widths) - num_cols) <= 1 and abs(len(row_heights) - num_rows) <= 1:
|
||||
# Adjust to match expected count
|
||||
while len(col_widths) < num_cols:
|
||||
col_widths.append(col_widths[-1] if col_widths else table_width / num_cols)
|
||||
while len(col_widths) > num_cols:
|
||||
col_widths.pop()
|
||||
while len(row_heights) < num_rows:
|
||||
row_heights.append(row_heights[-1] if row_heights else table_height / num_rows)
|
||||
while len(row_heights) > num_rows:
|
||||
row_heights.pop()
|
||||
return col_widths, row_heights
|
||||
|
||||
return None, None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[TABLE] Failed to compute grid from cell boxes: {e}")
|
||||
return None, None
|
||||
|
||||
def draw_table_region(
|
||||
self,
|
||||
pdf_canvas: canvas.Canvas,
|
||||
@@ -1765,8 +1879,36 @@ class PDFGeneratorService:
|
||||
|
||||
reportlab_data.append(row_data)
|
||||
|
||||
# Calculate column widths (equal distribution)
|
||||
col_widths = [table_width / max_cols] * max_cols
|
||||
# Calculate column widths and row heights
|
||||
# First, try to use cell_boxes if available for more accurate layout
|
||||
cell_boxes = table_element.get('cell_boxes')
|
||||
raw_table_bbox = [ocr_x_left_raw, ocr_y_top_raw, ocr_x_right_raw, ocr_y_bottom_raw]
|
||||
|
||||
computed_col_widths = None
|
||||
computed_row_heights = None
|
||||
|
||||
if cell_boxes:
|
||||
cell_boxes_source = table_element.get('cell_boxes_source', 'unknown')
|
||||
logger.info(f"[TABLE] Using {len(cell_boxes)} cell boxes from {cell_boxes_source}")
|
||||
computed_col_widths, computed_row_heights = self._compute_table_grid_from_cell_boxes(
|
||||
cell_boxes, raw_table_bbox, num_rows, max_cols
|
||||
)
|
||||
|
||||
# Use computed widths if available, otherwise fall back to equal distribution
|
||||
if computed_col_widths:
|
||||
# Scale col_widths to PDF coordinates
|
||||
col_widths = [w * scale_w for w in computed_col_widths]
|
||||
logger.info(f"[TABLE] Using cell_boxes col widths (scaled)")
|
||||
else:
|
||||
col_widths = [table_width / max_cols] * max_cols
|
||||
logger.info(f"[TABLE] Using equal distribution col widths")
|
||||
|
||||
# Row heights are used optionally (ReportLab can auto-size)
|
||||
row_heights = None
|
||||
if computed_row_heights:
|
||||
# Scale row_heights to PDF coordinates
|
||||
row_heights = [h * scale_h for h in computed_row_heights]
|
||||
logger.debug(f"[TABLE] Cell_boxes row heights available (scaled)")
|
||||
|
||||
# Create ReportLab Table
|
||||
# Use smaller font to fit content with auto-wrap
|
||||
@@ -1790,7 +1932,11 @@ class PDFGeneratorService:
|
||||
escaped_text = cell_text.replace('&', '&').replace('<', '<').replace('>', '>')
|
||||
reportlab_data[row_idx][col_idx] = Paragraph(escaped_text, cell_style)
|
||||
|
||||
# Create table WITHOUT fixed row heights - let it auto-size based on content
|
||||
# Create table with computed col widths
|
||||
# Note: We don't use row_heights even when available from cell_boxes because:
|
||||
# 1. ReportLab's auto-sizing handles content overflow better
|
||||
# 2. Fixed heights can cause text clipping when content exceeds cell size
|
||||
# 3. The col_widths from cell_boxes provide the main layout benefit
|
||||
table = Table(reportlab_data, colWidths=col_widths)
|
||||
|
||||
# Apply table style
|
||||
|
||||
@@ -80,6 +80,176 @@ class PPStructureEnhanced:
|
||||
"""
|
||||
self.structure_engine = structure_engine
|
||||
|
||||
# Lazy-loaded SLANeXt models for cell boxes extraction
|
||||
# These are loaded on-demand when enable_table_cell_boxes_extraction is True
|
||||
self._slanet_wired_model = None
|
||||
self._slanet_wireless_model = None
|
||||
self._table_cls_model = None
|
||||
|
||||
def _get_slanet_model(self, is_wired: bool = True):
|
||||
"""
|
||||
Get or create SLANeXt model for cell boxes extraction (lazy loading).
|
||||
|
||||
Args:
|
||||
is_wired: True for wired (bordered) tables, False for wireless
|
||||
|
||||
Returns:
|
||||
SLANeXt model instance or None if loading fails
|
||||
"""
|
||||
if not settings.enable_table_cell_boxes_extraction:
|
||||
return None
|
||||
|
||||
try:
|
||||
from paddlex import create_model
|
||||
|
||||
if is_wired:
|
||||
if self._slanet_wired_model is None:
|
||||
model_name = settings.wired_table_model_name or "SLANeXt_wired"
|
||||
logger.info(f"Loading SLANeXt wired model: {model_name}")
|
||||
self._slanet_wired_model = create_model(model_name)
|
||||
return self._slanet_wired_model
|
||||
else:
|
||||
if self._slanet_wireless_model is None:
|
||||
model_name = settings.wireless_table_model_name or "SLANeXt_wireless"
|
||||
logger.info(f"Loading SLANeXt wireless model: {model_name}")
|
||||
self._slanet_wireless_model = create_model(model_name)
|
||||
return self._slanet_wireless_model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load SLANeXt model: {e}")
|
||||
return None
|
||||
|
||||
def _get_table_classifier(self):
|
||||
"""
|
||||
Get or create table classification model (lazy loading).
|
||||
|
||||
Returns:
|
||||
Table classifier model instance or None if loading fails
|
||||
"""
|
||||
if not settings.enable_table_cell_boxes_extraction:
|
||||
return None
|
||||
|
||||
try:
|
||||
from paddlex import create_model
|
||||
|
||||
if self._table_cls_model is None:
|
||||
model_name = settings.table_classification_model_name or "PP-LCNet_x1_0_table_cls"
|
||||
logger.info(f"Loading table classification model: {model_name}")
|
||||
self._table_cls_model = create_model(model_name)
|
||||
return self._table_cls_model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load table classifier: {e}")
|
||||
return None
|
||||
|
||||
def _extract_cell_boxes_with_slanet(
|
||||
self,
|
||||
table_image: np.ndarray,
|
||||
table_bbox: List[float],
|
||||
is_wired: Optional[bool] = None
|
||||
) -> Optional[List[List[float]]]:
|
||||
"""
|
||||
Extract cell bounding boxes using direct SLANeXt model call.
|
||||
|
||||
This supplements PPStructureV3 which doesn't expose cell boxes in its output.
|
||||
|
||||
Args:
|
||||
table_image: Cropped table image as numpy array (BGR format)
|
||||
table_bbox: Table bounding box in page coordinates [x1, y1, x2, y2]
|
||||
is_wired: If None, auto-detect using classifier. True for bordered tables.
|
||||
|
||||
Returns:
|
||||
List of cell bounding boxes in page coordinates [[x1,y1,x2,y2], ...],
|
||||
or None if extraction fails
|
||||
"""
|
||||
if not settings.enable_table_cell_boxes_extraction:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Auto-detect table type if not specified
|
||||
if is_wired is None:
|
||||
classifier = self._get_table_classifier()
|
||||
if classifier:
|
||||
try:
|
||||
cls_result = classifier.predict(table_image)
|
||||
# PP-LCNet returns classification result
|
||||
for res in cls_result:
|
||||
label_names = res.get('label_names', [])
|
||||
if label_names:
|
||||
is_wired = 'wired' in str(label_names[0]).lower()
|
||||
logger.debug(f"Table classified as: {'wired' if is_wired else 'wireless'}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Table classification failed, defaulting to wired: {e}")
|
||||
is_wired = True
|
||||
else:
|
||||
is_wired = True # Default to wired if classifier unavailable
|
||||
|
||||
# Get appropriate SLANeXt model
|
||||
model = self._get_slanet_model(is_wired=is_wired)
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
# Run SLANeXt prediction
|
||||
results = model.predict(table_image)
|
||||
|
||||
# Extract cell boxes from result
|
||||
cell_boxes = []
|
||||
table_x, table_y = table_bbox[0], table_bbox[1]
|
||||
|
||||
for result in results:
|
||||
# SLANeXt returns 'bbox' with 8-point polygon format
|
||||
# [[x1,y1,x2,y2,x3,y3,x4,y4], ...]
|
||||
boxes = result.get('bbox', [])
|
||||
for box in boxes:
|
||||
if isinstance(box, (list, tuple)):
|
||||
if len(box) >= 8:
|
||||
# 8-point polygon: convert to 4-point rectangle
|
||||
xs = [box[i] for i in range(0, 8, 2)]
|
||||
ys = [box[i] for i in range(1, 8, 2)]
|
||||
x1, y1 = min(xs), min(ys)
|
||||
x2, y2 = max(xs), max(ys)
|
||||
elif len(box) >= 4:
|
||||
# Already 4-point rectangle
|
||||
x1, y1, x2, y2 = box[:4]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Convert to absolute page coordinates
|
||||
abs_box = [
|
||||
float(x1 + table_x),
|
||||
float(y1 + table_y),
|
||||
float(x2 + table_x),
|
||||
float(y2 + table_y)
|
||||
]
|
||||
cell_boxes.append(abs_box)
|
||||
|
||||
logger.info(f"SLANeXt extracted {len(cell_boxes)} cell boxes (is_wired={is_wired})")
|
||||
return cell_boxes if cell_boxes else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cell boxes extraction with SLANeXt failed: {e}")
|
||||
return None
|
||||
|
||||
def release_slanet_models(self):
|
||||
"""Release SLANeXt models to free GPU memory."""
|
||||
if self._slanet_wired_model is not None:
|
||||
del self._slanet_wired_model
|
||||
self._slanet_wired_model = None
|
||||
logger.info("Released SLANeXt wired model")
|
||||
|
||||
if self._slanet_wireless_model is not None:
|
||||
del self._slanet_wireless_model
|
||||
self._slanet_wireless_model = None
|
||||
logger.info("Released SLANeXt wireless model")
|
||||
|
||||
if self._table_cls_model is not None:
|
||||
del self._table_cls_model
|
||||
self._table_cls_model = None
|
||||
logger.info("Released table classifier model")
|
||||
|
||||
gc.collect()
|
||||
if TORCH_AVAILABLE:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def analyze_with_full_structure(
|
||||
self,
|
||||
image_path: Path,
|
||||
@@ -372,9 +542,12 @@ class PPStructureEnhanced:
|
||||
element['html'] = html_content
|
||||
element['extracted_text'] = self._extract_text_from_html(html_content)
|
||||
|
||||
# 2. 【新增】提取 Cell 座標 (boxes)
|
||||
# SLANet 回傳的格式通常是 [[x1, y1, x2, y2], ...]
|
||||
# 2. 提取 Cell 座標 (boxes)
|
||||
# 優先使用 PPStructureV3 返回的 boxes,若無則調用 SLANeXt 補充
|
||||
cell_boxes_extracted = False
|
||||
|
||||
if 'boxes' in res_data:
|
||||
# PPStructureV3 returned cell boxes (unlikely in PaddleX 3.x)
|
||||
cell_boxes = res_data['boxes']
|
||||
logger.info(f"[TABLE] Found {len(cell_boxes)} cell boxes in res_data")
|
||||
|
||||
@@ -399,9 +572,54 @@ class PPStructureEnhanced:
|
||||
# 將處理後的 Cell 座標存入 element
|
||||
element['cell_boxes'] = processed_cells
|
||||
element['raw_cell_boxes'] = cell_boxes
|
||||
element['cell_boxes_source'] = 'ppstructure'
|
||||
logger.info(f"[TABLE] Processed {len(processed_cells)} cell boxes with table offset ({table_x}, {table_y})")
|
||||
else:
|
||||
logger.info(f"[TABLE] No 'boxes' key in res_data. Available: {list(res_data.keys()) if res_data else 'empty'}")
|
||||
cell_boxes_extracted = True
|
||||
|
||||
# Supplement with direct SLANeXt call if PPStructureV3 didn't provide boxes
|
||||
if not cell_boxes_extracted and source_image_path and bbox != [0, 0, 0, 0]:
|
||||
logger.info(f"[TABLE] No boxes from PPStructureV3, attempting SLANeXt extraction...")
|
||||
try:
|
||||
# Load source image and crop table region
|
||||
source_img = Image.open(source_image_path)
|
||||
source_array = np.array(source_img)
|
||||
|
||||
# Crop table region (bbox is in original image coordinates)
|
||||
x1, y1, x2, y2 = [int(round(c)) for c in bbox]
|
||||
# Ensure coordinates are within image bounds
|
||||
h, w = source_array.shape[:2]
|
||||
x1, y1 = max(0, x1), max(0, y1)
|
||||
x2, y2 = min(w, x2), min(h, y2)
|
||||
|
||||
if x2 > x1 and y2 > y1:
|
||||
table_crop = source_array[y1:y2, x1:x2]
|
||||
|
||||
# Convert RGB to BGR for SLANeXt
|
||||
if len(table_crop.shape) == 3 and table_crop.shape[2] == 3:
|
||||
table_crop_bgr = table_crop[:, :, ::-1]
|
||||
else:
|
||||
table_crop_bgr = table_crop
|
||||
|
||||
# Extract cell boxes using SLANeXt
|
||||
slanet_boxes = self._extract_cell_boxes_with_slanet(
|
||||
table_crop_bgr,
|
||||
bbox, # Pass original bbox for coordinate offset
|
||||
is_wired=None # Auto-detect
|
||||
)
|
||||
|
||||
if slanet_boxes:
|
||||
element['cell_boxes'] = slanet_boxes
|
||||
element['cell_boxes_source'] = 'slanet'
|
||||
cell_boxes_extracted = True
|
||||
logger.info(f"[TABLE] SLANeXt extracted {len(slanet_boxes)} cell boxes")
|
||||
else:
|
||||
logger.warning(f"[TABLE] Invalid crop region: ({x1},{y1})-({x2},{y2})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[TABLE] SLANeXt extraction failed: {e}")
|
||||
|
||||
if not cell_boxes_extracted:
|
||||
logger.info(f"[TABLE] No cell boxes available. PPStructureV3 keys: {list(res_data.keys()) if res_data else 'empty'}")
|
||||
|
||||
# Special handling for images/figures
|
||||
elif mapped_type in [ElementType.IMAGE, ElementType.FIGURE]:
|
||||
|
||||
Reference in New Issue
Block a user