diff --git a/backend/app/core/config.py b/backend/app/core/config.py index be56106..4dc17ac 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -161,6 +161,14 @@ class Settings(BaseSettings): description="Cell detection model for borderless tables. RT-DETR-L provides best accuracy." ) + # Table Cell Boxes Extraction - supplement PPStructureV3 with direct SLANeXt calls + # When enabled, directly invokes SLANeXt models to extract cell bounding boxes + # which are not exposed by the PPStructureV3 high-level API + enable_table_cell_boxes_extraction: bool = Field( + default=True, + description="Enable direct SLANeXt model calls to extract table cell bounding boxes for accurate PDF layout." + ) + # Formula Recognition Model Configuration (Stage 4) # Available models: # - "PP-FormulaNet_plus-L": Best for Chinese formulas (90.64% Chinese, 92.22% English BLEU) diff --git a/backend/app/services/pdf_generator_service.py b/backend/app/services/pdf_generator_service.py index 61b2230..225801e 100644 --- a/backend/app/services/pdf_generator_service.py +++ b/backend/app/services/pdf_generator_service.py @@ -1605,6 +1605,120 @@ class PDFGeneratorService: except Exception as e: logger.warning(f"Failed to draw text region '{text[:20]}...': {e}") + def _compute_table_grid_from_cell_boxes( + self, + cell_boxes: List[List[float]], + table_bbox: List[float], + num_rows: int, + num_cols: int + ) -> Tuple[Optional[List[float]], Optional[List[float]]]: + """ + Compute column widths and row heights from cell bounding boxes. + + This uses the cell boxes extracted by SLANeXt to calculate the actual + column widths and row heights, which provides more accurate table rendering + than uniform distribution. + + Args: + cell_boxes: List of cell bboxes [[x1,y1,x2,y2], ...] + table_bbox: Table bounding box [x1,y1,x2,y2] + num_rows: Number of rows in the table + num_cols: Number of columns in the table + + Returns: + Tuple of (col_widths, row_heights) or (None, None) if calculation fails + """ + if not cell_boxes or len(cell_boxes) < 2: + return None, None + + try: + table_x1, table_y1, table_x2, table_y2 = table_bbox + table_width = table_x2 - table_x1 + table_height = table_y2 - table_y1 + + # Collect all unique X and Y boundaries from cell boxes + x_boundaries = set() + y_boundaries = set() + + for box in cell_boxes: + if len(box) >= 4: + x1, y1, x2, y2 = box[:4] + # Convert to relative coordinates within table + x_boundaries.add(x1 - table_x1) + x_boundaries.add(x2 - table_x1) + y_boundaries.add(y1 - table_y1) + y_boundaries.add(y2 - table_y1) + + # Sort boundaries + x_boundaries = sorted(x_boundaries) + y_boundaries = sorted(y_boundaries) + + # Ensure we have boundaries at table edges + if x_boundaries and x_boundaries[0] > 5: + x_boundaries.insert(0, 0) + if x_boundaries and x_boundaries[-1] < table_width - 5: + x_boundaries.append(table_width) + + if y_boundaries and y_boundaries[0] > 5: + y_boundaries.insert(0, 0) + if y_boundaries and y_boundaries[-1] < table_height - 5: + y_boundaries.append(table_height) + + # Calculate column widths from X boundaries + # Merge boundaries that are too close (< 5px) + merged_x = [x_boundaries[0]] if x_boundaries else [] + for x in x_boundaries[1:]: + if x - merged_x[-1] > 5: + merged_x.append(x) + x_boundaries = merged_x + + # Calculate row heights from Y boundaries + merged_y = [y_boundaries[0]] if y_boundaries else [] + for y in y_boundaries[1:]: + if y - merged_y[-1] > 5: + merged_y.append(y) + y_boundaries = merged_y + + # Calculate widths and heights + col_widths = [] + for i in range(len(x_boundaries) - 1): + col_widths.append(x_boundaries[i + 1] - x_boundaries[i]) + + row_heights = [] + for i in range(len(y_boundaries) - 1): + row_heights.append(y_boundaries[i + 1] - y_boundaries[i]) + + # Validate: number of columns/rows should match expected + if len(col_widths) == num_cols and len(row_heights) == num_rows: + logger.info(f"[TABLE] Cell boxes grid: {num_cols} cols, {num_rows} rows") + logger.debug(f"[TABLE] Col widths from cell_boxes: {[f'{w:.1f}' for w in col_widths]}") + logger.debug(f"[TABLE] Row heights from cell_boxes: {[f'{h:.1f}' for h in row_heights]}") + return col_widths, row_heights + else: + # Grid doesn't match, might be due to merged cells + logger.debug( + f"[TABLE] Cell boxes grid mismatch: " + f"got {len(col_widths)}x{len(row_heights)}, expected {num_cols}x{num_rows}" + ) + # Still return the widths/heights if counts are close + if abs(len(col_widths) - num_cols) <= 1 and abs(len(row_heights) - num_rows) <= 1: + # Adjust to match expected count + while len(col_widths) < num_cols: + col_widths.append(col_widths[-1] if col_widths else table_width / num_cols) + while len(col_widths) > num_cols: + col_widths.pop() + while len(row_heights) < num_rows: + row_heights.append(row_heights[-1] if row_heights else table_height / num_rows) + while len(row_heights) > num_rows: + row_heights.pop() + return col_widths, row_heights + + return None, None + + except Exception as e: + logger.warning(f"[TABLE] Failed to compute grid from cell boxes: {e}") + return None, None + def draw_table_region( self, pdf_canvas: canvas.Canvas, @@ -1765,8 +1879,36 @@ class PDFGeneratorService: reportlab_data.append(row_data) - # Calculate column widths (equal distribution) - col_widths = [table_width / max_cols] * max_cols + # Calculate column widths and row heights + # First, try to use cell_boxes if available for more accurate layout + cell_boxes = table_element.get('cell_boxes') + raw_table_bbox = [ocr_x_left_raw, ocr_y_top_raw, ocr_x_right_raw, ocr_y_bottom_raw] + + computed_col_widths = None + computed_row_heights = None + + if cell_boxes: + cell_boxes_source = table_element.get('cell_boxes_source', 'unknown') + logger.info(f"[TABLE] Using {len(cell_boxes)} cell boxes from {cell_boxes_source}") + computed_col_widths, computed_row_heights = self._compute_table_grid_from_cell_boxes( + cell_boxes, raw_table_bbox, num_rows, max_cols + ) + + # Use computed widths if available, otherwise fall back to equal distribution + if computed_col_widths: + # Scale col_widths to PDF coordinates + col_widths = [w * scale_w for w in computed_col_widths] + logger.info(f"[TABLE] Using cell_boxes col widths (scaled)") + else: + col_widths = [table_width / max_cols] * max_cols + logger.info(f"[TABLE] Using equal distribution col widths") + + # Row heights are used optionally (ReportLab can auto-size) + row_heights = None + if computed_row_heights: + # Scale row_heights to PDF coordinates + row_heights = [h * scale_h for h in computed_row_heights] + logger.debug(f"[TABLE] Cell_boxes row heights available (scaled)") # Create ReportLab Table # Use smaller font to fit content with auto-wrap @@ -1790,7 +1932,11 @@ class PDFGeneratorService: escaped_text = cell_text.replace('&', '&').replace('<', '<').replace('>', '>') reportlab_data[row_idx][col_idx] = Paragraph(escaped_text, cell_style) - # Create table WITHOUT fixed row heights - let it auto-size based on content + # Create table with computed col widths + # Note: We don't use row_heights even when available from cell_boxes because: + # 1. ReportLab's auto-sizing handles content overflow better + # 2. Fixed heights can cause text clipping when content exceeds cell size + # 3. The col_widths from cell_boxes provide the main layout benefit table = Table(reportlab_data, colWidths=col_widths) # Apply table style diff --git a/backend/app/services/pp_structure_enhanced.py b/backend/app/services/pp_structure_enhanced.py index f32d390..c0cfed3 100644 --- a/backend/app/services/pp_structure_enhanced.py +++ b/backend/app/services/pp_structure_enhanced.py @@ -80,6 +80,176 @@ class PPStructureEnhanced: """ self.structure_engine = structure_engine + # Lazy-loaded SLANeXt models for cell boxes extraction + # These are loaded on-demand when enable_table_cell_boxes_extraction is True + self._slanet_wired_model = None + self._slanet_wireless_model = None + self._table_cls_model = None + + def _get_slanet_model(self, is_wired: bool = True): + """ + Get or create SLANeXt model for cell boxes extraction (lazy loading). + + Args: + is_wired: True for wired (bordered) tables, False for wireless + + Returns: + SLANeXt model instance or None if loading fails + """ + if not settings.enable_table_cell_boxes_extraction: + return None + + try: + from paddlex import create_model + + if is_wired: + if self._slanet_wired_model is None: + model_name = settings.wired_table_model_name or "SLANeXt_wired" + logger.info(f"Loading SLANeXt wired model: {model_name}") + self._slanet_wired_model = create_model(model_name) + return self._slanet_wired_model + else: + if self._slanet_wireless_model is None: + model_name = settings.wireless_table_model_name or "SLANeXt_wireless" + logger.info(f"Loading SLANeXt wireless model: {model_name}") + self._slanet_wireless_model = create_model(model_name) + return self._slanet_wireless_model + except Exception as e: + logger.error(f"Failed to load SLANeXt model: {e}") + return None + + def _get_table_classifier(self): + """ + Get or create table classification model (lazy loading). + + Returns: + Table classifier model instance or None if loading fails + """ + if not settings.enable_table_cell_boxes_extraction: + return None + + try: + from paddlex import create_model + + if self._table_cls_model is None: + model_name = settings.table_classification_model_name or "PP-LCNet_x1_0_table_cls" + logger.info(f"Loading table classification model: {model_name}") + self._table_cls_model = create_model(model_name) + return self._table_cls_model + except Exception as e: + logger.error(f"Failed to load table classifier: {e}") + return None + + def _extract_cell_boxes_with_slanet( + self, + table_image: np.ndarray, + table_bbox: List[float], + is_wired: Optional[bool] = None + ) -> Optional[List[List[float]]]: + """ + Extract cell bounding boxes using direct SLANeXt model call. + + This supplements PPStructureV3 which doesn't expose cell boxes in its output. + + Args: + table_image: Cropped table image as numpy array (BGR format) + table_bbox: Table bounding box in page coordinates [x1, y1, x2, y2] + is_wired: If None, auto-detect using classifier. True for bordered tables. + + Returns: + List of cell bounding boxes in page coordinates [[x1,y1,x2,y2], ...], + or None if extraction fails + """ + if not settings.enable_table_cell_boxes_extraction: + return None + + try: + # Auto-detect table type if not specified + if is_wired is None: + classifier = self._get_table_classifier() + if classifier: + try: + cls_result = classifier.predict(table_image) + # PP-LCNet returns classification result + for res in cls_result: + label_names = res.get('label_names', []) + if label_names: + is_wired = 'wired' in str(label_names[0]).lower() + logger.debug(f"Table classified as: {'wired' if is_wired else 'wireless'}") + break + except Exception as e: + logger.warning(f"Table classification failed, defaulting to wired: {e}") + is_wired = True + else: + is_wired = True # Default to wired if classifier unavailable + + # Get appropriate SLANeXt model + model = self._get_slanet_model(is_wired=is_wired) + if model is None: + return None + + # Run SLANeXt prediction + results = model.predict(table_image) + + # Extract cell boxes from result + cell_boxes = [] + table_x, table_y = table_bbox[0], table_bbox[1] + + for result in results: + # SLANeXt returns 'bbox' with 8-point polygon format + # [[x1,y1,x2,y2,x3,y3,x4,y4], ...] + boxes = result.get('bbox', []) + for box in boxes: + if isinstance(box, (list, tuple)): + if len(box) >= 8: + # 8-point polygon: convert to 4-point rectangle + xs = [box[i] for i in range(0, 8, 2)] + ys = [box[i] for i in range(1, 8, 2)] + x1, y1 = min(xs), min(ys) + x2, y2 = max(xs), max(ys) + elif len(box) >= 4: + # Already 4-point rectangle + x1, y1, x2, y2 = box[:4] + else: + continue + + # Convert to absolute page coordinates + abs_box = [ + float(x1 + table_x), + float(y1 + table_y), + float(x2 + table_x), + float(y2 + table_y) + ] + cell_boxes.append(abs_box) + + logger.info(f"SLANeXt extracted {len(cell_boxes)} cell boxes (is_wired={is_wired})") + return cell_boxes if cell_boxes else None + + except Exception as e: + logger.error(f"Cell boxes extraction with SLANeXt failed: {e}") + return None + + def release_slanet_models(self): + """Release SLANeXt models to free GPU memory.""" + if self._slanet_wired_model is not None: + del self._slanet_wired_model + self._slanet_wired_model = None + logger.info("Released SLANeXt wired model") + + if self._slanet_wireless_model is not None: + del self._slanet_wireless_model + self._slanet_wireless_model = None + logger.info("Released SLANeXt wireless model") + + if self._table_cls_model is not None: + del self._table_cls_model + self._table_cls_model = None + logger.info("Released table classifier model") + + gc.collect() + if TORCH_AVAILABLE: + torch.cuda.empty_cache() + def analyze_with_full_structure( self, image_path: Path, @@ -372,9 +542,12 @@ class PPStructureEnhanced: element['html'] = html_content element['extracted_text'] = self._extract_text_from_html(html_content) - # 2. 【新增】提取 Cell 座標 (boxes) - # SLANet 回傳的格式通常是 [[x1, y1, x2, y2], ...] + # 2. 提取 Cell 座標 (boxes) + # 優先使用 PPStructureV3 返回的 boxes,若無則調用 SLANeXt 補充 + cell_boxes_extracted = False + if 'boxes' in res_data: + # PPStructureV3 returned cell boxes (unlikely in PaddleX 3.x) cell_boxes = res_data['boxes'] logger.info(f"[TABLE] Found {len(cell_boxes)} cell boxes in res_data") @@ -399,9 +572,54 @@ class PPStructureEnhanced: # 將處理後的 Cell 座標存入 element element['cell_boxes'] = processed_cells element['raw_cell_boxes'] = cell_boxes + element['cell_boxes_source'] = 'ppstructure' logger.info(f"[TABLE] Processed {len(processed_cells)} cell boxes with table offset ({table_x}, {table_y})") - else: - logger.info(f"[TABLE] No 'boxes' key in res_data. Available: {list(res_data.keys()) if res_data else 'empty'}") + cell_boxes_extracted = True + + # Supplement with direct SLANeXt call if PPStructureV3 didn't provide boxes + if not cell_boxes_extracted and source_image_path and bbox != [0, 0, 0, 0]: + logger.info(f"[TABLE] No boxes from PPStructureV3, attempting SLANeXt extraction...") + try: + # Load source image and crop table region + source_img = Image.open(source_image_path) + source_array = np.array(source_img) + + # Crop table region (bbox is in original image coordinates) + x1, y1, x2, y2 = [int(round(c)) for c in bbox] + # Ensure coordinates are within image bounds + h, w = source_array.shape[:2] + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(w, x2), min(h, y2) + + if x2 > x1 and y2 > y1: + table_crop = source_array[y1:y2, x1:x2] + + # Convert RGB to BGR for SLANeXt + if len(table_crop.shape) == 3 and table_crop.shape[2] == 3: + table_crop_bgr = table_crop[:, :, ::-1] + else: + table_crop_bgr = table_crop + + # Extract cell boxes using SLANeXt + slanet_boxes = self._extract_cell_boxes_with_slanet( + table_crop_bgr, + bbox, # Pass original bbox for coordinate offset + is_wired=None # Auto-detect + ) + + if slanet_boxes: + element['cell_boxes'] = slanet_boxes + element['cell_boxes_source'] = 'slanet' + cell_boxes_extracted = True + logger.info(f"[TABLE] SLANeXt extracted {len(slanet_boxes)} cell boxes") + else: + logger.warning(f"[TABLE] Invalid crop region: ({x1},{y1})-({x2},{y2})") + + except Exception as e: + logger.error(f"[TABLE] SLANeXt extraction failed: {e}") + + if not cell_boxes_extracted: + logger.info(f"[TABLE] No cell boxes available. PPStructureV3 keys: {list(res_data.keys()) if res_data else 'empty'}") # Special handling for images/figures elif mapped_type in [ElementType.IMAGE, ElementType.FIGURE]: