feat: enable document orientation detection for scanned PDFs

- Enable PP-StructureV3's use_doc_orientation_classify feature
- Detect rotation angle from doc_preprocessor_res.angle
- Swap page dimensions (width <-> height) for 90°/270° rotations
- Output PDF now correctly displays landscape-scanned content

Also includes:
- Archive completed openspec proposals
- Add simplify-frontend-ocr-config proposal (pending)
- Code cleanup and frontend simplification

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
egg
2025-12-11 17:13:46 +08:00
parent 57070af307
commit cfe65158a3
58 changed files with 1271 additions and 3048 deletions

View File

@@ -90,11 +90,16 @@ class Settings(BaseSettings):
max_concurrent_pages: int = Field(default=2) # Process 2 pages concurrently
# PP-StructureV3 optimization
# Strategy: Use raw OCR positioning (simple-text-positioning) instead of table structure reconstruction
# - Layout Detection: ON (detect regions)
# - General OCR: ON (text recognition)
# - Table Recognition: OFF (no cell/structure parsing - use raw OCR bbox instead)
# - Seal/Formula/Chart: ON (specialized recognition)
enable_chart_recognition: bool = Field(default=True) # Chart/diagram recognition
enable_formula_recognition: bool = Field(default=True) # Math formula recognition
enable_table_recognition: bool = Field(default=True) # Table structure recognition
enable_table_recognition: bool = Field(default=False) # Table structure recognition - DISABLED (use raw OCR)
enable_seal_recognition: bool = Field(default=True) # Seal/stamp recognition
enable_region_detection: bool = Field(default=True) # Region detection for better table structure
enable_region_detection: bool = Field(default=True) # Region detection for layout
enable_text_recognition: bool = Field(default=True) # General text recognition
# Table Parsing Mode - Controls how aggressively tables are parsed
@@ -116,57 +121,6 @@ class Settings(BaseSettings):
description="Layout threshold for ALL element detection. Higher values = fewer elements detected."
)
# Cell Validation (filter over-detected table cells)
# DISABLED: This is a patch behavior - focus on getting PP-Structure output right first
cell_validation_enabled: bool = Field(
default=False,
description="Enable cell validation to filter over-detected tables"
)
cell_validation_max_density: float = Field(
default=3.0,
description="Max cells per 10,000px². Tables exceeding this are reclassified as TEXT."
)
cell_validation_min_cell_area: float = Field(
default=3000.0,
description="Min average cell area in px². Tables below this are reclassified as TEXT."
)
cell_validation_min_cell_height: float = Field(
default=10.0,
description="Min average cell height in px. Tables below this are reclassified as TEXT."
)
# Table Content Rebuilder (rebuild table HTML from raw OCR)
# DISABLED: This is a patch behavior - focus on getting PP-Structure output right first
table_content_rebuilder_enabled: bool = Field(
default=False,
description="Enable table content rebuilder to fix PP-Structure table HTML"
)
# Table Quality Check (determines rendering strategy based on cell_boxes overlap)
# When enabled, tables with overlapping cell_boxes are marked as 'bad' quality
# and rendered with border-only mode instead of full cell_boxes rendering.
# Disable this to always use cell_boxes rendering regardless of quality.
table_quality_check_enabled: bool = Field(
default=False,
description="Enable cell_boxes quality check. When disabled, all tables use cell_boxes rendering."
)
# Table Rendering: cell_boxes-first approach
# When enabled, uses cell_boxes coordinates as the primary source for table structure
# instead of relying on HTML colspan/rowspan, which often causes grid mismatch issues
# DISABLED: Algorithm needs improvement - clustering produces incorrect grid dimensions
table_rendering_prefer_cellboxes: bool = Field(
default=False,
description="Use cell_boxes coordinates as primary table structure source for PDF rendering"
)
table_cellboxes_row_threshold: float = Field(
default=15.0,
description="Y-coordinate threshold for row clustering when inferring grid from cell_boxes"
)
table_cellboxes_col_threshold: float = Field(
default=15.0,
description="X-coordinate threshold for column clustering when inferring grid from cell_boxes"
)
# Table Column Alignment Correction (Header-Anchor Algorithm)
# Corrects PP-Structure's column assignment errors using header row X-coordinates as reference
@@ -204,7 +158,10 @@ class Settings(BaseSettings):
)
# PP-StructureV3 Preprocessing (Stage 1)
use_doc_orientation_classify: bool = Field(default=True) # Auto-detect and correct document rotation
# NOTE: doc_orientation_classify ENABLED - detects and corrects document orientation
# for scanned PDFs where content orientation differs from PDF page metadata.
# When rotation is detected (90°/270°), page dimensions are swapped accordingly.
use_doc_orientation_classify: bool = Field(default=True) # Enabled: auto-detect and correct page orientation
use_doc_unwarping: bool = Field(default=False) # Disabled: can cause document distortion/skewing
use_textline_orientation: bool = Field(default=True) # Detect textline orientation
@@ -417,10 +374,6 @@ class Settings(BaseSettings):
description="Use PP-StructureV3's internal OCR results instead of separate inference."
)
# Legacy IoU threshold (deprecated, kept for backward compatibility)
gap_filling_iou_threshold: float = Field(default=0.15) # Deprecated: use IoA thresholds
gap_filling_dedup_iou_threshold: float = Field(default=0.5) # Deprecated: use gap_filling_dedup_ioa_threshold
# ===== Debug Configuration =====
# Enable debug outputs for PP-StructureV3 analysis
pp_structure_debug_enabled: bool = Field(default=True) # Save debug files for PP-StructureV3

View File

@@ -40,10 +40,6 @@ from app.schemas.task import (
PreprocessingPreviewRequest,
PreprocessingPreviewResponse,
ImageQualityMetrics,
TableDetectionConfig,
OCRPresetEnum,
OCRConfig,
OCR_PRESET_CONFIGS,
)
from app.services.task_service import task_service
from app.services.file_access_service import file_access_service
@@ -79,10 +75,7 @@ def process_task_ocr(
language: str = 'ch',
layout_model: Optional[str] = "chinese",
preprocessing_mode: Optional[str] = "auto",
preprocessing_config: Optional[dict] = None,
table_detection_config: Optional[dict] = None,
ocr_preset: Optional[str] = None,
ocr_config: Optional[dict] = None
preprocessing_config: Optional[dict] = None
):
"""
Background task to process OCR for a task with dual-track support.
@@ -101,9 +94,6 @@ def process_task_ocr(
layout_model: Layout detection model ('chinese', 'default', 'cdla')
preprocessing_mode: Preprocessing mode ('auto', 'manual', 'disabled')
preprocessing_config: Manual preprocessing config dict (contrast, sharpen, binarize)
table_detection_config: Table detection config dict (enable_wired_table, enable_wireless_table, enable_region_detection)
ocr_preset: OCR processing preset (text_heavy, datasheet, table_heavy, form, mixed, custom)
ocr_config: Custom OCR config dict (overrides preset values)
"""
from app.core.database import SessionLocal
from app.models.task import Task
@@ -116,7 +106,6 @@ def process_task_ocr(
logger.info(f"Starting OCR processing for task {task_id}, file: {filename}")
logger.info(f"Processing options: dual_track={use_dual_track}, force_track={force_track}, lang={language}")
logger.info(f"Preprocessing options: mode={preprocessing_mode}, config={preprocessing_config}")
logger.info(f"Table detection options: {table_detection_config}")
# Convert preprocessing parameters to proper types
preprocess_mode_enum = None
@@ -133,35 +122,6 @@ def process_task_ocr(
binarize=preprocessing_config.get("binarize", False)
)
# Convert table detection config to object
table_det_config_obj = None
if table_detection_config:
table_det_config_obj = TableDetectionConfig(
enable_wired_table=table_detection_config.get("enable_wired_table", True),
enable_wireless_table=table_detection_config.get("enable_wireless_table", True),
enable_region_detection=table_detection_config.get("enable_region_detection", True)
)
# Convert OCR preset and config to proper objects
from app.schemas.task import OCRPresetEnum, OCRConfig, OCR_PRESET_CONFIGS, TableParsingModeEnum
ocr_config_obj = None
if ocr_preset:
preset_enum = OCRPresetEnum(ocr_preset)
# Get preset config as base
if preset_enum in OCR_PRESET_CONFIGS:
ocr_config_obj = OCR_PRESET_CONFIGS[preset_enum].model_copy()
else:
# CUSTOM preset - use provided config or defaults
ocr_config_obj = OCRConfig()
# Override with custom config values if provided
if ocr_config:
for key, value in ocr_config.items():
if hasattr(ocr_config_obj, key) and value is not None:
setattr(ocr_config_obj, key, value)
logger.info(f"OCR config resolved: preset={ocr_preset}, config={ocr_config_obj.model_dump() if ocr_config_obj else None}")
# Get task directly by database ID (bypass user isolation for background task)
task = db.query(Task).filter(Task.id == task_db_id).first()
if not task:
@@ -210,9 +170,7 @@ def process_task_ocr(
force_track=force_track,
layout_model=layout_model,
preprocessing_mode=preprocess_mode_enum,
preprocessing_config=preprocess_config_obj,
table_detection_config=table_det_config_obj,
ocr_config=ocr_config_obj
preprocessing_config=preprocess_config_obj
)
else:
# Fall back to traditional processing (no force_track support)
@@ -223,9 +181,7 @@ def process_task_ocr(
output_dir=result_dir,
layout_model=layout_model,
preprocessing_mode=preprocess_mode_enum,
preprocessing_config=preprocess_config_obj,
table_detection_config=table_det_config_obj,
ocr_config=ocr_config_obj
preprocessing_config=preprocess_config_obj
)
# Calculate processing time
@@ -818,7 +774,7 @@ async def start_task(
- **force_track**: Force specific processing track ('ocr' or 'direct')
- **language**: OCR language code (default: 'ch')
- **layout_model**: Layout detection model ('chinese', 'default', 'cdla')
- **table_detection**: Table detection config (enable_wired_table, enable_wireless_table, enable_region_detection)
- **preprocessing_mode**: Preprocessing mode ('auto', 'manual', 'disabled')
"""
try:
# Parse processing options with defaults
@@ -846,23 +802,6 @@ async def start_task(
}
logger.info(f"Preprocessing: mode={preprocessing_mode}, config={preprocessing_config}")
# Extract table detection options
table_detection_config = None
if options.table_detection:
table_detection_config = {
"enable_wired_table": options.table_detection.enable_wired_table,
"enable_wireless_table": options.table_detection.enable_wireless_table,
"enable_region_detection": options.table_detection.enable_region_detection
}
logger.info(f"Table detection: {table_detection_config}")
# Extract OCR preset and config
ocr_preset = options.ocr_preset.value if options.ocr_preset else "datasheet"
ocr_config_dict = None
if options.ocr_config:
ocr_config_dict = options.ocr_config.model_dump()
logger.info(f"OCR preset: {ocr_preset}, config: {ocr_config_dict}")
# Get task details
task = task_service.get_task_by_id(
db=db,
@@ -911,14 +850,11 @@ async def start_task(
language=language,
layout_model=layout_model,
preprocessing_mode=preprocessing_mode,
preprocessing_config=preprocessing_config,
table_detection_config=table_detection_config,
ocr_preset=ocr_preset,
ocr_config=ocr_config_dict
preprocessing_config=preprocessing_config
)
logger.info(f"Started OCR processing task {task_id} for user {current_user.email}")
logger.info(f"Options: dual_track={use_dual_track}, force_track={force_track}, lang={language}, layout_model={layout_model}, preprocessing={preprocessing_mode}, table_detection={table_detection_config}, ocr_preset={ocr_preset}")
logger.info(f"Options: dual_track={use_dual_track}, force_track={force_track}, lang={language}, layout_model={layout_model}, preprocessing={preprocessing_mode}")
return task
except HTTPException:

View File

@@ -65,139 +65,6 @@ class PreprocessingContrastEnum(str, Enum):
DOCUMENT = "document"
class OCRPresetEnum(str, Enum):
"""OCR processing preset for different document types.
Presets provide optimized PP-Structure configurations for common document types:
- TEXT_HEAVY: Reports, articles, manuals (disable table recognition)
- DATASHEET: Technical datasheets, TDS (conservative table parsing)
- TABLE_HEAVY: Financial reports, spreadsheets (full table recognition)
- FORM: Applications, surveys (conservative table parsing)
- MIXED: General documents (classification only)
- CUSTOM: User-defined settings (use ocr_config)
"""
TEXT_HEAVY = "text_heavy" # Reports, articles, manuals
DATASHEET = "datasheet" # Technical datasheets, TDS
TABLE_HEAVY = "table_heavy" # Financial reports, spreadsheets
FORM = "form" # Applications, surveys
MIXED = "mixed" # General documents
CUSTOM = "custom" # User-defined settings
class TableParsingModeEnum(str, Enum):
"""Table parsing mode controlling how aggressively tables are parsed.
- FULL: Full table recognition with cell segmentation (aggressive)
- CONSERVATIVE: Disable wireless tables to prevent cell explosion
- CLASSIFICATION_ONLY: Only classify table regions, no cell segmentation
- DISABLED: Completely disable table recognition
"""
FULL = "full"
CONSERVATIVE = "conservative"
CLASSIFICATION_ONLY = "classification_only"
DISABLED = "disabled"
class OCRConfig(BaseModel):
"""OCR processing configuration for PP-Structure.
Allows fine-grained control over PP-Structure parameters.
Use with ocr_preset=CUSTOM or to override specific preset values.
"""
# Table Processing
table_parsing_mode: TableParsingModeEnum = Field(
default=TableParsingModeEnum.CONSERVATIVE,
description="Table parsing mode: full, conservative, classification_only, disabled"
)
enable_wired_table: bool = Field(
default=True,
description="Enable wired (bordered) table detection"
)
enable_wireless_table: bool = Field(
default=False,
description="Enable wireless (borderless) table detection. Can cause cell explosion."
)
# Layout Detection
layout_threshold: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="Layout detection threshold. Higher = stricter. None uses default."
)
layout_nms_threshold: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="Layout NMS threshold. None uses default."
)
# Preprocessing
use_doc_orientation_classify: bool = Field(
default=True,
description="Auto-detect and correct document rotation"
)
use_doc_unwarping: bool = Field(
default=False,
description="Correct document warping. Can cause distortion."
)
use_textline_orientation: bool = Field(
default=True,
description="Detect textline orientation"
)
# Recognition Modules
enable_chart_recognition: bool = Field(
default=True,
description="Enable chart/diagram recognition"
)
enable_formula_recognition: bool = Field(
default=True,
description="Enable math formula recognition"
)
enable_seal_recognition: bool = Field(
default=False,
description="Enable seal/stamp recognition"
)
enable_region_detection: bool = Field(
default=True,
description="Enable region detection for better structure"
)
# Preset configurations mapping
OCR_PRESET_CONFIGS = {
OCRPresetEnum.TEXT_HEAVY: OCRConfig(
table_parsing_mode=TableParsingModeEnum.DISABLED,
enable_wired_table=False,
enable_wireless_table=False,
enable_chart_recognition=False,
enable_formula_recognition=False,
),
OCRPresetEnum.DATASHEET: OCRConfig(
table_parsing_mode=TableParsingModeEnum.CONSERVATIVE,
enable_wired_table=True,
enable_wireless_table=False,
),
OCRPresetEnum.TABLE_HEAVY: OCRConfig(
table_parsing_mode=TableParsingModeEnum.FULL,
enable_wired_table=True,
enable_wireless_table=True,
),
OCRPresetEnum.FORM: OCRConfig(
table_parsing_mode=TableParsingModeEnum.CONSERVATIVE,
enable_wired_table=True,
enable_wireless_table=False,
),
OCRPresetEnum.MIXED: OCRConfig(
table_parsing_mode=TableParsingModeEnum.CLASSIFICATION_ONLY,
enable_wired_table=True,
enable_wireless_table=False,
),
# CUSTOM uses user-provided config directly
}
class PreprocessingConfig(BaseModel):
"""Preprocessing configuration for layout detection enhancement.
@@ -235,31 +102,6 @@ class PreprocessingConfig(BaseModel):
)
class TableDetectionConfig(BaseModel):
"""Table detection configuration for PP-StructureV3.
Controls which table detection modes to enable. PP-StructureV3 uses specialized
models for different table types:
- Wired (bordered): Tables with visible cell borders/grid lines
- Wireless (borderless): Tables without visible borders, relying on alignment
- Region detection: Detect table-like regions for better cell structure
Multiple options can be enabled simultaneously for comprehensive detection.
"""
enable_wired_table: bool = Field(
default=True,
description="Enable wired (bordered) table detection. Best for tables with visible grid lines."
)
enable_wireless_table: bool = Field(
default=True,
description="Enable wireless (borderless) table detection. Best for tables without visible borders."
)
enable_region_detection: bool = Field(
default=True,
description="Enable region detection for better table structure inference."
)
class ImageQualityMetrics(BaseModel):
"""Image quality metrics from auto-analysis."""
contrast: float = Field(..., description="Contrast level (std dev of grayscale)")
@@ -456,23 +298,6 @@ class ProcessingOptions(BaseModel):
description="Manual preprocessing config (only used when preprocessing_mode='manual')"
)
# Table detection configuration (OCR track only)
table_detection: Optional[TableDetectionConfig] = Field(
None,
description="Table detection config. If None, all table detection modes are enabled."
)
# OCR Processing Preset (OCR track only)
# Use presets for optimized configurations or CUSTOM with ocr_config for fine-tuning
ocr_preset: Optional[OCRPresetEnum] = Field(
default=OCRPresetEnum.DATASHEET,
description="OCR processing preset: text_heavy, datasheet, table_heavy, form, mixed, custom"
)
ocr_config: Optional[OCRConfig] = Field(
None,
description="Custom OCR config. Used when ocr_preset=custom or to override preset values."
)
class AnalyzeRequest(BaseModel):
"""Document analysis request"""

View File

@@ -1,583 +0,0 @@
"""
Cell Validation Engine
Validates PP-StructureV3 table detections using metric-based heuristics
to filter over-detected cells and reclassify invalid tables as TEXT elements.
Metrics used:
- Cell density: cells per 10,000 px² (normal: 0.4-1.0, over-detected: 6+)
- Average cell area: px² per cell (normal: 10,000-25,000, over-detected: ~1,600)
- Cell height: table_height / cell_count (minimum: 10px for readable text)
"""
import logging
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple
from html.parser import HTMLParser
import re
logger = logging.getLogger(__name__)
@dataclass
class CellValidationConfig:
"""Configuration for cell validation thresholds."""
max_cell_density: float = 3.0 # cells per 10,000 px²
min_avg_cell_area: float = 3000.0 # px² per cell
min_cell_height: float = 10.0 # px per cell row
enabled: bool = True
@dataclass
class TableValidationResult:
"""Result of table validation."""
is_valid: bool
table_element: Dict[str, Any]
reason: Optional[str] = None
metrics: Optional[Dict[str, float]] = None
class CellValidationEngine:
"""
Validates table elements from PP-StructureV3 output.
Over-detected tables are identified by abnormal metrics and
reclassified as TEXT elements while preserving content.
"""
def __init__(self, config: Optional[CellValidationConfig] = None):
self.config = config or CellValidationConfig()
def calculate_table_metrics(
self,
bbox: List[float],
cell_boxes: List[List[float]]
) -> Dict[str, float]:
"""
Calculate validation metrics for a table.
Args:
bbox: Table bounding box [x0, y0, x1, y1]
cell_boxes: List of cell bounding boxes
Returns:
Dictionary with calculated metrics
"""
if len(bbox) < 4:
return {"cell_count": 0, "cell_density": 0, "avg_cell_area": 0, "avg_cell_height": 0}
cell_count = len(cell_boxes)
if cell_count == 0:
return {"cell_count": 0, "cell_density": 0, "avg_cell_area": 0, "avg_cell_height": 0}
# Calculate table dimensions
table_width = bbox[2] - bbox[0]
table_height = bbox[3] - bbox[1]
table_area = table_width * table_height
if table_area <= 0:
return {"cell_count": cell_count, "cell_density": 0, "avg_cell_area": 0, "avg_cell_height": 0}
# Cell density: cells per 10,000 px²
cell_density = (cell_count / table_area) * 10000
# Average cell area
avg_cell_area = table_area / cell_count
# Average cell height (table height / cell count)
avg_cell_height = table_height / cell_count
return {
"cell_count": cell_count,
"table_width": table_width,
"table_height": table_height,
"table_area": table_area,
"cell_density": cell_density,
"avg_cell_area": avg_cell_area,
"avg_cell_height": avg_cell_height
}
def validate_table(
self,
element: Dict[str, Any]
) -> TableValidationResult:
"""
Validate a single table element.
Args:
element: Table element from PP-StructureV3 output
Returns:
TableValidationResult with validation status and metrics
"""
if not self.config.enabled:
return TableValidationResult(is_valid=True, table_element=element)
# Extract bbox and cell_boxes
bbox = element.get("bbox", [])
cell_boxes = element.get("cell_boxes", [])
# Tables without cells pass validation (structure-only tables)
if not cell_boxes:
return TableValidationResult(
is_valid=True,
table_element=element,
reason="No cells to validate"
)
# Calculate metrics
metrics = self.calculate_table_metrics(bbox, cell_boxes)
# Check cell density
if metrics["cell_density"] > self.config.max_cell_density:
return TableValidationResult(
is_valid=False,
table_element=element,
reason=f"Cell density {metrics['cell_density']:.2f} exceeds threshold {self.config.max_cell_density}",
metrics=metrics
)
# Check average cell area
if metrics["avg_cell_area"] < self.config.min_avg_cell_area:
return TableValidationResult(
is_valid=False,
table_element=element,
reason=f"Avg cell area {metrics['avg_cell_area']:.0f}px² below threshold {self.config.min_avg_cell_area}px²",
metrics=metrics
)
# Check cell height
if metrics["avg_cell_height"] < self.config.min_cell_height:
return TableValidationResult(
is_valid=False,
table_element=element,
reason=f"Avg cell height {metrics['avg_cell_height']:.1f}px below threshold {self.config.min_cell_height}px",
metrics=metrics
)
# Content-based validation: check if content looks like prose vs tabular data
content_check = self._validate_table_content(element)
if not content_check["is_tabular"]:
return TableValidationResult(
is_valid=False,
table_element=element,
reason=content_check["reason"],
metrics=metrics
)
return TableValidationResult(
is_valid=True,
table_element=element,
metrics=metrics
)
def _validate_table_content(self, element: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate table content to detect false positive tables.
Checks:
1. Sparsity: text coverage ratio (text area / table area)
2. Header: does table have proper header structure
3. Key-Value: for 2-col tables, is it a key-value list or random layout
4. Prose: are cells containing long prose text
Returns:
Dict with is_tabular (bool) and reason (str)
"""
html_content = element.get("content", "")
bbox = element.get("bbox", [])
cell_boxes = element.get("cell_boxes", [])
if not html_content or '<table' not in html_content.lower():
return {"is_tabular": True, "reason": "no_html_content"}
try:
from bs4 import BeautifulSoup
soup = BeautifulSoup(html_content, 'html.parser')
table = soup.find('table')
if not table:
return {"is_tabular": True, "reason": "no_table_element"}
rows = table.find_all('tr')
if not rows:
return {"is_tabular": True, "reason": "no_rows"}
# Extract cell contents with row structure
row_data = []
all_cells = []
for row_idx, row in enumerate(rows):
cells = row.find_all(['td', 'th'])
row_cells = []
for cell in cells:
text = cell.get_text(strip=True)
colspan = int(cell.get('colspan', 1))
is_header = cell.name == 'th'
cell_info = {
"text": text,
"length": len(text),
"colspan": colspan,
"is_header": is_header,
"row": row_idx
}
row_cells.append(cell_info)
all_cells.append(cell_info)
row_data.append(row_cells)
if not all_cells:
return {"is_tabular": True, "reason": "no_cells"}
num_rows = len(row_data)
num_cols = max(len(r) for r in row_data) if row_data else 0
# === Check 1: Sparsity (text coverage) ===
sparsity_result = self._check_sparsity(bbox, cell_boxes, all_cells)
if not sparsity_result["is_valid"]:
return {"is_tabular": False, "reason": sparsity_result["reason"]}
# === Check 2: Header structure ===
header_result = self._check_header_structure(row_data, num_cols)
if not header_result["has_header"] and num_rows > 3:
# Large table without header is suspicious
logger.debug(f"Table has no header structure with {num_rows} rows")
# === Check 3: Key-Value pattern for 2-column tables ===
if num_cols == 2:
kv_result = self._check_key_value_pattern(row_data)
if kv_result["is_kv_list"] and kv_result["confidence"] > 0.7:
# High confidence key-value list - keep as table but log
logger.debug(f"Table identified as key-value list (conf={kv_result['confidence']:.2f})")
elif not kv_result["is_kv_list"] and kv_result["is_random_layout"]:
# Random 2-column layout, not a real table
return {
"is_tabular": False,
"reason": f"random_two_column_layout (not key-value)"
}
# === Check 4: Prose content ===
long_cells = [c for c in all_cells if c["length"] > 80]
prose_ratio = len(long_cells) / len(all_cells) if all_cells else 0
if prose_ratio > 0.3:
return {
"is_tabular": False,
"reason": f"prose_content ({len(long_cells)}/{len(all_cells)} cells > 80 chars)"
}
# === Check 5: Section header as table ===
if num_rows <= 2 and num_cols <= 2:
first_row = row_data[0] if row_data else []
if len(first_row) == 1:
text = first_row[0]["text"]
if text.isupper() and len(text) < 50:
return {
"is_tabular": False,
"reason": f"section_header_only ({text[:30]})"
}
return {"is_tabular": True, "reason": "content_valid"}
except Exception as e:
logger.warning(f"Content validation failed: {e}")
return {"is_tabular": True, "reason": f"validation_error: {e}"}
def _check_sparsity(
self,
bbox: List[float],
cell_boxes: List[List[float]],
all_cells: List[Dict]
) -> Dict[str, Any]:
"""
Check text coverage ratio (sparsity).
Two-column layouts have large empty gaps in the middle.
Real tables have more uniform cell distribution.
"""
if len(bbox) < 4:
return {"is_valid": True, "reason": "no_bbox"}
table_width = bbox[2] - bbox[0]
table_height = bbox[3] - bbox[1]
table_area = table_width * table_height
if table_area <= 0:
return {"is_valid": True, "reason": "invalid_area"}
# Calculate text area from cell_boxes
if cell_boxes:
text_area = 0
for cb in cell_boxes:
if len(cb) >= 4:
w = abs(cb[2] - cb[0])
h = abs(cb[3] - cb[1])
text_area += w * h
coverage = text_area / table_area
else:
# Estimate from cell content length
total_chars = sum(c["length"] for c in all_cells)
# Rough estimate: 1 char ≈ 8x12 pixels = 96 px²
estimated_text_area = total_chars * 96
coverage = min(estimated_text_area / table_area, 1.0)
# Very sparse table (< 15% coverage) is suspicious
if coverage < 0.15:
return {
"is_valid": False,
"reason": f"sparse_content (coverage={coverage:.1%})"
}
return {"is_valid": True, "coverage": coverage}
def _check_header_structure(
self,
row_data: List[List[Dict]],
num_cols: int
) -> Dict[str, Any]:
"""
Check if table has proper header structure.
Real tables usually have:
- First row with <th> elements
- Or first row with different content pattern (labels vs values)
"""
if not row_data:
return {"has_header": False}
first_row = row_data[0]
# Check for <th> elements
th_count = sum(1 for c in first_row if c.get("is_header", False))
if th_count > 0 and th_count >= len(first_row) * 0.5:
return {"has_header": True, "type": "th_elements"}
# Check for header-like content (short, distinct from body)
if len(row_data) > 1:
first_row_avg_len = sum(c["length"] for c in first_row) / len(first_row) if first_row else 0
body_rows = row_data[1:]
body_cells = [c for row in body_rows for c in row]
body_avg_len = sum(c["length"] for c in body_cells) / len(body_cells) if body_cells else 0
# Header row should be shorter (labels) than body (data)
if first_row_avg_len < body_avg_len * 0.7:
return {"has_header": True, "type": "short_labels"}
return {"has_header": False}
def _check_key_value_pattern(
self,
row_data: List[List[Dict]]
) -> Dict[str, Any]:
"""
For 2-column tables, check if it's a key-value list.
Key-value characteristics:
- Left column: short labels (< 30 chars)
- Right column: values (can be longer)
- Consistent pattern across rows
Random layout characteristics:
- Both columns have similar length distribution
- No clear label-value relationship
"""
if not row_data:
return {"is_kv_list": False, "is_random_layout": False, "confidence": 0}
left_lengths = []
right_lengths = []
kv_rows = 0
total_rows = 0
for row in row_data:
if len(row) != 2:
continue
total_rows += 1
left = row[0]
right = row[1]
left_lengths.append(left["length"])
right_lengths.append(right["length"])
# Key-value pattern: left is short label, right is value
if left["length"] < 40 and left["length"] < right["length"] * 2:
kv_rows += 1
if total_rows == 0:
return {"is_kv_list": False, "is_random_layout": False, "confidence": 0}
kv_ratio = kv_rows / total_rows
avg_left = sum(left_lengths) / len(left_lengths) if left_lengths else 0
avg_right = sum(right_lengths) / len(right_lengths) if right_lengths else 0
# High KV ratio and left column is shorter = key-value list
if kv_ratio > 0.6 and avg_left < avg_right:
return {
"is_kv_list": True,
"is_random_layout": False,
"confidence": kv_ratio,
"avg_left": avg_left,
"avg_right": avg_right
}
# Similar lengths on both sides = random layout
if avg_left > 0 and 0.5 < avg_right / avg_left < 2.0:
# Both columns have similar content length
return {
"is_kv_list": False,
"is_random_layout": True,
"confidence": 1 - kv_ratio,
"avg_left": avg_left,
"avg_right": avg_right
}
return {
"is_kv_list": False,
"is_random_layout": False,
"confidence": 0,
"avg_left": avg_left,
"avg_right": avg_right
}
def extract_text_from_table_html(self, html_content: str) -> str:
"""
Extract plain text from table HTML content.
Args:
html_content: HTML string containing table structure
Returns:
Plain text extracted from table cells
"""
if not html_content:
return ""
try:
class TableTextExtractor(HTMLParser):
def __init__(self):
super().__init__()
self.text_parts = []
self.in_cell = False
def handle_starttag(self, tag, attrs):
if tag in ('td', 'th'):
self.in_cell = True
def handle_endtag(self, tag):
if tag in ('td', 'th'):
self.in_cell = False
def handle_data(self, data):
if self.in_cell:
stripped = data.strip()
if stripped:
self.text_parts.append(stripped)
parser = TableTextExtractor()
parser.feed(html_content)
return ' '.join(parser.text_parts)
except Exception as e:
logger.warning(f"Failed to parse table HTML: {e}")
# Fallback: strip HTML tags with regex
text = re.sub(r'<[^>]+>', ' ', html_content)
text = re.sub(r'\s+', ' ', text).strip()
return text
def reclassify_as_text(self, element: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert an over-detected table element to a TEXT element.
Args:
element: Table element to reclassify
Returns:
New TEXT element with preserved content
"""
# Extract text content from HTML
html_content = element.get("content", "")
text_content = self.extract_text_from_table_html(html_content)
# Create new TEXT element
text_element = {
"element_id": element.get("element_id", ""),
"type": "text",
"original_type": "table_reclassified", # Mark as reclassified
"content": text_content,
"page": element.get("page", 0),
"bbox": element.get("bbox", []),
"index": element.get("index", 0),
"confidence": element.get("confidence", 1.0),
"reclassified_from": "table",
"reclassification_reason": "over_detection"
}
return text_element
def validate_and_filter_elements(
self,
elements: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""
Validate all elements and filter/reclassify over-detected tables.
Args:
elements: List of elements from PP-StructureV3 output
Returns:
Tuple of (filtered_elements, statistics)
"""
filtered_elements = []
stats = {
"total_tables": 0,
"valid_tables": 0,
"reclassified_tables": 0,
"reclassification_details": []
}
for element in elements:
if element.get("type") != "table":
# Non-table elements pass through unchanged
filtered_elements.append(element)
continue
stats["total_tables"] += 1
# Validate table
result = self.validate_table(element)
if result.is_valid:
stats["valid_tables"] += 1
filtered_elements.append(element)
else:
# Reclassify as TEXT
stats["reclassified_tables"] += 1
text_element = self.reclassify_as_text(element)
filtered_elements.append(text_element)
stats["reclassification_details"].append({
"element_id": element.get("element_id"),
"reason": result.reason,
"metrics": result.metrics
})
logger.info(
f"Reclassified table {element.get('element_id')} as TEXT: {result.reason}"
)
# Re-sort by reading order (y0 then x0)
filtered_elements = self._sort_by_reading_order(filtered_elements)
return filtered_elements, stats
def _sort_by_reading_order(
self,
elements: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Sort elements by reading order (top-to-bottom, left-to-right)."""
def sort_key(elem):
bbox = elem.get("bbox", [0, 0, 0, 0])
if isinstance(bbox, dict):
y0 = bbox.get("y0", 0)
x0 = bbox.get("x0", 0)
elif isinstance(bbox, list) and len(bbox) >= 2:
x0, y0 = bbox[0], bbox[1]
else:
y0, x0 = 0, 0
return (y0, x0)
return sorted(elements, key=sort_key)

View File

@@ -16,6 +16,7 @@ from app.models.unified_document import (
DocumentElement, BoundingBox, ElementType, Dimensions
)
from app.core.config import settings
from app.utils.bbox_utils import normalize_bbox as _normalize_bbox
logger = logging.getLogger(__name__)
@@ -49,32 +50,9 @@ class TextRegion:
@property
def normalized_bbox(self) -> Tuple[float, float, float, float]:
"""Get normalized bbox as (x0, y0, x1, y1)."""
if not self.bbox:
return (0, 0, 0, 0)
# Check if bbox is nested list format [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
# This is common PaddleOCR polygon format
if len(self.bbox) >= 1 and isinstance(self.bbox[0], (list, tuple)):
# Nested format: extract all x and y coordinates
xs = [pt[0] for pt in self.bbox if len(pt) >= 2]
ys = [pt[1] for pt in self.bbox if len(pt) >= 2]
if xs and ys:
return (min(xs), min(ys), max(xs), max(ys))
return (0, 0, 0, 0)
# Flat format
if len(self.bbox) == 4:
# Simple [x0, y0, x1, y1] format
return (float(self.bbox[0]), float(self.bbox[1]),
float(self.bbox[2]), float(self.bbox[3]))
elif len(self.bbox) >= 8:
# Flat polygon format: [x1, y1, x2, y2, x3, y3, x4, y4]
xs = [self.bbox[i] for i in range(0, len(self.bbox), 2)]
ys = [self.bbox[i] for i in range(1, len(self.bbox), 2)]
return (min(xs), min(ys), max(xs), max(ys))
return (0, 0, 0, 0)
"""Get normalized bbox as (x0, y0, x1, y1). Uses shared bbox utility."""
result = _normalize_bbox(self.bbox)
return result if result else (0, 0, 0, 0)
@property
def center(self) -> Tuple[float, float]:
@@ -171,10 +149,6 @@ class GapFillingService:
settings, 'gap_filling_enabled', True
)
# Legacy compatibility
self.iou_threshold = getattr(settings, 'gap_filling_iou_threshold', 0.15)
self.dedup_iou_threshold = getattr(settings, 'gap_filling_dedup_iou_threshold', 0.5)
def should_activate(
self,
raw_ocr_regions: List[TextRegion],

View File

@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from datetime import datetime
import uuid
import gc # For garbage collection
import warnings # For suppressing PaddleX deprecation warnings
from paddleocr import PaddleOCR, PPStructureV3
from PIL import Image
@@ -34,7 +35,21 @@ from app.services.layout_preprocessing_service import (
get_layout_preprocessing_service,
LayoutPreprocessingService,
)
from app.schemas.task import PreprocessingModeEnum, PreprocessingConfig, TableDetectionConfig
from app.schemas.task import PreprocessingModeEnum, PreprocessingConfig
from dataclasses import dataclass
@dataclass
class TableDetectionConfig:
"""Internal table detection configuration for OCR service.
Note: This was previously in app.schemas.task but is now internal to OCR service
as frontend no longer configures these options.
"""
enable_wired_table: bool = True
enable_wireless_table: bool = True
enable_region_detection: bool = True
# Import dual-track components
try:
@@ -798,7 +813,12 @@ class OCRService:
if textline_ori_model:
pp_kwargs['textline_orientation_model_name'] = textline_ori_model
self.structure_engine = PPStructureV3(**pp_kwargs)
# Suppress DeprecationWarning during PPStructureV3 initialization
# Workaround for PaddleX bug: it incorrectly treats Python's datetime.utcnow()
# deprecation warning as a model loading error in PP-Chart2Table
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
self.structure_engine = PPStructureV3(**pp_kwargs)
# Track model loading for cache management
self._model_last_used['structure'] = datetime.now()
@@ -881,7 +901,10 @@ class OCRService:
if settings.textline_orientation_model_name:
cpu_kwargs['textline_orientation_model_name'] = settings.textline_orientation_model_name
self.structure_engine = PPStructureV3(**cpu_kwargs)
# Suppress DeprecationWarning during PPStructureV3 initialization (CPU fallback)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
self.structure_engine = PPStructureV3(**cpu_kwargs)
self._current_layout_model = layout_model # Track current model for recreation check
# Track table detection config for recreation check
if table_detection_config:
@@ -1429,6 +1452,22 @@ class OCRService:
raw_ocr_regions=text_regions # For table content rebuilding
)
# Get detected rotation from layout analysis (default: "0" = no rotation)
detected_rotation = "0"
if layout_data:
detected_rotation = layout_data.get('detected_rotation', '0')
# Adjust page dimensions based on detected rotation
# When rotation is 90° or 270°, the page orientation changes (portrait <-> landscape)
# PP-StructureV3 returns coordinates based on the rotated image, so we need to swap dimensions
if detected_rotation in ['90', '270']:
original_width, original_height = ocr_width, ocr_height
ocr_width, ocr_height = original_height, original_width
logger.info(
f"Page dimensions adjusted for {detected_rotation}° rotation: "
f"{original_width}x{original_height} -> {ocr_width}x{ocr_height}"
)
# Generate Markdown
markdown_content = self.generate_markdown(text_regions, layout_data)
@@ -1450,7 +1489,8 @@ class OCRService:
'ocr_dimensions': {
'width': ocr_width,
'height': ocr_height
}
},
'detected_rotation': detected_rotation # Document orientation: "0", "90", "180", "270"
}
# If layout data is enhanced, add enhanced results for converter
@@ -1705,7 +1745,8 @@ class OCRService:
'total_elements': result['total_elements'],
'reading_order': result['reading_order'],
'element_types': result.get('element_types', {}),
'enhanced': True
'enhanced': True,
'detected_rotation': result.get('detected_rotation', '0') # Document orientation
}
# Extract images metadata

View File

@@ -1,507 +0,0 @@
"""
Tool_OCR - PDF Generator Service
Converts Markdown to layout-preserved PDFs using Pandoc + WeasyPrint
"""
import logging
import subprocess
from pathlib import Path
from typing import Optional, Dict
from datetime import datetime
from weasyprint import HTML, CSS
from markdown import markdown
from app.core.config import settings
logger = logging.getLogger(__name__)
class PDFGenerationError(Exception):
"""Exception raised when PDF generation fails"""
pass
class PDFGenerator:
"""
PDF generation service with layout preservation
Supports two generation methods:
1. Pandoc (preferred): Markdown → HTML → PDF via pandoc command
2. WeasyPrint (fallback): Direct Python-based HTML → PDF conversion
"""
# Default CSS template for layout preservation
DEFAULT_CSS = """
@page {
size: A4;
margin: 2cm;
}
body {
font-family: "Noto Sans CJK SC", "Noto Sans CJK TC", "Microsoft YaHei", "SimSun", sans-serif;
font-size: 11pt;
line-height: 1.6;
color: #333;
}
h1 {
font-size: 24pt;
font-weight: bold;
margin-top: 0;
margin-bottom: 12pt;
color: #000;
page-break-after: avoid;
}
h2 {
font-size: 18pt;
font-weight: bold;
margin-top: 18pt;
margin-bottom: 10pt;
color: #000;
page-break-after: avoid;
}
h3 {
font-size: 14pt;
font-weight: bold;
margin-top: 14pt;
margin-bottom: 8pt;
color: #000;
page-break-after: avoid;
}
p {
margin: 0 0 10pt 0;
text-align: justify;
}
table {
width: 100%;
border-collapse: collapse;
margin: 12pt 0;
page-break-inside: avoid;
}
table th {
background-color: #f0f0f0;
border: 1px solid #ccc;
padding: 8pt;
text-align: left;
font-weight: bold;
}
table td {
border: 1px solid #ccc;
padding: 8pt;
text-align: left;
}
code {
font-family: "Courier New", monospace;
font-size: 10pt;
background-color: #f5f5f5;
padding: 2pt 4pt;
border-radius: 3px;
}
pre {
background-color: #f5f5f5;
border: 1px solid #ddd;
border-radius: 5px;
padding: 10pt;
overflow-x: auto;
page-break-inside: avoid;
}
pre code {
background-color: transparent;
padding: 0;
}
img {
max-width: 100%;
height: auto;
display: block;
margin: 12pt auto;
page-break-inside: avoid;
}
blockquote {
border-left: 4px solid #ddd;
padding-left: 12pt;
margin: 12pt 0;
color: #666;
font-style: italic;
}
ul, ol {
margin: 10pt 0;
padding-left: 20pt;
}
li {
margin: 5pt 0;
}
hr {
border: none;
border-top: 1px solid #ccc;
margin: 20pt 0;
}
.page-break {
page-break-after: always;
}
"""
# Academic paper template
ACADEMIC_CSS = """
@page {
size: A4;
margin: 2.5cm;
}
body {
font-family: "Times New Roman", "Noto Serif CJK SC", serif;
font-size: 12pt;
line-height: 1.8;
color: #000;
}
h1 {
font-size: 20pt;
text-align: center;
margin-bottom: 24pt;
page-break-after: avoid;
}
h2 {
font-size: 16pt;
margin-top: 20pt;
margin-bottom: 12pt;
page-break-after: avoid;
}
h3 {
font-size: 14pt;
margin-top: 16pt;
margin-bottom: 10pt;
page-break-after: avoid;
}
p {
text-indent: 2em;
text-align: justify;
margin: 0 0 12pt 0;
}
table {
width: 100%;
border-collapse: collapse;
margin: 16pt auto;
page-break-inside: avoid;
}
table caption {
font-weight: bold;
margin-bottom: 8pt;
}
"""
# Business report template
BUSINESS_CSS = """
@page {
size: A4;
margin: 2cm 2.5cm;
}
body {
font-family: "Arial", "Noto Sans CJK SC", sans-serif;
font-size: 11pt;
line-height: 1.5;
color: #333;
}
h1 {
font-size: 22pt;
color: #0066cc;
border-bottom: 3px solid #0066cc;
padding-bottom: 8pt;
margin-bottom: 20pt;
page-break-after: avoid;
}
h2 {
font-size: 16pt;
color: #0066cc;
margin-top: 20pt;
margin-bottom: 12pt;
page-break-after: avoid;
}
table {
width: 100%;
border-collapse: collapse;
margin: 16pt 0;
}
table th {
background-color: #0066cc;
color: white;
padding: 10pt;
font-weight: bold;
}
table td {
border: 1px solid #ddd;
padding: 10pt;
}
table tr:nth-child(even) {
background-color: #f9f9f9;
}
"""
def __init__(self):
"""Initialize PDF generator"""
self.css_templates = {
"default": self.DEFAULT_CSS,
"academic": self.ACADEMIC_CSS,
"business": self.BUSINESS_CSS,
}
def check_pandoc_available(self) -> bool:
"""
Check if Pandoc is installed and available
Returns:
bool: True if pandoc is available, False otherwise
"""
try:
result = subprocess.run(
["pandoc", "--version"],
capture_output=True,
text=True,
timeout=5
)
return result.returncode == 0
except (subprocess.TimeoutExpired, FileNotFoundError):
logger.warning("Pandoc not found or timed out")
return False
def generate_pdf_pandoc(
self,
markdown_path: Path,
output_path: Path,
css_template: str = "default",
metadata: Optional[Dict] = None
) -> Path:
"""
Generate PDF using Pandoc (preferred method)
Args:
markdown_path: Path to input Markdown file
output_path: Path to output PDF file
css_template: CSS template name or custom CSS string
metadata: Optional metadata dict (title, author, date)
Returns:
Path: Path to generated PDF file
Raises:
PDFGenerationError: If PDF generation fails
"""
try:
# Create temporary CSS file
css_content = self.css_templates.get(css_template, css_template)
css_file = output_path.parent / f"temp_{datetime.now().timestamp()}.css"
css_file.write_text(css_content, encoding="utf-8")
# Build pandoc command
pandoc_cmd = [
"pandoc",
str(markdown_path),
"-o", str(output_path),
"--pdf-engine=weasyprint",
"--css", str(css_file),
"--standalone",
"--from=markdown+tables+fenced_code_blocks+footnotes",
]
# Add metadata if provided
if metadata:
if metadata.get("title"):
pandoc_cmd.extend(["--metadata", f"title={metadata['title']}"])
if metadata.get("author"):
pandoc_cmd.extend(["--metadata", f"author={metadata['author']}"])
if metadata.get("date"):
pandoc_cmd.extend(["--metadata", f"date={metadata['date']}"])
# Execute pandoc
logger.info(f"Executing pandoc: {' '.join(pandoc_cmd)}")
result = subprocess.run(
pandoc_cmd,
capture_output=True,
text=True,
timeout=60 # 60 second timeout for large documents
)
# Clean up temporary CSS file
css_file.unlink(missing_ok=True)
if result.returncode != 0:
error_msg = f"Pandoc failed: {result.stderr}"
logger.error(error_msg)
raise PDFGenerationError(error_msg)
if not output_path.exists():
raise PDFGenerationError(f"PDF file not created: {output_path}")
logger.info(f"PDF generated successfully via Pandoc: {output_path}")
return output_path
except subprocess.TimeoutExpired:
css_file.unlink(missing_ok=True)
raise PDFGenerationError("Pandoc execution timed out")
except Exception as e:
css_file.unlink(missing_ok=True)
raise PDFGenerationError(f"Pandoc PDF generation failed: {str(e)}")
def generate_pdf_weasyprint(
self,
markdown_path: Path,
output_path: Path,
css_template: str = "default",
metadata: Optional[Dict] = None
) -> Path:
"""
Generate PDF using WeasyPrint directly (fallback method)
Args:
markdown_path: Path to input Markdown file
output_path: Path to output PDF file
css_template: CSS template name or custom CSS string
metadata: Optional metadata dict (title, author, date)
Returns:
Path: Path to generated PDF file
Raises:
PDFGenerationError: If PDF generation fails
"""
try:
# Read Markdown content
markdown_content = markdown_path.read_text(encoding="utf-8")
# Convert Markdown to HTML
html_content = markdown(
markdown_content,
extensions=[
'tables',
'fenced_code',
'codehilite',
'nl2br',
'sane_lists',
]
)
# Wrap HTML with proper structure
title = metadata.get("title", markdown_path.stem) if metadata else markdown_path.stem
full_html = f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<title>{title}</title>
</head>
<body>
{html_content}
</body>
</html>
"""
# Get CSS content
css_content = self.css_templates.get(css_template, css_template)
# Generate PDF
logger.info(f"Generating PDF via WeasyPrint: {output_path}")
html = HTML(string=full_html, base_url=str(markdown_path.parent))
css = CSS(string=css_content)
html.write_pdf(str(output_path), stylesheets=[css])
if not output_path.exists():
raise PDFGenerationError(f"PDF file not created: {output_path}")
logger.info(f"PDF generated successfully via WeasyPrint: {output_path}")
return output_path
except Exception as e:
raise PDFGenerationError(f"WeasyPrint PDF generation failed: {str(e)}")
def generate_pdf(
self,
markdown_path: Path,
output_path: Path,
css_template: str = "default",
metadata: Optional[Dict] = None,
prefer_pandoc: bool = True
) -> Path:
"""
Generate PDF from Markdown with automatic fallback
Args:
markdown_path: Path to input Markdown file
output_path: Path to output PDF file
css_template: CSS template name ("default", "academic", "business") or custom CSS
metadata: Optional metadata dict (title, author, date)
prefer_pandoc: Use Pandoc if available, fallback to WeasyPrint
Returns:
Path: Path to generated PDF file
Raises:
PDFGenerationError: If both methods fail
"""
if not markdown_path.exists():
raise PDFGenerationError(f"Markdown file not found: {markdown_path}")
# Ensure output directory exists
output_path.parent.mkdir(parents=True, exist_ok=True)
# Try Pandoc first if preferred and available
if prefer_pandoc and self.check_pandoc_available():
try:
return self.generate_pdf_pandoc(markdown_path, output_path, css_template, metadata)
except PDFGenerationError as e:
logger.warning(f"Pandoc failed, falling back to WeasyPrint: {e}")
# Fall through to WeasyPrint
# Use WeasyPrint (fallback or direct)
return self.generate_pdf_weasyprint(markdown_path, output_path, css_template, metadata)
def get_available_templates(self) -> Dict[str, str]:
"""
Get list of available CSS templates
Returns:
Dict mapping template names to descriptions
"""
return {
"default": "通用排版模板,適合大多數文檔",
"academic": "學術論文模板,適合研究報告",
"business": "商業報告模板,適合企業文檔",
}
def save_custom_template(self, template_name: str, css_content: str) -> None:
"""
Save a custom CSS template
Args:
template_name: Template name
css_content: CSS content
"""
self.css_templates[template_name] = css_content
logger.info(f"Custom CSS template saved: {template_name}")

View File

@@ -25,6 +25,7 @@ from PIL import Image
from html.parser import HTMLParser
from app.core.config import settings
from app.utils.bbox_utils import normalize_bbox
# Import table column corrector for column alignment fix
try:
@@ -1258,8 +1259,44 @@ class PDFGeneratorService:
else:
logger.warning(f"Image file not found: {saved_path}")
# Also check for embedded images in table elements
# These are images detected inside table regions by PP-Structure
elif elem_type == 'table':
metadata = elem.metadata if hasattr(elem, 'metadata') else elem.get('metadata', {})
embedded_images = metadata.get('embedded_images', []) if metadata else []
for emb_img in embedded_images:
emb_bbox = emb_img.get('bbox', [])
if emb_bbox and len(emb_bbox) >= 4:
ex0, ey0, ex1, ey1 = emb_bbox[0], emb_bbox[1], emb_bbox[2], emb_bbox[3]
exclusion_zones.append((ex0, ey0, ex1, ey1))
# Also render the embedded image
saved_path = emb_img.get('saved_path', '')
if saved_path:
image_path = result_dir / saved_path
if not image_path.exists():
image_path = result_dir / Path(saved_path).name
if image_path.exists():
try:
pdf_x = ex0
pdf_y = current_height - ey1
img_width = ex1 - ex0
img_height = ey1 - ey0
pdf_canvas.drawImage(
str(image_path),
pdf_x, pdf_y,
width=img_width,
height=img_height,
preserveAspectRatio=True,
mask='auto'
)
image_elements_rendered += 1
logger.debug(f"Rendered embedded image: {saved_path} at ({pdf_x:.1f}, {pdf_y:.1f})")
except Exception as e:
logger.warning(f"Failed to render embedded image {saved_path}: {e}")
if image_elements_rendered > 0:
logger.info(f"Rendered {image_elements_rendered} image elements (figures/charts/seals/formulas)")
logger.info(f"Rendered {image_elements_rendered} image elements (figures/charts/seals/formulas/embedded)")
if exclusion_zones:
logger.info(f"Collected {len(exclusion_zones)} exclusion zones for text avoidance")
@@ -1857,38 +1894,8 @@ class PDFGeneratorService:
return None
def _get_bbox_coords(self, bbox: Union[Dict, List[List[float]], List[float]]) -> Optional[Tuple[float, float, float, float]]:
"""將任何 bbox 格式 (dict, 多邊形或 [x1,y1,x2,y2]) 轉換為 [x_min, y_min, x_max, y_max]"""
try:
if bbox is None:
return None
# Dict format from UnifiedDocument: {"x0": ..., "y0": ..., "x1": ..., "y1": ...}
if isinstance(bbox, dict):
if 'x0' in bbox and 'y0' in bbox and 'x1' in bbox and 'y1' in bbox:
return float(bbox['x0']), float(bbox['y0']), float(bbox['x1']), float(bbox['y1'])
else:
logger.warning(f"Dict bbox 缺少必要欄位: {bbox}")
return None
if not isinstance(bbox, (list, tuple)) or len(bbox) < 4:
return None
if isinstance(bbox[0], (list, tuple)):
# 處理多邊形 [[x, y], ...]
x_coords = [p[0] for p in bbox if isinstance(p, (list, tuple)) and len(p) >= 2]
y_coords = [p[1] for p in bbox if isinstance(p, (list, tuple)) and len(p) >= 2]
if not x_coords or not y_coords:
return None
return min(x_coords), min(y_coords), max(x_coords), max(y_coords)
elif isinstance(bbox[0], (int, float)) and len(bbox) == 4:
# 處理 [x1, y1, x2, y2]
return float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])
else:
logger.warning(f"未知的 bbox 格式: {bbox}")
return None
except Exception as e:
logger.error(f"解析 bbox {bbox} 時出錯: {e}")
return None
"""將任何 bbox 格式 (dict, 多邊形或 [x1,y1,x2,y2]) 轉換為 [x_min, y_min, x_max, y_max]. Uses shared bbox utility."""
return normalize_bbox(bbox)
def _is_bbox_inside(self, inner_bbox_data: Dict, outer_bbox_data: Dict, tolerance: float = 5.0) -> bool:
"""
@@ -2463,29 +2470,7 @@ class PDFGeneratorService:
else:
logger.info("[TABLE] cell_boxes rendering failed, using ReportLab Table with borders")
else:
# Grid mismatch: try cellboxes-first rendering if enabled
if settings.table_rendering_prefer_cellboxes:
logger.info(f"[TABLE] Grid mismatch, trying cellboxes-first rendering")
from app.services.pdf_table_renderer import TableRenderer, TableRenderConfig
renderer = TableRenderer(TableRenderConfig())
success = renderer.render_from_cellboxes_grid(
pdf_canvas,
cell_boxes,
html_content,
tuple(raw_bbox),
page_height,
scale_w,
scale_h,
row_threshold=settings.table_cellboxes_row_threshold,
col_threshold=settings.table_cellboxes_col_threshold
)
if success:
logger.info("[TABLE] cellboxes-first rendering succeeded, skipping HTML-based rendering")
return # Table fully rendered, exit early
else:
logger.info("[TABLE] cellboxes-first rendering failed, falling back to HTML-based")
else:
logger.info(f"[TABLE] Grid validation failed (mismatch), using ReportLab Table with borders")
logger.info(f"[TABLE] Grid validation failed (mismatch), using ReportLab Table with borders")
else:
logger.info("[TABLE] No valid bbox for grid validation, using ReportLab Table with borders")
@@ -2942,47 +2927,16 @@ class PDFGeneratorService:
"""
Check the quality of cell_boxes to determine rendering strategy.
Always returns 'good' to use pure PP-Structure output (quality check removed).
Args:
cell_boxes: List of cell bounding boxes
element_id: Optional element ID for logging
Returns:
'good' if cell_boxes form a proper grid, 'bad' otherwise
'good' - always use cell_boxes rendering
"""
# If quality check is disabled, always return 'good' to use pure PP-Structure output
if not settings.table_quality_check_enabled:
logger.debug(f"[TABLE QUALITY] {element_id}: good - quality check disabled (pure PP-Structure mode)")
return 'good'
if not cell_boxes or len(cell_boxes) < 2:
logger.debug(f"[TABLE QUALITY] {element_id}: bad - too few cells ({len(cell_boxes) if cell_boxes else 0})")
return 'bad' # No cell_boxes or too few
# Count overlapping cell pairs
overlap_count = 0
for i, box1 in enumerate(cell_boxes):
for j, box2 in enumerate(cell_boxes):
if i >= j:
continue
if not isinstance(box1, (list, tuple)) or len(box1) < 4:
continue
if not isinstance(box2, (list, tuple)) or len(box2) < 4:
continue
x_overlap = box1[0] < box2[2] and box1[2] > box2[0]
y_overlap = box1[1] < box2[3] and box1[3] > box2[1]
if x_overlap and y_overlap:
overlap_count += 1
total_pairs = len(cell_boxes) * (len(cell_boxes) - 1) // 2
overlap_ratio = overlap_count / total_pairs if total_pairs > 0 else 0
# Relaxed threshold: 20% overlap instead of 10% to allow more tables through
# This is because PP-StructureV3's cell detection sometimes has slight overlaps
if overlap_ratio > 0.20:
logger.info(f"[TABLE QUALITY] {element_id}: bad - overlap ratio {overlap_ratio:.2%} > 20%")
return 'bad'
logger.debug(f"[TABLE QUALITY] {element_id}: good - {len(cell_boxes)} cells, overlap {overlap_ratio:.2%}")
logger.debug(f"[TABLE QUALITY] {element_id}: good - pure PP-Structure mode")
return 'good'
def _draw_table_with_cell_boxes(

View File

@@ -15,6 +15,8 @@ from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
from app.utils.bbox_utils import normalize_bbox
logger = logging.getLogger(__name__)
# Color palette for different element types (RGB)
@@ -238,42 +240,8 @@ class PPStructureDebug:
draw.text((legend_x + 18, legend_y), "raw_ocr", fill=(0, 0, 0), font=font)
def _normalize_bbox(self, bbox: Any) -> Optional[Tuple[float, float, float, float]]:
"""Normalize bbox to (x0, y0, x1, y1) format."""
if not bbox:
return None
try:
# Handle nested list format [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
if isinstance(bbox, (list, tuple)) and len(bbox) >= 1:
if isinstance(bbox[0], (list, tuple)):
xs = [pt[0] for pt in bbox if len(pt) >= 2]
ys = [pt[1] for pt in bbox if len(pt) >= 2]
if xs and ys:
return (min(xs), min(ys), max(xs), max(ys))
# Handle flat list [x0, y0, x1, y1]
if isinstance(bbox, (list, tuple)) and len(bbox) == 4:
return (float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3]))
# Handle flat polygon [x1, y1, x2, y2, ...]
if isinstance(bbox, (list, tuple)) and len(bbox) >= 8:
xs = [bbox[i] for i in range(0, len(bbox), 2)]
ys = [bbox[i] for i in range(1, len(bbox), 2)]
return (min(xs), min(ys), max(xs), max(ys))
# Handle dict format
if isinstance(bbox, dict):
return (
float(bbox.get('x0', bbox.get('x_min', 0))),
float(bbox.get('y0', bbox.get('y_min', 0))),
float(bbox.get('x1', bbox.get('x_max', 0))),
float(bbox.get('y1', bbox.get('y_max', 0)))
)
except (TypeError, ValueError, IndexError) as e:
logger.warning(f"Failed to normalize bbox {bbox}: {e}")
return None
"""Normalize bbox to (x0, y0, x1, y1) format. Uses shared bbox utility."""
return normalize_bbox(bbox)
def _generate_summary(
self,

View File

@@ -28,11 +28,9 @@ from PIL import Image
import numpy as np
import cv2
from app.models.unified_document import ElementType
from app.services.cell_validation_engine import CellValidationEngine, CellValidationConfig
from app.core.config import settings
from app.services.memory_manager import prediction_context
from app.services.cv_table_detector import CVTableDetector
from app.services.table_content_rebuilder import TableContentRebuilder
logger = logging.getLogger(__name__)
@@ -159,6 +157,7 @@ class PPStructureEnhanced:
all_images = []
all_tables = []
visualization_dir = None
detected_rotation = "0" # Default: no rotation
# Process each page result
for page_idx, page_result in enumerate(results):
@@ -247,6 +246,56 @@ class PPStructureEnhanced:
ocr_count = len(overall_ocr_res.get('rec_texts', []))
logger.info(f"Found overall_ocr_res with {ocr_count} text regions")
# Extract doc_preprocessor_res for orientation detection
# When use_doc_orientation_classify=True, this contains the detected rotation angle
# Note: doc_preprocessor_res may be at top-level result_json OR inside 'res'
doc_preprocessor_res = None
# First, check result_dict (might be result_json['res'])
if 'doc_preprocessor_res' in result_dict:
doc_preprocessor_res = result_dict['doc_preprocessor_res']
logger.info("Found doc_preprocessor_res in result_dict")
# Also check top-level result_json if it exists and differs from result_dict
elif hasattr(page_result, 'json') and isinstance(page_result.json, dict):
if 'doc_preprocessor_res' in page_result.json:
doc_preprocessor_res = page_result.json['doc_preprocessor_res']
logger.info("Found doc_preprocessor_res at top-level result_json")
# Debug: Log available keys to help diagnose structure issues
if doc_preprocessor_res is None:
logger.warning(f"doc_preprocessor_res NOT found. result_dict keys: {list(result_dict.keys()) if result_dict else 'None'}")
if hasattr(page_result, 'json') and isinstance(page_result.json, dict):
logger.warning(f"result_json keys: {list(page_result.json.keys())}")
if doc_preprocessor_res:
# Debug: Log the complete structure of doc_preprocessor_res
logger.info(f"doc_preprocessor_res keys: {list(doc_preprocessor_res.keys()) if isinstance(doc_preprocessor_res, dict) else type(doc_preprocessor_res)}")
logger.info(f"doc_preprocessor_res content: {doc_preprocessor_res}")
# Try multiple possible key names for rotation info
# PaddleOCR may use different structures depending on version
label_names = doc_preprocessor_res.get('label_names', [])
class_ids = doc_preprocessor_res.get('class_ids', [])
labels = doc_preprocessor_res.get('labels', [])
angle = doc_preprocessor_res.get('angle', None)
# Determine rotation from available data
detected_rotation = "0"
if label_names:
detected_rotation = str(label_names[0])
elif class_ids:
# class_ids: 0=0°, 1=90°, 2=180°, 3=270°
rotation_map = {0: "0", 1: "90", 2: "180", 3: "270"}
detected_rotation = rotation_map.get(class_ids[0], "0")
elif labels:
detected_rotation = str(labels[0])
elif angle is not None:
detected_rotation = str(angle)
logger.info(f"Document orientation detected: {detected_rotation}° (label_names={label_names}, class_ids={class_ids}, labels={labels}, angle={angle})")
else:
detected_rotation = "0" # Default: no rotation
# Process parsing_res_list if found
if parsing_res_list:
elements = self._process_parsing_res_list(
@@ -295,7 +344,8 @@ class PPStructureEnhanced:
'tables': all_tables,
'images': all_images,
'element_types': self._count_element_types(all_elements),
'has_parsing_res_list': parsing_res_list is not None
'has_parsing_res_list': parsing_res_list is not None,
'detected_rotation': detected_rotation # Document orientation: "0", "90", "180", "270"
}
# Add visualization directory if available
@@ -653,42 +703,6 @@ class PPStructureEnhanced:
element['embedded_images'] = embedded_images
logger.info(f"[TABLE] Embedded {len(embedded_images)} images into table")
# 4. Table content rebuilding from raw OCR regions
# When cell_boxes have boundary issues, rebuild table content from raw OCR
# Only if table_content_rebuilder is enabled (disabled by default as it's a patch behavior)
logger.info(f"[TABLE] raw_ocr_regions available: {raw_ocr_regions is not None and len(raw_ocr_regions) if raw_ocr_regions else 0}")
logger.info(f"[TABLE] cell_boxes available: {len(element.get('cell_boxes', []))}")
if settings.table_content_rebuilder_enabled and raw_ocr_regions and element.get('cell_boxes'):
rebuilder = TableContentRebuilder()
should_rebuild, rebuild_reason = rebuilder.should_rebuild(
element['cell_boxes'],
bbox,
element.get('html', '')
)
if should_rebuild:
logger.info(f"[TABLE] Triggering table rebuild: {rebuild_reason}")
rebuilt_table, rebuild_stats = rebuilder.rebuild_table(
cell_boxes=element['cell_boxes'],
table_bbox=bbox,
raw_ocr_regions=raw_ocr_regions,
original_html=element.get('html', '')
)
if rebuilt_table:
# Update element with rebuilt content
element['html'] = rebuilt_table['html']
element['rebuilt_table'] = rebuilt_table
element['rebuild_stats'] = rebuild_stats
element['extracted_text'] = self._extract_text_from_html(rebuilt_table['html'])
logger.info(
f"[TABLE] Rebuilt table: {rebuilt_table['rows']}x{rebuilt_table['cols']} "
f"with {len(rebuilt_table['cells'])} cells"
)
else:
logger.warning(f"[TABLE] Rebuild failed: {rebuild_stats.get('reason', 'unknown')}")
element['rebuild_stats'] = rebuild_stats
# Special handling for images/figures/charts/stamps (visual elements that need cropping)
elif mapped_type in [ElementType.IMAGE, ElementType.FIGURE, ElementType.CHART, ElementType.DIAGRAM, ElementType.STAMP, ElementType.LOGO]:
# Save image if path provided
@@ -718,21 +732,6 @@ class PPStructureEnhanced:
elements.append(element)
logger.debug(f"Processed element {idx}: type={mapped_type}, bbox={bbox}")
# Apply cell validation to filter over-detected tables
if settings.cell_validation_enabled:
cell_validator = CellValidationEngine(CellValidationConfig(
max_cell_density=settings.cell_validation_max_density,
min_avg_cell_area=settings.cell_validation_min_cell_area,
min_cell_height=settings.cell_validation_min_cell_height,
enabled=True
))
elements, validation_stats = cell_validator.validate_and_filter_elements(elements)
if validation_stats['reclassified_tables'] > 0:
logger.info(
f"Cell validation: {validation_stats['reclassified_tables']}/{validation_stats['total_tables']} "
f"tables reclassified as TEXT due to over-detection"
)
return elements
def _embed_images_in_table(

View File

@@ -1,806 +0,0 @@
"""
Table Content Rebuilder
Rebuilds table content from raw OCR regions when PP-StructureV3's HTML output
is incorrect due to cell merge errors or boundary detection issues.
This module addresses the key problem: PP-StructureV3's ML-based table recognition
often merges multiple cells incorrectly, especially for borderless tables.
The solution uses:
1. cell_boxes validation (filter out-of-bounds cells)
2. Raw OCR regions to rebuild accurate cell content
3. Grid-based row/col position calculation
"""
import logging
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple
from collections import defaultdict
logger = logging.getLogger(__name__)
@dataclass
class CellBox:
"""Represents a validated cell bounding box."""
x0: float
y0: float
x1: float
y1: float
original_index: int
@property
def center_y(self) -> float:
return (self.y0 + self.y1) / 2
@property
def center_x(self) -> float:
return (self.x0 + self.x1) / 2
@property
def area(self) -> float:
return max(0, (self.x1 - self.x0) * (self.y1 - self.y0))
@dataclass
class OCRTextRegion:
"""Represents a raw OCR text region."""
text: str
x0: float
y0: float
x1: float
y1: float
confidence: float = 1.0
@property
def center_y(self) -> float:
return (self.y0 + self.y1) / 2
@property
def center_x(self) -> float:
return (self.x0 + self.x1) / 2
@dataclass
class RebuiltCell:
"""Represents a rebuilt table cell."""
row: int
col: int
row_span: int
col_span: int
content: str
bbox: Optional[List[float]] = None
ocr_regions: List[OCRTextRegion] = None
def __post_init__(self):
if self.ocr_regions is None:
self.ocr_regions = []
class TableContentRebuilder:
"""
Rebuilds table content from raw OCR regions and validated cell_boxes.
This class solves the problem where PP-StructureV3's HTML output incorrectly
merges multiple cells. Instead of relying on the ML-generated HTML, it:
1. Validates cell_boxes against table bbox
2. Groups cell_boxes into rows/columns by coordinate clustering
3. Fills each cell with matching raw OCR text
4. Generates correct table structure
"""
def __init__(
self,
boundary_tolerance: float = 20.0,
row_clustering_threshold: float = 15.0,
col_clustering_threshold: float = 15.0,
iou_threshold_for_ocr_match: float = 0.3,
min_text_coverage: float = 0.5
):
"""
Initialize the rebuilder.
Args:
boundary_tolerance: Tolerance for cell_boxes boundary check (pixels)
row_clustering_threshold: Max Y-distance for cells in same row (pixels)
col_clustering_threshold: Max X-distance for cells in same column (pixels)
iou_threshold_for_ocr_match: Min IoU to consider OCR region inside cell
min_text_coverage: Min overlap ratio for OCR text to be assigned to cell
"""
self.boundary_tolerance = boundary_tolerance
self.row_clustering_threshold = row_clustering_threshold
self.col_clustering_threshold = col_clustering_threshold
self.iou_threshold = iou_threshold_for_ocr_match
self.min_text_coverage = min_text_coverage
def validate_cell_boxes(
self,
cell_boxes: List[List[float]],
table_bbox: List[float]
) -> Tuple[List[CellBox], Dict[str, Any]]:
"""
Validate cell_boxes against table bbox, filtering invalid ones.
Args:
cell_boxes: List of cell bounding boxes [[x0, y0, x1, y1], ...]
table_bbox: Table bounding box [x0, y0, x1, y1]
Returns:
Tuple of (valid_cells, validation_stats)
"""
if not cell_boxes or len(table_bbox) < 4:
return [], {"total": 0, "valid": 0, "invalid": 0, "reason": "empty_input"}
table_x0, table_y0, table_x1, table_y1 = table_bbox[:4]
table_height = table_y1 - table_y0
table_width = table_x1 - table_x0
# Expanded table bounds with tolerance
expanded_y1 = table_y1 + self.boundary_tolerance
expanded_x1 = table_x1 + self.boundary_tolerance
expanded_y0 = table_y0 - self.boundary_tolerance
expanded_x0 = table_x0 - self.boundary_tolerance
valid_cells = []
invalid_reasons = defaultdict(int)
for idx, box in enumerate(cell_boxes):
if not box or len(box) < 4:
invalid_reasons["invalid_format"] += 1
continue
x0, y0, x1, y1 = box[:4]
# Check if cell is significantly outside table bounds
# Cell's bottom (y1) shouldn't exceed table's bottom + tolerance
if y1 > expanded_y1:
invalid_reasons["y1_exceeds_table"] += 1
continue
# Cell's top (y0) shouldn't be above table's top - tolerance
if y0 < expanded_y0:
invalid_reasons["y0_above_table"] += 1
continue
# Cell's right (x1) shouldn't exceed table's right + tolerance
if x1 > expanded_x1:
invalid_reasons["x1_exceeds_table"] += 1
continue
# Cell's left (x0) shouldn't be left of table - tolerance
if x0 < expanded_x0:
invalid_reasons["x0_left_of_table"] += 1
continue
# Check for inverted coordinates
if x0 >= x1 or y0 >= y1:
invalid_reasons["inverted_coords"] += 1
continue
# Check cell height is reasonable (at least 8px for readable text)
cell_height = y1 - y0
if cell_height < 8:
invalid_reasons["too_small"] += 1
continue
valid_cells.append(CellBox(
x0=x0, y0=y0, x1=x1, y1=y1,
original_index=idx
))
stats = {
"total": len(cell_boxes),
"valid": len(valid_cells),
"invalid": len(cell_boxes) - len(valid_cells),
"invalid_reasons": dict(invalid_reasons),
"validity_ratio": len(valid_cells) / len(cell_boxes) if cell_boxes else 0
}
logger.info(
f"Cell box validation: {stats['valid']}/{stats['total']} valid "
f"(ratio={stats['validity_ratio']:.2%})"
)
if invalid_reasons:
logger.debug(f"Invalid reasons: {dict(invalid_reasons)}")
return valid_cells, stats
def parse_raw_ocr_regions(
self,
raw_regions: List[Dict[str, Any]],
table_bbox: List[float]
) -> List[OCRTextRegion]:
"""
Parse raw OCR regions and filter to those within/near table bbox.
Args:
raw_regions: List of raw OCR region dicts with 'text', 'bbox', 'confidence'
table_bbox: Table bounding box [x0, y0, x1, y1]
Returns:
List of OCRTextRegion objects within table area
"""
if not raw_regions or len(table_bbox) < 4:
return []
table_x0, table_y0, table_x1, table_y1 = table_bbox[:4]
# Expand table area slightly to catch edge text
margin = 10
result = []
for region in raw_regions:
text = region.get('text', '').strip()
if not text:
continue
bbox = region.get('bbox', [])
confidence = region.get('confidence', 1.0)
# Parse bbox (handle both nested and flat formats)
if not bbox:
continue
if isinstance(bbox[0], (list, tuple)):
# Nested format: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
xs = [pt[0] for pt in bbox if len(pt) >= 2]
ys = [pt[1] for pt in bbox if len(pt) >= 2]
if xs and ys:
x0, y0, x1, y1 = min(xs), min(ys), max(xs), max(ys)
else:
continue
elif len(bbox) == 4:
x0, y0, x1, y1 = bbox
else:
continue
# Check if region overlaps with table area
if (x1 < table_x0 - margin or x0 > table_x1 + margin or
y1 < table_y0 - margin or y0 > table_y1 + margin):
continue
result.append(OCRTextRegion(
text=text,
x0=float(x0), y0=float(y0),
x1=float(x1), y1=float(y1),
confidence=confidence
))
logger.debug(f"Parsed {len(result)} OCR regions within table area")
return result
def cluster_cells_into_grid(
self,
cells: List[CellBox]
) -> Tuple[List[float], List[float], Dict[Tuple[int, int], CellBox]]:
"""
Cluster cells into rows and columns based on coordinates.
Args:
cells: List of validated CellBox objects
Returns:
Tuple of (row_boundaries, col_boundaries, cell_grid)
- row_boundaries: Y coordinates for row divisions
- col_boundaries: X coordinates for column divisions
- cell_grid: Dict mapping (row, col) to CellBox
"""
if not cells:
return [], [], {}
# Collect all unique Y boundaries (top and bottom of cells)
y_coords = set()
x_coords = set()
for cell in cells:
y_coords.add(round(cell.y0, 1))
y_coords.add(round(cell.y1, 1))
x_coords.add(round(cell.x0, 1))
x_coords.add(round(cell.x1, 1))
# Cluster nearby coordinates
row_boundaries = self._cluster_coordinates(sorted(y_coords), self.row_clustering_threshold)
col_boundaries = self._cluster_coordinates(sorted(x_coords), self.col_clustering_threshold)
logger.debug(f"Found {len(row_boundaries)} row boundaries, {len(col_boundaries)} col boundaries")
# Map cells to grid positions
cell_grid = {}
for cell in cells:
# Find row (based on cell's top Y coordinate)
row = self._find_position(cell.y0, row_boundaries)
# Find column (based on cell's left X coordinate)
col = self._find_position(cell.x0, col_boundaries)
if row is not None and col is not None:
# Check for span (if cell extends across multiple rows/cols)
row_end = self._find_position(cell.y1, row_boundaries)
col_end = self._find_position(cell.x1, col_boundaries)
# Store with potential span info
if (row, col) not in cell_grid:
cell_grid[(row, col)] = cell
return row_boundaries, col_boundaries, cell_grid
def _cluster_coordinates(
self,
coords: List[float],
threshold: float
) -> List[float]:
"""Cluster nearby coordinates into distinct values."""
if not coords:
return []
clustered = [coords[0]]
for coord in coords[1:]:
if coord - clustered[-1] > threshold:
clustered.append(coord)
return clustered
def _find_position(
self,
value: float,
boundaries: List[float]
) -> Optional[int]:
"""Find which position (index) a value falls into."""
for i, boundary in enumerate(boundaries):
if value <= boundary + self.row_clustering_threshold:
return i
return len(boundaries) - 1 if boundaries else None
def assign_ocr_to_cells(
self,
cells: List[CellBox],
ocr_regions: List[OCRTextRegion],
row_boundaries: List[float],
col_boundaries: List[float]
) -> Dict[Tuple[int, int], List[OCRTextRegion]]:
"""
Assign OCR text regions to cells based on spatial overlap.
Args:
cells: List of validated CellBox objects
ocr_regions: List of OCRTextRegion objects
row_boundaries: Y coordinates for row divisions
col_boundaries: X coordinates for column divisions
Returns:
Dict mapping (row, col) to list of OCR regions in that cell
"""
cell_ocr_map: Dict[Tuple[int, int], List[OCRTextRegion]] = defaultdict(list)
for ocr in ocr_regions:
best_cell = None
best_overlap = 0
for cell in cells:
overlap = self._calculate_overlap_ratio(
(ocr.x0, ocr.y0, ocr.x1, ocr.y1),
(cell.x0, cell.y0, cell.x1, cell.y1)
)
if overlap > best_overlap and overlap >= self.min_text_coverage:
best_overlap = overlap
best_cell = cell
if best_cell:
row = self._find_position(best_cell.y0, row_boundaries)
col = self._find_position(best_cell.x0, col_boundaries)
if row is not None and col is not None:
cell_ocr_map[(row, col)].append(ocr)
return cell_ocr_map
def _calculate_overlap_ratio(
self,
box1: Tuple[float, float, float, float],
box2: Tuple[float, float, float, float]
) -> float:
"""Calculate overlap ratio of box1 with box2."""
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
# Calculate intersection
inter_x0 = max(x0_1, x0_2)
inter_y0 = max(y0_1, y0_2)
inter_x1 = min(x1_1, x1_2)
inter_y1 = min(y1_1, y1_2)
if inter_x0 >= inter_x1 or inter_y0 >= inter_y1:
return 0.0
inter_area = (inter_x1 - inter_x0) * (inter_y1 - inter_y0)
box1_area = (x1_1 - x0_1) * (y1_1 - y0_1)
return inter_area / box1_area if box1_area > 0 else 0.0
def rebuild_table(
self,
cell_boxes: List[List[float]],
table_bbox: List[float],
raw_ocr_regions: List[Dict[str, Any]],
original_html: str = ""
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Rebuild table content from cell_boxes and raw OCR regions.
This is the main entry point. It:
1. Validates cell_boxes
2. If validity ratio is low, uses pure OCR-based rebuild
3. Otherwise, uses cell_boxes + OCR hybrid rebuild
Args:
cell_boxes: List of cell bounding boxes from PP-StructureV3
table_bbox: Table bounding box [x0, y0, x1, y1]
raw_ocr_regions: List of raw OCR region dicts
original_html: Original HTML from PP-StructureV3 (for fallback)
Returns:
Tuple of (rebuilt_table_dict, rebuild_stats)
"""
stats = {
"action": "none",
"reason": "",
"original_cell_count": len(cell_boxes) if cell_boxes else 0,
"valid_cell_count": 0,
"ocr_regions_in_table": 0,
"rebuilt_rows": 0,
"rebuilt_cols": 0
}
# Step 1: Validate cell_boxes
valid_cells, validation_stats = self.validate_cell_boxes(cell_boxes, table_bbox)
stats["valid_cell_count"] = validation_stats["valid"]
stats["validation"] = validation_stats
# Step 2: Parse raw OCR regions in table area
ocr_regions = self.parse_raw_ocr_regions(raw_ocr_regions, table_bbox)
stats["ocr_regions_in_table"] = len(ocr_regions)
if not ocr_regions:
stats["action"] = "skip"
stats["reason"] = "no_ocr_regions_in_table"
return None, stats
# Step 3: Choose rebuild strategy based on cell_boxes validity
# If validity ratio is too low (< 50%), use pure OCR-based rebuild
if validation_stats["validity_ratio"] < 0.5 or len(valid_cells) < 2:
logger.info(
f"Using pure OCR-based rebuild (validity={validation_stats['validity_ratio']:.2%})"
)
return self._rebuild_from_ocr_only(ocr_regions, table_bbox, stats)
# Otherwise, use hybrid cell_boxes + OCR rebuild
return self._rebuild_with_cell_boxes(valid_cells, ocr_regions, stats, table_bbox)
def _rebuild_from_ocr_only(
self,
ocr_regions: List[OCRTextRegion],
table_bbox: List[float],
stats: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Rebuild table using only OCR regions (when cell_boxes are unreliable).
Strategy:
1. Detect column boundary from OCR x-coordinates
2. Cluster OCR regions by Y coordinate into rows
3. Split each row into left/right columns
"""
if not ocr_regions:
stats["action"] = "skip"
stats["reason"] = "no_ocr_regions"
return None, stats
# Get table bounds
table_x0, table_y0, table_x1, table_y1 = table_bbox[:4]
table_width = table_x1 - table_x0
# Step 1: Detect column split point by analyzing x-coordinates
# Look for the gap between left column (x0 < 250) and right column (x0 >= 250)
col_split_x = self._detect_column_split(ocr_regions, table_bbox)
logger.debug(f"Detected column split at x={col_split_x}")
# Step 2: Cluster OCR regions by Y coordinate into rows
# Use smaller threshold (12px) to properly separate rows
row_threshold = 12.0
sorted_ocr = sorted(ocr_regions, key=lambda r: r.center_y)
rows = []
current_row = [sorted_ocr[0]]
for ocr in sorted_ocr[1:]:
if ocr.center_y - current_row[-1].center_y <= row_threshold:
current_row.append(ocr)
else:
rows.append(current_row)
current_row = [ocr]
rows.append(current_row)
logger.debug(f"Detected {len(rows)} rows")
# Step 3: Analyze column structure
left_regions = [r for r in ocr_regions if r.x0 < col_split_x]
right_regions = [r for r in ocr_regions if r.x0 >= col_split_x]
num_cols = 2 if len(left_regions) >= 2 and len(right_regions) >= 2 else 1
# Step 4: Build cells for each row
rebuilt_cells = []
for row_idx, row_ocrs in enumerate(rows):
row_ocrs_sorted = sorted(row_ocrs, key=lambda r: r.center_x)
if num_cols == 2:
# Split into left and right columns using x0
left_ocrs = [r for r in row_ocrs_sorted if r.x0 < col_split_x]
right_ocrs = [r for r in row_ocrs_sorted if r.x0 >= col_split_x]
# Left column cell
if left_ocrs:
left_content = " ".join(r.text for r in left_ocrs)
left_bbox = [
min(r.x0 for r in left_ocrs),
min(r.y0 for r in left_ocrs),
max(r.x1 for r in left_ocrs),
max(r.y1 for r in left_ocrs)
]
rebuilt_cells.append({
"row": row_idx,
"col": 0,
"row_span": 1,
"col_span": 1,
"content": left_content,
"bbox": left_bbox
})
# Right column cell
if right_ocrs:
right_content = " ".join(r.text for r in right_ocrs)
right_bbox = [
min(r.x0 for r in right_ocrs),
min(r.y0 for r in right_ocrs),
max(r.x1 for r in right_ocrs),
max(r.y1 for r in right_ocrs)
]
rebuilt_cells.append({
"row": row_idx,
"col": 1,
"row_span": 1,
"col_span": 1,
"content": right_content,
"bbox": right_bbox
})
else:
# Single column - merge all OCR in row
row_content = " ".join(r.text for r in row_ocrs_sorted)
row_bbox = [
min(r.x0 for r in row_ocrs_sorted),
min(r.y0 for r in row_ocrs_sorted),
max(r.x1 for r in row_ocrs_sorted),
max(r.y1 for r in row_ocrs_sorted)
]
rebuilt_cells.append({
"row": row_idx,
"col": 0,
"row_span": 1,
"col_span": 1,
"content": row_content,
"bbox": row_bbox
})
num_rows = len(rows)
stats["rebuilt_rows"] = num_rows
stats["rebuilt_cols"] = num_cols
# Build result
rebuilt_table = {
"rows": num_rows,
"cols": num_cols,
"cells": rebuilt_cells,
"html": self._generate_html(rebuilt_cells, num_rows, num_cols),
"rebuild_source": "pure_ocr"
}
stats["action"] = "rebuilt"
stats["reason"] = "pure_ocr_success"
stats["rebuilt_cell_count"] = len(rebuilt_cells)
logger.info(
f"Table rebuilt (pure OCR): {num_rows}x{num_cols} with {len(rebuilt_cells)} cells"
)
return rebuilt_table, stats
def _detect_column_split(
self,
ocr_regions: List[OCRTextRegion],
table_bbox: List[float]
) -> float:
"""
Detect the column split point by analyzing x-coordinates.
For tables with left/right structure (e.g., property-value tables),
there's usually a gap between left column text and right column text.
"""
if not ocr_regions:
return (table_bbox[0] + table_bbox[2]) / 2
# Collect all x0 values (left edge of each text region)
x0_values = sorted(set(round(r.x0) for r in ocr_regions))
if len(x0_values) < 2:
return (table_bbox[0] + table_bbox[2]) / 2
# Find the largest gap between consecutive x0 values
# This usually indicates the column boundary
max_gap = 0
split_point = (table_bbox[0] + table_bbox[2]) / 2
for i in range(len(x0_values) - 1):
gap = x0_values[i + 1] - x0_values[i]
if gap > max_gap and gap > 50: # Require minimum 50px gap
max_gap = gap
split_point = (x0_values[i] + x0_values[i + 1]) / 2
# If no clear gap found, use table center
if max_gap < 50:
split_point = (table_bbox[0] + table_bbox[2]) / 2
return split_point
def _rebuild_with_cell_boxes(
self,
valid_cells: List[CellBox],
ocr_regions: List[OCRTextRegion],
stats: Dict[str, Any],
table_bbox: Optional[List[float]] = None
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Rebuild table using cell_boxes structure + OCR content."""
# Step 3: Cluster cells into grid
row_boundaries, col_boundaries, cell_grid = self.cluster_cells_into_grid(valid_cells)
num_rows = len(row_boundaries) - 1 if len(row_boundaries) > 1 else 1
num_cols = len(col_boundaries) - 1 if len(col_boundaries) > 1 else 1
# Quality check: if hybrid produces too many columns or sparse grid, fall back to pure OCR
# A well-formed table typically has 2-5 columns. Too many columns indicates poor clustering.
total_expected_cells = num_rows * num_cols
if num_cols > 5 or total_expected_cells > 100:
logger.info(
f"Hybrid mode produced {num_rows}x{num_cols} grid (too sparse), "
f"falling back to pure OCR mode"
)
if table_bbox:
return self._rebuild_from_ocr_only(ocr_regions, table_bbox, stats)
stats["rebuilt_rows"] = num_rows
stats["rebuilt_cols"] = num_cols
# Step 4: Assign OCR text to cells
cell_ocr_map = self.assign_ocr_to_cells(
valid_cells, ocr_regions, row_boundaries, col_boundaries
)
# Step 5: Build rebuilt cells
rebuilt_cells = []
for (row, col), ocr_list in cell_ocr_map.items():
# Sort OCR regions by position (top to bottom, left to right)
sorted_ocr = sorted(ocr_list, key=lambda r: (r.center_y, r.center_x))
content = " ".join(r.text for r in sorted_ocr)
# Find the cell bbox for this position
cell_bbox = None
for cell in valid_cells:
cell_row = self._find_position(cell.y0, row_boundaries)
cell_col = self._find_position(cell.x0, col_boundaries)
if cell_row == row and cell_col == col:
cell_bbox = [cell.x0, cell.y0, cell.x1, cell.y1]
break
rebuilt_cells.append({
"row": row,
"col": col,
"row_span": 1,
"col_span": 1,
"content": content,
"bbox": cell_bbox
})
# Quality check: if too few cells have content compared to grid size, fall back to pure OCR
content_ratio = len(rebuilt_cells) / total_expected_cells if total_expected_cells > 0 else 0
if content_ratio < 0.3 and table_bbox:
logger.info(
f"Hybrid mode has low content ratio ({content_ratio:.2%}), "
f"falling back to pure OCR mode"
)
return self._rebuild_from_ocr_only(ocr_regions, table_bbox, stats)
# Build result
rebuilt_table = {
"rows": num_rows,
"cols": num_cols,
"cells": rebuilt_cells,
"html": self._generate_html(rebuilt_cells, num_rows, num_cols),
"rebuild_source": "cell_boxes_hybrid"
}
stats["action"] = "rebuilt"
stats["reason"] = "hybrid_success"
stats["rebuilt_cell_count"] = len(rebuilt_cells)
logger.info(
f"Table rebuilt (hybrid): {num_rows}x{num_cols} with {len(rebuilt_cells)} cells "
f"(from {len(ocr_regions)} OCR regions)"
)
return rebuilt_table, stats
def _generate_html(
self,
cells: List[Dict[str, Any]],
num_rows: int,
num_cols: int
) -> str:
"""Generate HTML table from rebuilt cells."""
# Create grid
grid = [[None for _ in range(num_cols)] for _ in range(num_rows)]
for cell in cells:
row, col = cell["row"], cell["col"]
if 0 <= row < num_rows and 0 <= col < num_cols:
grid[row][col] = cell["content"]
# Build HTML
html_parts = ["<html><body><table>"]
for row_idx in range(num_rows):
html_parts.append("<tr>")
for col_idx in range(num_cols):
content = grid[row_idx][col_idx] or ""
tag = "th" if row_idx == 0 else "td"
html_parts.append(f"<{tag}>{content}</{tag}>")
html_parts.append("</tr>")
html_parts.append("</table></body></html>")
return "".join(html_parts)
def should_rebuild(
self,
cell_boxes: List[List[float]],
table_bbox: List[float],
original_html: str = ""
) -> Tuple[bool, str]:
"""
Determine if table should be rebuilt based on cell_boxes validity.
Args:
cell_boxes: List of cell bounding boxes
table_bbox: Table bounding box
original_html: Original HTML from PP-StructureV3
Returns:
Tuple of (should_rebuild, reason)
"""
if not cell_boxes:
return False, "no_cell_boxes"
_, validation_stats = self.validate_cell_boxes(cell_boxes, table_bbox)
# Always rebuild if ANY cells are invalid - PP-Structure HTML often merges cells incorrectly
# even when most cell_boxes are valid
if validation_stats["invalid"] > 0:
return True, f"invalid_cells_{validation_stats['invalid']}/{validation_stats['total']}"
# Rebuild if there are boundary violations
invalid_reasons = validation_stats.get("invalid_reasons", {})
boundary_violations = (
invalid_reasons.get("y1_exceeds_table", 0) +
invalid_reasons.get("y0_above_table", 0) +
invalid_reasons.get("x1_exceeds_table", 0) +
invalid_reasons.get("x0_left_of_table", 0)
)
if boundary_violations > 0:
return True, f"boundary_violations_{boundary_violations}"
# Also rebuild to ensure OCR-based content is used instead of PP-Structure HTML
# PP-Structure's HTML often has incorrect cell merging
return True, "ocr_content_preferred"

View File

@@ -15,6 +15,8 @@ from typing import Dict, List, Optional, Set, Tuple
from reportlab.pdfgen import canvas
from reportlab.lib.colors import black
from app.utils.bbox_utils import normalize_bbox
logger = logging.getLogger(__name__)
@@ -162,6 +164,7 @@ class TextRegionRenderer:
def get_bbox_as_rect(self, bbox: List[List[float]]) -> Tuple[float, float, float, float]:
"""
Convert quadrilateral bbox to axis-aligned rectangle (x0, y0, x1, y1).
Uses shared bbox utility.
Args:
bbox: List of 4 [x, y] coordinate pairs
@@ -169,12 +172,8 @@ class TextRegionRenderer:
Returns:
Tuple of (x0, y0, x1, y1) - min/max coordinates
"""
if len(bbox) < 4:
return (0.0, 0.0, 0.0, 0.0)
x_coords = [p[0] for p in bbox]
y_coords = [p[1] for p in bbox]
return (min(x_coords), min(y_coords), max(x_coords), max(y_coords))
result = normalize_bbox(bbox)
return result if result else (0.0, 0.0, 0.0, 0.0)
def get_bbox_left_baseline(
self,
@@ -646,19 +645,26 @@ def load_raw_ocr_regions(result_dir: str, task_id: str, page_num: int) -> List[D
from pathlib import Path
import json
# Construct filename pattern
filename = f"{task_id}_edit_page_{page_num}_raw_ocr_regions.json"
file_path = Path(result_dir) / filename
result_path = Path(result_dir)
if not file_path.exists():
logger.warning(f"Raw OCR regions file not found: {file_path}")
return []
# Use glob pattern to find raw OCR regions file
# Filename format: {task_id}_{original_filename}_page_{page_num}_raw_ocr_regions.json
# The original_filename varies based on uploaded file (e.g., scan, document, etc.)
glob_pattern = f"{task_id}_*_page_{page_num}_raw_ocr_regions.json"
matching_files = list(result_path.glob(glob_pattern))
try:
with open(file_path, 'r', encoding='utf-8') as f:
regions = json.load(f)
logger.info(f"Loaded {len(regions)} raw OCR regions from {filename}")
return regions
except Exception as e:
logger.error(f"Failed to load raw OCR regions: {e}")
return []
if matching_files:
# Use the first matching file (there should only be one per page)
file_path = matching_files[0]
try:
with open(file_path, 'r', encoding='utf-8') as f:
regions = json.load(f)
logger.info(f"Loaded {len(regions)} raw OCR regions from {file_path.name}")
return regions
except Exception as e:
logger.error(f"Failed to load raw OCR regions from {file_path}: {e}")
return []
logger.warning(f"Raw OCR regions file not found for task {task_id} page {page_num}. "
f"Glob pattern: {glob_pattern}")
return []

View File

@@ -0,0 +1,5 @@
"""Utility modules for the OCR application."""
from .bbox_utils import normalize_bbox, get_bbox_center, calculate_ioa
__all__ = ['normalize_bbox', 'get_bbox_center', 'calculate_ioa']

View File

@@ -0,0 +1,265 @@
"""
Unified bounding box utilities for consistent bbox handling across services.
Supports multiple bbox formats:
- Nested polygon: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
- Flat rectangle: [x0, y0, x1, y1]
- Flat polygon: [x1, y1, x2, y2, x3, y3, x4, y4]
- Dict format: {"x0": ..., "y0": ..., "x1": ..., "y1": ...}
"""
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
logger = logging.getLogger(__name__)
BboxCoords = Tuple[float, float, float, float] # (x0, y0, x1, y1)
def normalize_bbox(
bbox: Union[Dict, List, Tuple, None]
) -> Optional[BboxCoords]:
"""
Normalize any bbox format to (x0, y0, x1, y1) tuple.
Handles:
- Nested polygon: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
- Flat rectangle: [x0, y0, x1, y1]
- Flat polygon: [x1, y1, x2, y2, x3, y3, x4, y4]
- Dict format: {"x0": ..., "y0": ..., "x1": ..., "y1": ...}
Args:
bbox: Bounding box in any supported format
Returns:
Normalized (x0, y0, x1, y1) tuple or None if invalid
"""
if bbox is None:
return None
try:
# Dict format: {"x0": ..., "y0": ..., "x1": ..., "y1": ...}
if isinstance(bbox, dict):
if 'x0' in bbox and 'y0' in bbox and 'x1' in bbox and 'y1' in bbox:
return (
float(bbox['x0']),
float(bbox['y0']),
float(bbox['x1']),
float(bbox['y1'])
)
# Alternative dict keys
if 'x_min' in bbox or 'y_min' in bbox:
return (
float(bbox.get('x_min', bbox.get('x0', 0))),
float(bbox.get('y_min', bbox.get('y0', 0))),
float(bbox.get('x_max', bbox.get('x1', 0))),
float(bbox.get('y_max', bbox.get('y1', 0)))
)
logger.warning(f"Dict bbox missing required fields: {bbox}")
return None
# List/tuple formats
if isinstance(bbox, (list, tuple)):
if len(bbox) == 0:
return None
# Nested polygon format: [[x1,y1], [x2,y2], ...]
if isinstance(bbox[0], (list, tuple)):
xs = [pt[0] for pt in bbox if len(pt) >= 2]
ys = [pt[1] for pt in bbox if len(pt) >= 2]
if xs and ys:
return (
float(min(xs)),
float(min(ys)),
float(max(xs)),
float(max(ys))
)
return None
# Flat rectangle: [x0, y0, x1, y1]
if len(bbox) == 4:
return (
float(bbox[0]),
float(bbox[1]),
float(bbox[2]),
float(bbox[3])
)
# Flat polygon: [x1, y1, x2, y2, x3, y3, x4, y4, ...]
if len(bbox) >= 8 and len(bbox) % 2 == 0:
xs = [bbox[i] for i in range(0, len(bbox), 2)]
ys = [bbox[i] for i in range(1, len(bbox), 2)]
return (
float(min(xs)),
float(min(ys)),
float(max(xs)),
float(max(ys))
)
logger.warning(f"Unknown bbox format: {type(bbox).__name__}, value: {bbox}")
return None
except (TypeError, ValueError, IndexError) as e:
logger.warning(f"Failed to normalize bbox {bbox}: {e}")
return None
def get_bbox_center(bbox: Union[Dict, List, Tuple, BboxCoords, None]) -> Optional[Tuple[float, float]]:
"""
Get the center point of a bounding box.
Args:
bbox: Bounding box in any supported format or already normalized
Returns:
(center_x, center_y) tuple or None if invalid
"""
# If already normalized tuple
if isinstance(bbox, tuple) and len(bbox) == 4:
x0, y0, x1, y1 = bbox
return ((x0 + x1) / 2, (y0 + y1) / 2)
# Otherwise normalize first
coords = normalize_bbox(bbox)
if coords is None:
return None
x0, y0, x1, y1 = coords
return ((x0 + x1) / 2, (y0 + y1) / 2)
def get_bbox_area(bbox: Union[Dict, List, Tuple, BboxCoords, None]) -> float:
"""
Calculate the area of a bounding box.
Args:
bbox: Bounding box in any supported format
Returns:
Area in square pixels, 0 if invalid
"""
coords = normalize_bbox(bbox) if not (isinstance(bbox, tuple) and len(bbox) == 4) else bbox
if coords is None:
return 0.0
x0, y0, x1, y1 = coords
return max(0, x1 - x0) * max(0, y1 - y0)
def calculate_ioa(
inner_bbox: Union[Dict, List, Tuple, BboxCoords, None],
outer_bbox: Union[Dict, List, Tuple, BboxCoords, None]
) -> float:
"""
Calculate Intersection over Area (IoA) of inner bbox with respect to outer bbox.
IoA = intersection_area / inner_area
Args:
inner_bbox: The bbox to check (numerator area)
outer_bbox: The reference bbox
Returns:
IoA ratio (0.0 to 1.0), 0.0 if either bbox is invalid
"""
inner_coords = normalize_bbox(inner_bbox) if not (isinstance(inner_bbox, tuple) and len(inner_bbox) == 4) else inner_bbox
outer_coords = normalize_bbox(outer_bbox) if not (isinstance(outer_bbox, tuple) and len(outer_bbox) == 4) else outer_bbox
if inner_coords is None or outer_coords is None:
return 0.0
inner_x0, inner_y0, inner_x1, inner_y1 = inner_coords
outer_x0, outer_y0, outer_x1, outer_y1 = outer_coords
# Calculate intersection
inter_x0 = max(inner_x0, outer_x0)
inter_y0 = max(inner_y0, outer_y0)
inter_x1 = min(inner_x1, outer_x1)
inter_y1 = min(inner_y1, outer_y1)
if inter_x1 <= inter_x0 or inter_y1 <= inter_y0:
return 0.0
intersection_area = (inter_x1 - inter_x0) * (inter_y1 - inter_y0)
inner_area = (inner_x1 - inner_x0) * (inner_y1 - inner_y0)
if inner_area <= 0:
return 0.0
return intersection_area / inner_area
def calculate_iou(
bbox1: Union[Dict, List, Tuple, BboxCoords, None],
bbox2: Union[Dict, List, Tuple, BboxCoords, None]
) -> float:
"""
Calculate Intersection over Union (IoU) of two bounding boxes.
Args:
bbox1: First bounding box
bbox2: Second bounding box
Returns:
IoU ratio (0.0 to 1.0), 0.0 if either bbox is invalid
"""
coords1 = normalize_bbox(bbox1) if not (isinstance(bbox1, tuple) and len(bbox1) == 4) else bbox1
coords2 = normalize_bbox(bbox2) if not (isinstance(bbox2, tuple) and len(bbox2) == 4) else bbox2
if coords1 is None or coords2 is None:
return 0.0
x0_1, y0_1, x1_1, y1_1 = coords1
x0_2, y0_2, x1_2, y1_2 = coords2
# Calculate intersection
inter_x0 = max(x0_1, x0_2)
inter_y0 = max(y0_1, y0_2)
inter_x1 = min(x1_1, x1_2)
inter_y1 = min(y1_1, y1_2)
if inter_x1 <= inter_x0 or inter_y1 <= inter_y0:
return 0.0
intersection_area = (inter_x1 - inter_x0) * (inter_y1 - inter_y0)
area1 = (x1_1 - x0_1) * (y1_1 - y0_1)
area2 = (x1_2 - x0_2) * (y1_2 - y0_2)
union_area = area1 + area2 - intersection_area
if union_area <= 0:
return 0.0
return intersection_area / union_area
def is_bbox_inside(
inner_bbox: Union[Dict, List, Tuple, BboxCoords, None],
outer_bbox: Union[Dict, List, Tuple, BboxCoords, None],
tolerance: float = 0.0
) -> bool:
"""
Check if inner_bbox is completely inside outer_bbox (with optional tolerance).
Args:
inner_bbox: The bbox to check
outer_bbox: The containing bbox
tolerance: Allowed overflow in pixels
Returns:
True if inner is inside outer (within tolerance)
"""
inner_coords = normalize_bbox(inner_bbox) if not (isinstance(inner_bbox, tuple) and len(inner_bbox) == 4) else inner_bbox
outer_coords = normalize_bbox(outer_bbox) if not (isinstance(outer_bbox, tuple) and len(outer_bbox) == 4) else outer_bbox
if inner_coords is None or outer_coords is None:
return False
inner_x0, inner_y0, inner_x1, inner_y1 = inner_coords
outer_x0, outer_y0, outer_x1, outer_y1 = outer_coords
return (
inner_x0 >= outer_x0 - tolerance and
inner_y0 >= outer_y0 - tolerance and
inner_x1 <= outer_x1 + tolerance and
inner_y1 <= outer_y1 + tolerance
)