feat: simplify layout model selection and archive proposals
Changes: - Replace PP-Structure 7-slider parameter UI with simple 3-option layout model selector - Add layout model mapping: chinese (PP-DocLayout-S), default (PubLayNet), cdla - Add LayoutModelSelector component and zh-TW translations - Fix "default" model behavior with sentinel value for PubLayNet - Add gap filling service for OCR track coverage improvement - Add PP-Structure debug utilities - Archive completed/incomplete proposals: - add-ocr-track-gap-filling (complete) - fix-ocr-track-table-rendering (incomplete) - simplify-ppstructure-model-selection (22/25 tasks) - Add new layout model tests, archive old PP-Structure param tests - Update OpenSpec ocr-processing spec with layout model requirements 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
649
backend/app/services/gap_filling_service.py
Normal file
649
backend/app/services/gap_filling_service.py
Normal file
@@ -0,0 +1,649 @@
|
||||
"""
|
||||
Gap Filling Service for OCR Track
|
||||
|
||||
This service detects and fills gaps in PP-StructureV3 output by supplementing
|
||||
with Raw OCR text regions when significant content loss is detected.
|
||||
|
||||
The hybrid approach uses Raw OCR's comprehensive text detection to compensate
|
||||
for PP-StructureV3's layout model limitations on certain document types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.models.unified_document import (
|
||||
DocumentElement, BoundingBox, ElementType, Dimensions
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Element types that should NOT be supplemented (preserve structural integrity)
|
||||
SKIP_ELEMENT_TYPES: Set[ElementType] = {
|
||||
ElementType.TABLE,
|
||||
ElementType.IMAGE,
|
||||
ElementType.FIGURE,
|
||||
ElementType.CHART,
|
||||
ElementType.DIAGRAM,
|
||||
ElementType.HEADER,
|
||||
ElementType.FOOTER,
|
||||
ElementType.FORMULA,
|
||||
ElementType.CODE,
|
||||
ElementType.BARCODE,
|
||||
ElementType.QR_CODE,
|
||||
ElementType.LOGO,
|
||||
ElementType.STAMP,
|
||||
ElementType.SIGNATURE,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextRegion:
|
||||
"""Represents a raw OCR text region."""
|
||||
text: str
|
||||
bbox: List[float] # [x0, y0, x1, y1] or polygon format
|
||||
confidence: float
|
||||
page: int = 0
|
||||
|
||||
@property
|
||||
def normalized_bbox(self) -> Tuple[float, float, float, float]:
|
||||
"""Get normalized bbox as (x0, y0, x1, y1)."""
|
||||
if not self.bbox:
|
||||
return (0, 0, 0, 0)
|
||||
|
||||
# Check if bbox is nested list format [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||||
# This is common PaddleOCR polygon format
|
||||
if len(self.bbox) >= 1 and isinstance(self.bbox[0], (list, tuple)):
|
||||
# Nested format: extract all x and y coordinates
|
||||
xs = [pt[0] for pt in self.bbox if len(pt) >= 2]
|
||||
ys = [pt[1] for pt in self.bbox if len(pt) >= 2]
|
||||
if xs and ys:
|
||||
return (min(xs), min(ys), max(xs), max(ys))
|
||||
return (0, 0, 0, 0)
|
||||
|
||||
# Flat format
|
||||
if len(self.bbox) == 4:
|
||||
# Simple [x0, y0, x1, y1] format
|
||||
return (float(self.bbox[0]), float(self.bbox[1]),
|
||||
float(self.bbox[2]), float(self.bbox[3]))
|
||||
elif len(self.bbox) >= 8:
|
||||
# Flat polygon format: [x1, y1, x2, y2, x3, y3, x4, y4]
|
||||
xs = [self.bbox[i] for i in range(0, len(self.bbox), 2)]
|
||||
ys = [self.bbox[i] for i in range(1, len(self.bbox), 2)]
|
||||
return (min(xs), min(ys), max(xs), max(ys))
|
||||
|
||||
return (0, 0, 0, 0)
|
||||
|
||||
@property
|
||||
def center(self) -> Tuple[float, float]:
|
||||
"""Get center point of the bbox."""
|
||||
x0, y0, x1, y1 = self.normalized_bbox
|
||||
return ((x0 + x1) / 2, (y0 + y1) / 2)
|
||||
|
||||
|
||||
class GapFillingService:
|
||||
"""
|
||||
Service for detecting and filling gaps in PP-StructureV3 output.
|
||||
|
||||
This service:
|
||||
1. Calculates coverage of PP-StructureV3 elements over raw OCR regions
|
||||
2. Identifies uncovered raw OCR regions
|
||||
3. Supplements uncovered regions as TEXT elements
|
||||
4. Deduplicates against existing PP-StructureV3 TEXT elements
|
||||
5. Recalculates reading order for the combined result
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
coverage_threshold: float = None,
|
||||
iou_threshold: float = None,
|
||||
confidence_threshold: float = None,
|
||||
dedup_iou_threshold: float = None,
|
||||
enabled: bool = None
|
||||
):
|
||||
"""
|
||||
Initialize the gap filling service.
|
||||
|
||||
Args:
|
||||
coverage_threshold: Coverage ratio below which gap filling activates (default: 0.7)
|
||||
iou_threshold: IoU threshold for coverage detection (default: 0.15)
|
||||
confidence_threshold: Minimum confidence for raw OCR regions (default: 0.3)
|
||||
dedup_iou_threshold: IoU threshold for deduplication (default: 0.5)
|
||||
enabled: Whether gap filling is enabled (default: True)
|
||||
"""
|
||||
self.coverage_threshold = coverage_threshold if coverage_threshold is not None else getattr(
|
||||
settings, 'gap_filling_coverage_threshold', 0.7
|
||||
)
|
||||
self.iou_threshold = iou_threshold if iou_threshold is not None else getattr(
|
||||
settings, 'gap_filling_iou_threshold', 0.15
|
||||
)
|
||||
self.confidence_threshold = confidence_threshold if confidence_threshold is not None else getattr(
|
||||
settings, 'gap_filling_confidence_threshold', 0.3
|
||||
)
|
||||
self.dedup_iou_threshold = dedup_iou_threshold if dedup_iou_threshold is not None else getattr(
|
||||
settings, 'gap_filling_dedup_iou_threshold', 0.5
|
||||
)
|
||||
self.enabled = enabled if enabled is not None else getattr(
|
||||
settings, 'gap_filling_enabled', True
|
||||
)
|
||||
|
||||
def should_activate(
|
||||
self,
|
||||
raw_ocr_regions: List[TextRegion],
|
||||
pp_structure_elements: List[DocumentElement]
|
||||
) -> Tuple[bool, float]:
|
||||
"""
|
||||
Determine if gap filling should be activated.
|
||||
|
||||
Gap filling activates when:
|
||||
1. Coverage ratio is below threshold (default: 70%)
|
||||
2. OR element count disparity is significant
|
||||
|
||||
Args:
|
||||
raw_ocr_regions: List of raw OCR text regions
|
||||
pp_structure_elements: List of PP-StructureV3 elements
|
||||
|
||||
Returns:
|
||||
Tuple of (should_activate, coverage_ratio)
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False, 1.0
|
||||
|
||||
if not raw_ocr_regions:
|
||||
return False, 1.0
|
||||
|
||||
# Calculate coverage
|
||||
covered_count = 0
|
||||
for region in raw_ocr_regions:
|
||||
if self._is_region_covered(region, pp_structure_elements):
|
||||
covered_count += 1
|
||||
|
||||
coverage_ratio = covered_count / len(raw_ocr_regions)
|
||||
|
||||
# Check activation conditions
|
||||
should_activate = coverage_ratio < self.coverage_threshold
|
||||
|
||||
if should_activate:
|
||||
logger.info(
|
||||
f"Gap filling activated: coverage={coverage_ratio:.2%} < threshold={self.coverage_threshold:.0%}, "
|
||||
f"raw_regions={len(raw_ocr_regions)}, pp_elements={len(pp_structure_elements)}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Gap filling not needed: coverage={coverage_ratio:.2%} >= threshold={self.coverage_threshold:.0%}"
|
||||
)
|
||||
|
||||
return should_activate, coverage_ratio
|
||||
|
||||
def find_uncovered_regions(
|
||||
self,
|
||||
raw_ocr_regions: List[TextRegion],
|
||||
pp_structure_elements: List[DocumentElement]
|
||||
) -> List[TextRegion]:
|
||||
"""
|
||||
Find raw OCR regions not covered by PP-StructureV3 elements.
|
||||
|
||||
A region is considered covered if:
|
||||
1. Its center point falls inside any PP-StructureV3 element bbox, OR
|
||||
2. IoU with any PP-StructureV3 element exceeds iou_threshold
|
||||
|
||||
Args:
|
||||
raw_ocr_regions: List of raw OCR text regions
|
||||
pp_structure_elements: List of PP-StructureV3 elements
|
||||
|
||||
Returns:
|
||||
List of uncovered raw OCR regions
|
||||
"""
|
||||
uncovered = []
|
||||
|
||||
for region in raw_ocr_regions:
|
||||
# Skip low confidence regions
|
||||
if region.confidence < self.confidence_threshold:
|
||||
continue
|
||||
|
||||
if not self._is_region_covered(region, pp_structure_elements):
|
||||
uncovered.append(region)
|
||||
|
||||
logger.debug(f"Found {len(uncovered)} uncovered regions out of {len(raw_ocr_regions)}")
|
||||
return uncovered
|
||||
|
||||
def _is_region_covered(
|
||||
self,
|
||||
region: TextRegion,
|
||||
pp_structure_elements: List[DocumentElement]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a raw OCR region is covered by any PP-StructureV3 element.
|
||||
|
||||
Args:
|
||||
region: Raw OCR text region
|
||||
pp_structure_elements: List of PP-StructureV3 elements
|
||||
|
||||
Returns:
|
||||
True if the region is covered
|
||||
"""
|
||||
center_x, center_y = region.center
|
||||
region_bbox = region.normalized_bbox
|
||||
|
||||
for element in pp_structure_elements:
|
||||
elem_bbox = (
|
||||
element.bbox.x0, element.bbox.y0,
|
||||
element.bbox.x1, element.bbox.y1
|
||||
)
|
||||
|
||||
# Check 1: Center point falls inside element bbox
|
||||
if self._point_in_bbox(center_x, center_y, elem_bbox):
|
||||
return True
|
||||
|
||||
# Check 2: IoU exceeds threshold
|
||||
iou = self._calculate_iou(region_bbox, elem_bbox)
|
||||
if iou > self.iou_threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def deduplicate_regions(
|
||||
self,
|
||||
uncovered_regions: List[TextRegion],
|
||||
pp_structure_elements: List[DocumentElement]
|
||||
) -> List[TextRegion]:
|
||||
"""
|
||||
Remove regions that highly overlap with existing PP-StructureV3 TEXT elements.
|
||||
|
||||
Args:
|
||||
uncovered_regions: List of uncovered raw OCR regions
|
||||
pp_structure_elements: List of PP-StructureV3 elements
|
||||
|
||||
Returns:
|
||||
Deduplicated list of regions
|
||||
"""
|
||||
# Get TEXT elements only for deduplication
|
||||
text_elements = [
|
||||
e for e in pp_structure_elements
|
||||
if e.type not in SKIP_ELEMENT_TYPES
|
||||
]
|
||||
|
||||
deduplicated = []
|
||||
for region in uncovered_regions:
|
||||
region_bbox = region.normalized_bbox
|
||||
is_duplicate = False
|
||||
|
||||
for element in text_elements:
|
||||
elem_bbox = (
|
||||
element.bbox.x0, element.bbox.y0,
|
||||
element.bbox.x1, element.bbox.y1
|
||||
)
|
||||
|
||||
iou = self._calculate_iou(region_bbox, elem_bbox)
|
||||
if iou > self.dedup_iou_threshold:
|
||||
logger.debug(
|
||||
f"Skipping duplicate region (IoU={iou:.2f}): '{region.text[:30]}...'"
|
||||
)
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
deduplicated.append(region)
|
||||
|
||||
removed_count = len(uncovered_regions) - len(deduplicated)
|
||||
if removed_count > 0:
|
||||
logger.debug(f"Removed {removed_count} duplicate regions")
|
||||
|
||||
return deduplicated
|
||||
|
||||
def convert_regions_to_elements(
|
||||
self,
|
||||
regions: List[TextRegion],
|
||||
page_number: int,
|
||||
start_element_id: int = 0
|
||||
) -> List[DocumentElement]:
|
||||
"""
|
||||
Convert raw OCR regions to DocumentElement objects.
|
||||
|
||||
Args:
|
||||
regions: List of raw OCR regions to convert
|
||||
page_number: Page number for the elements
|
||||
start_element_id: Starting ID counter for elements
|
||||
|
||||
Returns:
|
||||
List of DocumentElement objects
|
||||
"""
|
||||
elements = []
|
||||
|
||||
for idx, region in enumerate(regions):
|
||||
x0, y0, x1, y1 = region.normalized_bbox
|
||||
|
||||
element = DocumentElement(
|
||||
element_id=f"gap_fill_{page_number}_{start_element_id + idx}",
|
||||
type=ElementType.TEXT,
|
||||
content=region.text,
|
||||
bbox=BoundingBox(x0=x0, y0=y0, x1=x1, y1=y1),
|
||||
confidence=region.confidence,
|
||||
metadata={
|
||||
'source': 'gap_filling',
|
||||
'original_confidence': region.confidence
|
||||
}
|
||||
)
|
||||
elements.append(element)
|
||||
|
||||
return elements
|
||||
|
||||
def recalculate_reading_order(
|
||||
self,
|
||||
elements: List[DocumentElement]
|
||||
) -> List[int]:
|
||||
"""
|
||||
Recalculate reading order for elements based on position.
|
||||
|
||||
Sorts elements by y0 (top to bottom) then x0 (left to right).
|
||||
|
||||
Args:
|
||||
elements: List of DocumentElement objects
|
||||
|
||||
Returns:
|
||||
List of element indices in reading order
|
||||
"""
|
||||
# Create indexed list with position info
|
||||
indexed_elements = [
|
||||
(idx, e.bbox.y0, e.bbox.x0)
|
||||
for idx, e in enumerate(elements)
|
||||
]
|
||||
|
||||
# Sort by y0 then x0
|
||||
indexed_elements.sort(key=lambda x: (x[1], x[2]))
|
||||
|
||||
# Return indices in reading order
|
||||
return [idx for idx, _, _ in indexed_elements]
|
||||
|
||||
def merge_adjacent_regions(
|
||||
self,
|
||||
regions: List[TextRegion],
|
||||
max_horizontal_gap: float = 20.0,
|
||||
max_vertical_gap: float = 5.0
|
||||
) -> List[TextRegion]:
|
||||
"""
|
||||
Merge fragmented adjacent regions on the same line.
|
||||
|
||||
This is optional and can reduce fragmentation from raw OCR.
|
||||
|
||||
Args:
|
||||
regions: List of raw OCR regions
|
||||
max_horizontal_gap: Maximum horizontal gap to merge (pixels)
|
||||
max_vertical_gap: Maximum vertical gap to merge (pixels)
|
||||
|
||||
Returns:
|
||||
List of merged regions
|
||||
"""
|
||||
if not regions:
|
||||
return regions
|
||||
|
||||
# Sort by y0, then x0
|
||||
sorted_regions = sorted(
|
||||
regions,
|
||||
key=lambda r: (r.normalized_bbox[1], r.normalized_bbox[0])
|
||||
)
|
||||
|
||||
merged = []
|
||||
current = sorted_regions[0]
|
||||
|
||||
for next_region in sorted_regions[1:]:
|
||||
curr_bbox = current.normalized_bbox
|
||||
next_bbox = next_region.normalized_bbox
|
||||
|
||||
# Check if on same line (vertical overlap)
|
||||
curr_y_center = (curr_bbox[1] + curr_bbox[3]) / 2
|
||||
next_y_center = (next_bbox[1] + next_bbox[3]) / 2
|
||||
vertical_distance = abs(curr_y_center - next_y_center)
|
||||
|
||||
# Check horizontal gap
|
||||
horizontal_gap = next_bbox[0] - curr_bbox[2]
|
||||
|
||||
if (vertical_distance < max_vertical_gap and
|
||||
0 <= horizontal_gap <= max_horizontal_gap):
|
||||
# Merge regions
|
||||
merged_bbox = [
|
||||
min(curr_bbox[0], next_bbox[0]),
|
||||
min(curr_bbox[1], next_bbox[1]),
|
||||
max(curr_bbox[2], next_bbox[2]),
|
||||
max(curr_bbox[3], next_bbox[3])
|
||||
]
|
||||
current = TextRegion(
|
||||
text=current.text + " " + next_region.text,
|
||||
bbox=merged_bbox,
|
||||
confidence=min(current.confidence, next_region.confidence),
|
||||
page=current.page
|
||||
)
|
||||
else:
|
||||
merged.append(current)
|
||||
current = next_region
|
||||
|
||||
merged.append(current)
|
||||
|
||||
if len(merged) < len(regions):
|
||||
logger.debug(f"Merged {len(regions)} regions into {len(merged)}")
|
||||
|
||||
return merged
|
||||
|
||||
def fill_gaps(
|
||||
self,
|
||||
raw_ocr_regions: List[Dict[str, Any]],
|
||||
pp_structure_elements: List[DocumentElement],
|
||||
page_number: int,
|
||||
ocr_dimensions: Optional[Dict[str, Any]] = None,
|
||||
pp_dimensions: Optional[Dimensions] = None
|
||||
) -> Tuple[List[DocumentElement], Dict[str, Any]]:
|
||||
"""
|
||||
Main entry point: detect gaps and fill with raw OCR regions.
|
||||
|
||||
Args:
|
||||
raw_ocr_regions: Raw OCR results (list of dicts with text, bbox, confidence)
|
||||
pp_structure_elements: PP-StructureV3 elements
|
||||
page_number: Current page number
|
||||
ocr_dimensions: OCR image dimensions for coordinate alignment
|
||||
pp_dimensions: PP-Structure dimensions for coordinate alignment
|
||||
|
||||
Returns:
|
||||
Tuple of (supplemented_elements, statistics)
|
||||
"""
|
||||
statistics = {
|
||||
'enabled': self.enabled,
|
||||
'activated': False,
|
||||
'coverage_ratio': 1.0,
|
||||
'raw_ocr_count': len(raw_ocr_regions),
|
||||
'pp_structure_count': len(pp_structure_elements),
|
||||
'uncovered_count': 0,
|
||||
'deduplicated_count': 0,
|
||||
'supplemented_count': 0
|
||||
}
|
||||
|
||||
if not self.enabled:
|
||||
logger.debug("Gap filling is disabled")
|
||||
return [], statistics
|
||||
|
||||
# Convert raw OCR regions to TextRegion objects
|
||||
text_regions = self._convert_raw_ocr_regions(
|
||||
raw_ocr_regions, page_number, ocr_dimensions, pp_dimensions
|
||||
)
|
||||
|
||||
if not text_regions:
|
||||
logger.debug("No valid text regions to process")
|
||||
return [], statistics
|
||||
|
||||
# Check if gap filling should activate
|
||||
should_activate, coverage_ratio = self.should_activate(
|
||||
text_regions, pp_structure_elements
|
||||
)
|
||||
statistics['coverage_ratio'] = coverage_ratio
|
||||
statistics['activated'] = should_activate
|
||||
|
||||
if not should_activate:
|
||||
return [], statistics
|
||||
|
||||
# Find uncovered regions
|
||||
uncovered = self.find_uncovered_regions(text_regions, pp_structure_elements)
|
||||
statistics['uncovered_count'] = len(uncovered)
|
||||
|
||||
if not uncovered:
|
||||
logger.debug("No uncovered regions found")
|
||||
return [], statistics
|
||||
|
||||
# Deduplicate against existing TEXT elements
|
||||
deduplicated = self.deduplicate_regions(uncovered, pp_structure_elements)
|
||||
statistics['deduplicated_count'] = len(deduplicated)
|
||||
|
||||
if not deduplicated:
|
||||
logger.debug("All uncovered regions were duplicates")
|
||||
return [], statistics
|
||||
|
||||
# Optional: Merge adjacent regions
|
||||
# merged = self.merge_adjacent_regions(deduplicated)
|
||||
|
||||
# Convert to DocumentElements
|
||||
start_id = len(pp_structure_elements)
|
||||
supplemented = self.convert_regions_to_elements(
|
||||
deduplicated, page_number, start_id
|
||||
)
|
||||
statistics['supplemented_count'] = len(supplemented)
|
||||
|
||||
logger.info(
|
||||
f"Gap filling complete: supplemented {len(supplemented)} elements "
|
||||
f"(coverage: {coverage_ratio:.2%} -> estimated {(coverage_ratio + len(supplemented)/len(text_regions) if text_regions else 0):.2%})"
|
||||
)
|
||||
|
||||
return supplemented, statistics
|
||||
|
||||
def _convert_raw_ocr_regions(
|
||||
self,
|
||||
raw_regions: List[Dict[str, Any]],
|
||||
page_number: int,
|
||||
ocr_dimensions: Optional[Dict[str, Any]] = None,
|
||||
pp_dimensions: Optional[Dimensions] = None
|
||||
) -> List[TextRegion]:
|
||||
"""
|
||||
Convert raw OCR region dicts to TextRegion objects.
|
||||
|
||||
Handles coordinate alignment if dimensions are provided.
|
||||
|
||||
Args:
|
||||
raw_regions: List of raw OCR region dictionaries
|
||||
page_number: Current page number
|
||||
ocr_dimensions: OCR image dimensions
|
||||
pp_dimensions: PP-Structure dimensions
|
||||
|
||||
Returns:
|
||||
List of TextRegion objects
|
||||
"""
|
||||
text_regions = []
|
||||
|
||||
# Calculate scale factors if needed
|
||||
scale_x, scale_y = 1.0, 1.0
|
||||
if ocr_dimensions and pp_dimensions:
|
||||
ocr_width = ocr_dimensions.get('width', 0)
|
||||
ocr_height = ocr_dimensions.get('height', 0)
|
||||
|
||||
if ocr_width > 0 and pp_dimensions.width > 0:
|
||||
scale_x = pp_dimensions.width / ocr_width
|
||||
if ocr_height > 0 and pp_dimensions.height > 0:
|
||||
scale_y = pp_dimensions.height / ocr_height
|
||||
|
||||
if scale_x != 1.0 or scale_y != 1.0:
|
||||
logger.debug(f"Coordinate scaling: x={scale_x:.3f}, y={scale_y:.3f}")
|
||||
|
||||
for region in raw_regions:
|
||||
text = region.get('text', '')
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
|
||||
confidence = region.get('confidence', 0.0)
|
||||
bbox_raw = region.get('bbox', [])
|
||||
|
||||
# Normalize bbox
|
||||
if isinstance(bbox_raw, dict):
|
||||
# Dict format: {x_min, y_min, x_max, y_max}
|
||||
bbox = [
|
||||
bbox_raw.get('x_min', 0),
|
||||
bbox_raw.get('y_min', 0),
|
||||
bbox_raw.get('x_max', 0),
|
||||
bbox_raw.get('y_max', 0)
|
||||
]
|
||||
elif isinstance(bbox_raw, (list, tuple)):
|
||||
bbox = list(bbox_raw)
|
||||
else:
|
||||
continue
|
||||
|
||||
# Apply scaling if needed
|
||||
if scale_x != 1.0 or scale_y != 1.0:
|
||||
# Check if nested list format [[x1,y1], [x2,y2], ...]
|
||||
if len(bbox) >= 1 and isinstance(bbox[0], (list, tuple)):
|
||||
bbox = [
|
||||
[pt[0] * scale_x, pt[1] * scale_y]
|
||||
for pt in bbox if len(pt) >= 2
|
||||
]
|
||||
elif len(bbox) == 4 and not isinstance(bbox[0], (list, tuple)):
|
||||
# Simple [x0, y0, x1, y1] format
|
||||
bbox = [
|
||||
bbox[0] * scale_x, bbox[1] * scale_y,
|
||||
bbox[2] * scale_x, bbox[3] * scale_y
|
||||
]
|
||||
elif len(bbox) >= 8:
|
||||
# Flat polygon format [x1, y1, x2, y2, ...]
|
||||
bbox = [
|
||||
bbox[i] * (scale_x if i % 2 == 0 else scale_y)
|
||||
for i in range(len(bbox))
|
||||
]
|
||||
|
||||
text_regions.append(TextRegion(
|
||||
text=text,
|
||||
bbox=bbox,
|
||||
confidence=confidence,
|
||||
page=page_number
|
||||
))
|
||||
|
||||
return text_regions
|
||||
|
||||
@staticmethod
|
||||
def _point_in_bbox(
|
||||
x: float, y: float,
|
||||
bbox: Tuple[float, float, float, float]
|
||||
) -> bool:
|
||||
"""Check if point (x, y) is inside bbox (x0, y0, x1, y1)."""
|
||||
x0, y0, x1, y1 = bbox
|
||||
return x0 <= x <= x1 and y0 <= y <= y1
|
||||
|
||||
@staticmethod
|
||||
def _calculate_iou(
|
||||
bbox1: Tuple[float, float, float, float],
|
||||
bbox2: Tuple[float, float, float, float]
|
||||
) -> float:
|
||||
"""
|
||||
Calculate Intersection over Union (IoU) of two bboxes.
|
||||
|
||||
Args:
|
||||
bbox1: First bbox (x0, y0, x1, y1)
|
||||
bbox2: Second bbox (x0, y0, x1, y1)
|
||||
|
||||
Returns:
|
||||
IoU value between 0 and 1
|
||||
"""
|
||||
# Calculate intersection
|
||||
x0 = max(bbox1[0], bbox2[0])
|
||||
y0 = max(bbox1[1], bbox2[1])
|
||||
x1 = min(bbox1[2], bbox2[2])
|
||||
y1 = min(bbox1[3], bbox2[3])
|
||||
|
||||
if x1 <= x0 or y1 <= y0:
|
||||
return 0.0
|
||||
|
||||
intersection = (x1 - x0) * (y1 - y0)
|
||||
|
||||
# Calculate union
|
||||
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
||||
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
||||
union = area1 + area2 - intersection
|
||||
|
||||
if union <= 0:
|
||||
return 0.0
|
||||
|
||||
return intersection / union
|
||||
@@ -46,6 +46,19 @@ except ImportError as e:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sentinel value for "use PubLayNet default" - explicitly NO model specification
|
||||
_USE_PUBLAYNET_DEFAULT = "__USE_PUBLAYNET_DEFAULT__"
|
||||
|
||||
# Layout model mapping: user-friendly names to actual model names
|
||||
# - "chinese": PP-DocLayout-S - Best for Chinese documents (forms, contracts, invoices)
|
||||
# - "default": PubLayNet-based default model - Best for English documents
|
||||
# - "cdla": picodet_lcnet_x1_0_fgd_layout_cdla - Alternative for Chinese layout
|
||||
LAYOUT_MODEL_MAPPING = {
|
||||
"chinese": "PP-DocLayout-S",
|
||||
"default": _USE_PUBLAYNET_DEFAULT, # Uses default PubLayNet-based model (no custom model)
|
||||
"cdla": "picodet_lcnet_x1_0_fgd_layout_cdla",
|
||||
}
|
||||
|
||||
|
||||
class OCRService:
|
||||
"""
|
||||
@@ -436,77 +449,45 @@ class OCRService:
|
||||
|
||||
return self.ocr_engines[lang]
|
||||
|
||||
def _ensure_structure_engine(self, custom_params: Optional[Dict[str, any]] = None) -> PPStructureV3:
|
||||
def _ensure_structure_engine(self, layout_model: Optional[str] = None) -> PPStructureV3:
|
||||
"""
|
||||
Get or create PP-Structure engine for layout analysis with GPU support.
|
||||
Supports custom parameters that override default settings.
|
||||
Supports layout model selection for different document types.
|
||||
|
||||
Args:
|
||||
custom_params: Optional dictionary of custom PP-StructureV3 parameters.
|
||||
If provided, creates a new engine instance (not cached).
|
||||
Supported keys: layout_detection_threshold, layout_nms_threshold,
|
||||
layout_merge_bboxes_mode, layout_unclip_ratio, text_det_thresh,
|
||||
text_det_box_thresh, text_det_unclip_ratio
|
||||
layout_model: Layout detection model selection:
|
||||
- "chinese": PP-DocLayout-S (best for Chinese documents)
|
||||
- "default": PubLayNet-based (best for English documents)
|
||||
- "cdla": CDLA model (alternative for Chinese layout)
|
||||
- None: Use config default
|
||||
|
||||
Returns:
|
||||
PPStructure engine instance
|
||||
"""
|
||||
# If custom params provided, create a new engine instance (don't use cache)
|
||||
if custom_params:
|
||||
logger.info(f"Creating PP-StructureV3 engine with custom parameters (GPU: {self.use_gpu})")
|
||||
logger.info(f"Custom params: {custom_params}")
|
||||
# Resolve layout model name from user-friendly name
|
||||
resolved_model_name = None
|
||||
use_publaynet_default = False # Flag to explicitly use PubLayNet default (no model param)
|
||||
|
||||
try:
|
||||
# Base configuration from settings
|
||||
use_chart = settings.enable_chart_recognition
|
||||
use_formula = settings.enable_formula_recognition
|
||||
use_table = settings.enable_table_recognition
|
||||
if layout_model:
|
||||
resolved_model_name = LAYOUT_MODEL_MAPPING.get(layout_model)
|
||||
if layout_model not in LAYOUT_MODEL_MAPPING:
|
||||
logger.warning(f"Unknown layout model '{layout_model}', using config default")
|
||||
resolved_model_name = settings.layout_detection_model_name
|
||||
elif resolved_model_name == _USE_PUBLAYNET_DEFAULT:
|
||||
# User explicitly selected "default" - use PubLayNet without custom model
|
||||
use_publaynet_default = True
|
||||
resolved_model_name = None
|
||||
logger.info(f"Using layout model: {layout_model} -> PubLayNet default (no custom model)")
|
||||
else:
|
||||
logger.info(f"Using layout model: {layout_model} -> {resolved_model_name}")
|
||||
|
||||
# Parameter priority: custom > settings default
|
||||
layout_threshold = custom_params.get('layout_detection_threshold', settings.layout_detection_threshold)
|
||||
layout_nms = custom_params.get('layout_nms_threshold', settings.layout_nms_threshold)
|
||||
layout_merge = custom_params.get('layout_merge_bboxes_mode', settings.layout_merge_mode)
|
||||
layout_unclip = custom_params.get('layout_unclip_ratio', settings.layout_unclip_ratio)
|
||||
text_thresh = custom_params.get('text_det_thresh', settings.text_det_thresh)
|
||||
text_box_thresh = custom_params.get('text_det_box_thresh', settings.text_det_box_thresh)
|
||||
text_unclip = custom_params.get('text_det_unclip_ratio', settings.text_det_unclip_ratio)
|
||||
# Check if we need to recreate the engine due to different model
|
||||
current_model = getattr(self, '_current_layout_model', None)
|
||||
if self.structure_engine is not None and layout_model and layout_model != current_model:
|
||||
logger.info(f"Layout model changed from {current_model} to {layout_model}, recreating engine")
|
||||
self.structure_engine = None # Force recreation
|
||||
|
||||
logger.info(f"PP-StructureV3 config: table={use_table}, formula={use_formula}, chart={use_chart}")
|
||||
logger.info(f"Layout config: threshold={layout_threshold}, nms={layout_nms}, merge={layout_merge}, unclip={layout_unclip}")
|
||||
logger.info(f"Text detection: thresh={text_thresh}, box_thresh={text_box_thresh}, unclip={text_unclip}")
|
||||
|
||||
# Create temporary engine with custom params (not cached)
|
||||
custom_engine = PPStructureV3(
|
||||
use_doc_orientation_classify=False,
|
||||
use_doc_unwarping=False,
|
||||
use_textline_orientation=False,
|
||||
use_table_recognition=use_table,
|
||||
use_formula_recognition=use_formula,
|
||||
use_chart_recognition=use_chart,
|
||||
layout_threshold=layout_threshold,
|
||||
layout_nms=layout_nms,
|
||||
layout_unclip_ratio=layout_unclip,
|
||||
layout_merge_bboxes_mode=layout_merge,
|
||||
text_det_thresh=text_thresh,
|
||||
text_det_box_thresh=text_box_thresh,
|
||||
text_det_unclip_ratio=text_unclip,
|
||||
)
|
||||
|
||||
logger.info(f"PP-StructureV3 engine with custom params ready (PaddlePaddle {paddle.__version__}, {'GPU' if self.use_gpu else 'CPU'} mode)")
|
||||
|
||||
# Check GPU memory after loading
|
||||
if self.use_gpu and settings.enable_memory_optimization:
|
||||
self._check_gpu_memory_usage()
|
||||
|
||||
return custom_engine
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create PP-StructureV3 engine with custom params: {e}")
|
||||
# Fall back to default cached engine
|
||||
logger.warning("Falling back to default cached engine")
|
||||
custom_params = None # Clear custom params to use cached engine
|
||||
|
||||
# Use cached default engine
|
||||
# Use cached engine or create new one
|
||||
if self.structure_engine is None:
|
||||
logger.info(f"Initializing PP-StructureV3 engine (GPU: {self.use_gpu})")
|
||||
|
||||
@@ -524,28 +505,51 @@ class OCRService:
|
||||
text_box_thresh = settings.text_det_box_thresh
|
||||
text_unclip = settings.text_det_unclip_ratio
|
||||
|
||||
# Layout model configuration:
|
||||
# - If use_publaynet_default: don't specify any model (use PubLayNet default)
|
||||
# - If resolved_model_name: use the specified model
|
||||
# - Otherwise: use config default
|
||||
if use_publaynet_default:
|
||||
layout_model_name = None # Explicitly no model = PubLayNet default
|
||||
elif resolved_model_name:
|
||||
layout_model_name = resolved_model_name
|
||||
else:
|
||||
layout_model_name = settings.layout_detection_model_name
|
||||
layout_model_dir = settings.layout_detection_model_dir
|
||||
|
||||
logger.info(f"PP-StructureV3 config: table={use_table}, formula={use_formula}, chart={use_chart}")
|
||||
logger.info(f"Layout model: name={layout_model_name}, dir={layout_model_dir}")
|
||||
logger.info(f"Layout config: threshold={layout_threshold}, nms={layout_nms}, merge={layout_merge}, unclip={layout_unclip}")
|
||||
logger.info(f"Text detection: thresh={text_thresh}, box_thresh={text_box_thresh}, unclip={text_unclip}")
|
||||
|
||||
self.structure_engine = PPStructureV3(
|
||||
use_doc_orientation_classify=False,
|
||||
use_doc_unwarping=False,
|
||||
use_textline_orientation=False,
|
||||
use_table_recognition=use_table,
|
||||
use_formula_recognition=use_formula,
|
||||
use_chart_recognition=use_chart,
|
||||
layout_threshold=layout_threshold,
|
||||
layout_nms=layout_nms,
|
||||
layout_unclip_ratio=layout_unclip,
|
||||
layout_merge_bboxes_mode=layout_merge, # Use 'small' to minimize merging
|
||||
text_det_thresh=text_thresh,
|
||||
text_det_box_thresh=text_box_thresh,
|
||||
text_det_unclip_ratio=text_unclip,
|
||||
)
|
||||
# Build PPStructureV3 kwargs
|
||||
pp_kwargs = {
|
||||
'use_doc_orientation_classify': False,
|
||||
'use_doc_unwarping': False,
|
||||
'use_textline_orientation': False,
|
||||
'use_table_recognition': use_table,
|
||||
'use_formula_recognition': use_formula,
|
||||
'use_chart_recognition': use_chart,
|
||||
'layout_threshold': layout_threshold,
|
||||
'layout_nms': layout_nms,
|
||||
'layout_unclip_ratio': layout_unclip,
|
||||
'layout_merge_bboxes_mode': layout_merge,
|
||||
'text_det_thresh': text_thresh,
|
||||
'text_det_box_thresh': text_box_thresh,
|
||||
'text_det_unclip_ratio': text_unclip,
|
||||
}
|
||||
|
||||
# Add layout model configuration if specified
|
||||
if layout_model_name:
|
||||
pp_kwargs['layout_detection_model_name'] = layout_model_name
|
||||
if layout_model_dir:
|
||||
pp_kwargs['layout_detection_model_dir'] = layout_model_dir
|
||||
|
||||
self.structure_engine = PPStructureV3(**pp_kwargs)
|
||||
|
||||
# Track model loading for cache management
|
||||
self._model_last_used['structure'] = datetime.now()
|
||||
self._current_layout_model = layout_model # Track current model for recreation check
|
||||
|
||||
logger.info(f"PP-StructureV3 engine ready (PaddlePaddle {paddle.__version__}, {'GPU' if self.use_gpu else 'CPU'} mode)")
|
||||
|
||||
@@ -565,17 +569,27 @@ class OCRService:
|
||||
use_formula = settings.enable_formula_recognition
|
||||
use_table = settings.enable_table_recognition
|
||||
layout_threshold = settings.layout_detection_threshold
|
||||
layout_model_name = settings.layout_detection_model_name
|
||||
layout_model_dir = settings.layout_detection_model_dir
|
||||
|
||||
self.structure_engine = PPStructureV3(
|
||||
use_doc_orientation_classify=False,
|
||||
use_doc_unwarping=False,
|
||||
use_textline_orientation=False,
|
||||
use_table_recognition=use_table,
|
||||
use_formula_recognition=use_formula,
|
||||
use_chart_recognition=use_chart,
|
||||
layout_threshold=layout_threshold,
|
||||
)
|
||||
logger.info("PP-StructureV3 engine ready (CPU mode - fallback)")
|
||||
# Build CPU fallback kwargs
|
||||
cpu_kwargs = {
|
||||
'use_doc_orientation_classify': False,
|
||||
'use_doc_unwarping': False,
|
||||
'use_textline_orientation': False,
|
||||
'use_table_recognition': use_table,
|
||||
'use_formula_recognition': use_formula,
|
||||
'use_chart_recognition': use_chart,
|
||||
'layout_threshold': layout_threshold,
|
||||
}
|
||||
if layout_model_name:
|
||||
cpu_kwargs['layout_detection_model_name'] = layout_model_name
|
||||
if layout_model_dir:
|
||||
cpu_kwargs['layout_detection_model_dir'] = layout_model_dir
|
||||
|
||||
self.structure_engine = PPStructureV3(**cpu_kwargs)
|
||||
self._current_layout_model = layout_model # Track current model for recreation check
|
||||
logger.info(f"PP-StructureV3 engine ready (CPU mode - fallback, layout_model={layout_model_name})")
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -813,7 +827,7 @@ class OCRService:
|
||||
confidence_threshold: Optional[float] = None,
|
||||
output_dir: Optional[Path] = None,
|
||||
current_page: int = 0,
|
||||
pp_structure_params: Optional[Dict[str, any]] = None
|
||||
layout_model: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Process single image with OCR and layout analysis
|
||||
@@ -825,7 +839,7 @@ class OCRService:
|
||||
confidence_threshold: Minimum confidence threshold (uses default if None)
|
||||
output_dir: Optional output directory for saving extracted images
|
||||
current_page: Current page number (0-based) for multi-page documents
|
||||
pp_structure_params: Optional custom PP-StructureV3 parameters
|
||||
layout_model: Layout detection model ('chinese', 'default', 'cdla')
|
||||
|
||||
Returns:
|
||||
Dictionary with OCR results and metadata
|
||||
@@ -894,7 +908,7 @@ class OCRService:
|
||||
confidence_threshold=confidence_threshold,
|
||||
output_dir=output_dir,
|
||||
current_page=page_num - 1, # Convert to 0-based page number for layout data
|
||||
pp_structure_params=pp_structure_params
|
||||
layout_model=layout_model
|
||||
)
|
||||
|
||||
# Accumulate results
|
||||
@@ -1040,7 +1054,7 @@ class OCRService:
|
||||
image_path,
|
||||
output_dir=output_dir,
|
||||
current_page=current_page,
|
||||
pp_structure_params=pp_structure_params
|
||||
layout_model=layout_model
|
||||
)
|
||||
|
||||
# Generate Markdown
|
||||
@@ -1078,6 +1092,38 @@ class OCRService:
|
||||
'height': ocr_height
|
||||
}]
|
||||
|
||||
# Generate PP-StructureV3 debug outputs if enabled
|
||||
if settings.pp_structure_debug_enabled and output_dir:
|
||||
try:
|
||||
from app.services.pp_structure_debug import PPStructureDebug
|
||||
debug_service = PPStructureDebug(output_dir)
|
||||
|
||||
# Save raw results as JSON
|
||||
debug_service.save_raw_results(
|
||||
pp_structure_results={
|
||||
'elements': layout_data.get('elements', []),
|
||||
'total_elements': layout_data.get('total_elements', 0),
|
||||
'element_types': layout_data.get('element_types', {}),
|
||||
'reading_order': layout_data.get('reading_order', []),
|
||||
'enhanced': True,
|
||||
'has_parsing_res_list': True
|
||||
},
|
||||
raw_ocr_regions=text_regions,
|
||||
filename_prefix=image_path.stem
|
||||
)
|
||||
|
||||
# Generate visualization if enabled
|
||||
if settings.pp_structure_debug_visualization:
|
||||
debug_service.generate_visualization(
|
||||
image_path=image_path,
|
||||
pp_structure_elements=layout_data.get('elements', []),
|
||||
raw_ocr_regions=text_regions,
|
||||
filename_prefix=image_path.stem
|
||||
)
|
||||
logger.info(f"Generated PP-StructureV3 debug outputs for {image_path.name}")
|
||||
except Exception as debug_error:
|
||||
logger.warning(f"Failed to generate debug outputs: {debug_error}")
|
||||
|
||||
logger.info(
|
||||
f"OCR completed: {image_path.name} - "
|
||||
f"{len(text_regions)} regions, "
|
||||
@@ -1164,7 +1210,7 @@ class OCRService:
|
||||
image_path: Path,
|
||||
output_dir: Optional[Path] = None,
|
||||
current_page: int = 0,
|
||||
pp_structure_params: Optional[Dict[str, any]] = None
|
||||
layout_model: Optional[str] = None
|
||||
) -> Tuple[Optional[Dict], List[Dict]]:
|
||||
"""
|
||||
Analyze document layout using PP-StructureV3 with enhanced element extraction
|
||||
@@ -1173,7 +1219,7 @@ class OCRService:
|
||||
image_path: Path to image file
|
||||
output_dir: Optional output directory for saving extracted images (defaults to image_path.parent)
|
||||
current_page: Current page number (0-based) for multi-page documents
|
||||
pp_structure_params: Optional custom PP-StructureV3 parameters
|
||||
layout_model: Layout detection model ('chinese', 'default', 'cdla')
|
||||
|
||||
Returns:
|
||||
Tuple of (layout_data, images_metadata)
|
||||
@@ -1191,7 +1237,7 @@ class OCRService:
|
||||
f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU'}"
|
||||
)
|
||||
|
||||
structure_engine = self._ensure_structure_engine(pp_structure_params)
|
||||
structure_engine = self._ensure_structure_engine(layout_model)
|
||||
|
||||
# Try enhanced processing first
|
||||
try:
|
||||
@@ -1425,7 +1471,7 @@ class OCRService:
|
||||
confidence_threshold: Optional[float] = None,
|
||||
output_dir: Optional[Path] = None,
|
||||
force_track: Optional[str] = None,
|
||||
pp_structure_params: Optional[Dict[str, any]] = None
|
||||
layout_model: Optional[str] = None
|
||||
) -> Union[UnifiedDocument, Dict]:
|
||||
"""
|
||||
Process document using dual-track approach.
|
||||
@@ -1437,7 +1483,7 @@ class OCRService:
|
||||
confidence_threshold: Minimum confidence threshold
|
||||
output_dir: Optional output directory for extracted images
|
||||
force_track: Force specific track ("ocr" or "direct"), None for auto-detection
|
||||
pp_structure_params: Optional custom PP-StructureV3 parameters (used for OCR track only)
|
||||
layout_model: Layout detection model ('chinese', 'default', 'cdla') (used for OCR track only)
|
||||
|
||||
Returns:
|
||||
UnifiedDocument if dual-track is enabled, Dict otherwise
|
||||
@@ -1445,7 +1491,7 @@ class OCRService:
|
||||
if not self.dual_track_enabled:
|
||||
# Fallback to traditional OCR processing
|
||||
return self.process_file_traditional(
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, layout_model
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
@@ -1517,7 +1563,7 @@ class OCRService:
|
||||
ocr_result = self.process_file_traditional(
|
||||
actual_file_path, lang, detect_layout=True,
|
||||
confidence_threshold=confidence_threshold,
|
||||
output_dir=output_dir, pp_structure_params=pp_structure_params
|
||||
output_dir=output_dir, layout_model=layout_model
|
||||
)
|
||||
|
||||
# Convert OCR result to extract images
|
||||
@@ -1550,7 +1596,7 @@ class OCRService:
|
||||
# Use OCR for scanned documents, images, etc.
|
||||
logger.info("Using OCR track (PaddleOCR)")
|
||||
ocr_result = self.process_file_traditional(
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, layout_model
|
||||
)
|
||||
|
||||
# Convert OCR result to UnifiedDocument using the converter
|
||||
@@ -1580,7 +1626,7 @@ class OCRService:
|
||||
logger.error(f"Error in dual-track processing: {e}")
|
||||
# Fallback to traditional OCR
|
||||
return self.process_file_traditional(
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, layout_model
|
||||
)
|
||||
|
||||
def _merge_ocr_images_into_direct(
|
||||
@@ -1659,7 +1705,7 @@ class OCRService:
|
||||
detect_layout: bool = True,
|
||||
confidence_threshold: Optional[float] = None,
|
||||
output_dir: Optional[Path] = None,
|
||||
pp_structure_params: Optional[Dict[str, any]] = None
|
||||
layout_model: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Traditional OCR processing (legacy method).
|
||||
@@ -1670,7 +1716,7 @@ class OCRService:
|
||||
detect_layout: Whether to perform layout analysis
|
||||
confidence_threshold: Minimum confidence threshold
|
||||
output_dir: Optional output directory
|
||||
pp_structure_params: Optional custom PP-StructureV3 parameters
|
||||
layout_model: Layout detection model ('chinese', 'default', 'cdla')
|
||||
|
||||
Returns:
|
||||
Dictionary with OCR results in legacy format
|
||||
@@ -1683,7 +1729,7 @@ class OCRService:
|
||||
all_results = []
|
||||
for i, image_path in enumerate(image_paths):
|
||||
result = self.process_image(
|
||||
image_path, lang, detect_layout, confidence_threshold, output_dir, i, pp_structure_params
|
||||
image_path, lang, detect_layout, confidence_threshold, output_dir, i, layout_model
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
@@ -1699,7 +1745,7 @@ class OCRService:
|
||||
else:
|
||||
# Single image or other file
|
||||
return self.process_image(
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, 0, pp_structure_params
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, 0, layout_model
|
||||
)
|
||||
|
||||
def _combine_results(self, results: List[Dict]) -> Dict:
|
||||
@@ -1784,7 +1830,7 @@ class OCRService:
|
||||
output_dir: Optional[Path] = None,
|
||||
use_dual_track: bool = True,
|
||||
force_track: Optional[str] = None,
|
||||
pp_structure_params: Optional[Dict[str, any]] = None
|
||||
layout_model: Optional[str] = None
|
||||
) -> Union[UnifiedDocument, Dict]:
|
||||
"""
|
||||
Main processing method with dual-track support.
|
||||
@@ -1797,7 +1843,7 @@ class OCRService:
|
||||
output_dir: Optional output directory
|
||||
use_dual_track: Whether to use dual-track processing (default True)
|
||||
force_track: Force specific track ("ocr" or "direct")
|
||||
pp_structure_params: Optional custom PP-StructureV3 parameters (used for OCR track only)
|
||||
layout_model: Layout detection model ('chinese', 'default', 'cdla') (used for OCR track only)
|
||||
|
||||
Returns:
|
||||
UnifiedDocument if dual-track is enabled and use_dual_track=True,
|
||||
@@ -1809,12 +1855,12 @@ class OCRService:
|
||||
if (use_dual_track or force_track) and self.dual_track_enabled:
|
||||
# Use dual-track processing (or forced track)
|
||||
return self.process_with_dual_track(
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, force_track, pp_structure_params
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, force_track, layout_model
|
||||
)
|
||||
else:
|
||||
# Use traditional OCR processing (no force_track support)
|
||||
return self.process_file_traditional(
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir, layout_model
|
||||
)
|
||||
|
||||
def process_legacy(
|
||||
|
||||
@@ -3,6 +3,9 @@ OCR to UnifiedDocument Converter
|
||||
|
||||
Converts PP-StructureV3 OCR results to UnifiedDocument format, preserving
|
||||
all structure information and metadata.
|
||||
|
||||
Includes gap filling support to supplement PP-StructureV3 output with raw OCR
|
||||
regions when significant content loss is detected.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -16,10 +19,165 @@ from app.models.unified_document import (
|
||||
BoundingBox, StyleInfo, TableData, ElementType,
|
||||
ProcessingTrack, TableCell, Dimensions
|
||||
)
|
||||
from app.services.gap_filling_service import GapFillingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def trim_empty_columns(table_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Remove empty columns from a table dictionary.
|
||||
|
||||
A column is considered empty if ALL cells in that column have content that is
|
||||
empty or whitespace-only (using .strip() to determine emptiness).
|
||||
|
||||
This function:
|
||||
1. Identifies columns where every cell's content is empty/whitespace
|
||||
2. Removes identified empty columns
|
||||
3. Updates cols/columns value
|
||||
4. Recalculates each cell's col index
|
||||
5. Adjusts col_span when spans cross removed columns
|
||||
6. Removes cells entirely when their complete span falls within removed columns
|
||||
7. Preserves original bbox (no layout drift)
|
||||
|
||||
Args:
|
||||
table_dict: Table dictionary with keys: rows, cols/columns, cells
|
||||
|
||||
Returns:
|
||||
Cleaned table dictionary with empty columns removed
|
||||
"""
|
||||
cells = table_dict.get('cells', [])
|
||||
if not cells:
|
||||
return table_dict
|
||||
|
||||
# Get original column count
|
||||
original_cols = table_dict.get('cols', table_dict.get('columns', 0))
|
||||
if original_cols == 0:
|
||||
# Calculate from cells if not provided
|
||||
max_col = 0
|
||||
for cell in cells:
|
||||
cell_col = cell.get('col', 0) if isinstance(cell, dict) else getattr(cell, 'col', 0)
|
||||
cell_span = cell.get('col_span', 1) if isinstance(cell, dict) else getattr(cell, 'col_span', 1)
|
||||
max_col = max(max_col, cell_col + cell_span)
|
||||
original_cols = max_col
|
||||
|
||||
if original_cols == 0:
|
||||
return table_dict
|
||||
|
||||
# Build a map: column_index -> list of cell contents
|
||||
# For cells with col_span > 1, we only check their primary column
|
||||
column_contents: Dict[int, List[str]] = {i: [] for i in range(original_cols)}
|
||||
|
||||
for cell in cells:
|
||||
if isinstance(cell, dict):
|
||||
col = cell.get('col', 0)
|
||||
col_span = cell.get('col_span', 1)
|
||||
content = cell.get('content', '')
|
||||
else:
|
||||
col = getattr(cell, 'col', 0)
|
||||
col_span = getattr(cell, 'col_span', 1)
|
||||
content = getattr(cell, 'content', '')
|
||||
|
||||
# Mark content for each column this cell spans
|
||||
for c in range(col, min(col + col_span, original_cols)):
|
||||
if c in column_contents:
|
||||
column_contents[c].append(str(content).strip() if content else '')
|
||||
|
||||
# Identify empty columns (all content is empty/whitespace)
|
||||
empty_columns = set()
|
||||
for col_idx, contents in column_contents.items():
|
||||
# A column is empty if ALL cells in it have empty content
|
||||
# Note: If a column has no cells at all, it's considered empty
|
||||
if all(c == '' for c in contents):
|
||||
empty_columns.add(col_idx)
|
||||
|
||||
if not empty_columns:
|
||||
# No empty columns to remove, just ensure cols is set
|
||||
result = dict(table_dict)
|
||||
if result.get('cols', result.get('columns', 0)) == 0:
|
||||
result['cols'] = original_cols
|
||||
if 'columns' in result:
|
||||
result['columns'] = original_cols
|
||||
return result
|
||||
|
||||
logger.debug(f"Removing empty columns: {sorted(empty_columns)} from table with {original_cols} cols")
|
||||
|
||||
# Build column mapping: old_col -> new_col (or None if removed)
|
||||
col_mapping: Dict[int, Optional[int]] = {}
|
||||
new_col = 0
|
||||
for old_col in range(original_cols):
|
||||
if old_col in empty_columns:
|
||||
col_mapping[old_col] = None
|
||||
else:
|
||||
col_mapping[old_col] = new_col
|
||||
new_col += 1
|
||||
|
||||
new_cols = new_col
|
||||
|
||||
# Process cells
|
||||
new_cells = []
|
||||
for cell in cells:
|
||||
if isinstance(cell, dict):
|
||||
old_col = cell.get('col', 0)
|
||||
old_col_span = cell.get('col_span', 1)
|
||||
else:
|
||||
old_col = getattr(cell, 'col', 0)
|
||||
old_col_span = getattr(cell, 'col_span', 1)
|
||||
|
||||
# Calculate new col and col_span
|
||||
# Find the first non-removed column in this cell's span
|
||||
new_start_col = None
|
||||
new_end_col = None
|
||||
|
||||
for c in range(old_col, min(old_col + old_col_span, original_cols)):
|
||||
mapped = col_mapping.get(c)
|
||||
if mapped is not None:
|
||||
if new_start_col is None:
|
||||
new_start_col = mapped
|
||||
new_end_col = mapped
|
||||
|
||||
# If entire span falls within removed columns, skip this cell
|
||||
if new_start_col is None:
|
||||
logger.debug(f"Removing cell at row={cell.get('row', 0) if isinstance(cell, dict) else cell.row}, "
|
||||
f"col={old_col} (entire span in removed columns)")
|
||||
continue
|
||||
|
||||
new_col_span = new_end_col - new_start_col + 1
|
||||
|
||||
# Create new cell
|
||||
if isinstance(cell, dict):
|
||||
new_cell = dict(cell)
|
||||
new_cell['col'] = new_start_col
|
||||
new_cell['col_span'] = new_col_span
|
||||
else:
|
||||
# Handle TableCell objects
|
||||
new_cell = {
|
||||
'row': cell.row,
|
||||
'col': new_start_col,
|
||||
'row_span': cell.row_span,
|
||||
'col_span': new_col_span,
|
||||
'content': cell.content
|
||||
}
|
||||
if hasattr(cell, 'bbox') and cell.bbox:
|
||||
new_cell['bbox'] = cell.bbox
|
||||
if hasattr(cell, 'style') and cell.style:
|
||||
new_cell['style'] = cell.style
|
||||
|
||||
new_cells.append(new_cell)
|
||||
|
||||
# Build result
|
||||
result = dict(table_dict)
|
||||
result['cells'] = new_cells
|
||||
result['cols'] = new_cols
|
||||
if 'columns' in result:
|
||||
result['columns'] = new_cols
|
||||
|
||||
logger.info(f"Trimmed table: {original_cols} -> {new_cols} columns, "
|
||||
f"{len(cells)} -> {len(new_cells)} cells")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class OCRToUnifiedConverter:
|
||||
"""
|
||||
Converter for transforming PP-StructureV3 OCR results to UnifiedDocument format.
|
||||
@@ -30,11 +188,19 @@ class OCRToUnifiedConverter:
|
||||
- Multi-page document assembly
|
||||
- Metadata preservation
|
||||
- Structure relationship mapping
|
||||
- Gap filling with raw OCR regions (when PP-StructureV3 misses content)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the converter."""
|
||||
def __init__(self, enable_gap_filling: bool = True):
|
||||
"""
|
||||
Initialize the converter.
|
||||
|
||||
Args:
|
||||
enable_gap_filling: Whether to enable gap filling with raw OCR regions
|
||||
"""
|
||||
self.element_counter = 0
|
||||
self.gap_filling_service = GapFillingService() if enable_gap_filling else None
|
||||
self.gap_filling_stats: Dict[str, Any] = {}
|
||||
|
||||
def convert(
|
||||
self,
|
||||
@@ -120,13 +286,21 @@ class OCRToUnifiedConverter:
|
||||
Extract pages from OCR results.
|
||||
|
||||
Handles both enhanced PP-StructureV3 results (with parsing_res_list)
|
||||
and traditional markdown results.
|
||||
and traditional markdown results. Applies gap filling when enabled.
|
||||
"""
|
||||
pages = []
|
||||
|
||||
# Extract raw OCR text regions for gap filling
|
||||
raw_text_regions = ocr_results.get('text_regions', [])
|
||||
ocr_dimensions = ocr_results.get('ocr_dimensions', {})
|
||||
|
||||
# Check if we have enhanced results from PPStructureEnhanced
|
||||
if 'enhanced_results' in ocr_results:
|
||||
pages = self._extract_from_enhanced_results(ocr_results['enhanced_results'])
|
||||
pages = self._extract_from_enhanced_results(
|
||||
ocr_results['enhanced_results'],
|
||||
raw_text_regions=raw_text_regions,
|
||||
ocr_dimensions=ocr_dimensions
|
||||
)
|
||||
# Check for traditional OCR results with text_regions at top level (from process_file_traditional)
|
||||
elif 'text_regions' in ocr_results:
|
||||
pages = self._extract_from_traditional_ocr(ocr_results)
|
||||
@@ -143,9 +317,21 @@ class OCRToUnifiedConverter:
|
||||
|
||||
def _extract_from_enhanced_results(
|
||||
self,
|
||||
enhanced_results: List[Dict[str, Any]]
|
||||
enhanced_results: List[Dict[str, Any]],
|
||||
raw_text_regions: Optional[List[Dict[str, Any]]] = None,
|
||||
ocr_dimensions: Optional[Dict[str, Any]] = None
|
||||
) -> List[Page]:
|
||||
"""Extract pages from enhanced PP-StructureV3 results."""
|
||||
"""
|
||||
Extract pages from enhanced PP-StructureV3 results.
|
||||
|
||||
Applies gap filling when enabled to supplement PP-StructureV3 output
|
||||
with raw OCR regions that were not detected by the layout model.
|
||||
|
||||
Args:
|
||||
enhanced_results: PP-StructureV3 enhanced results
|
||||
raw_text_regions: Raw OCR text regions for gap filling
|
||||
ocr_dimensions: OCR image dimensions for coordinate alignment
|
||||
"""
|
||||
pages = []
|
||||
|
||||
for page_idx, page_result in enumerate(enhanced_results):
|
||||
@@ -158,15 +344,52 @@ class OCRToUnifiedConverter:
|
||||
if element:
|
||||
elements.append(element)
|
||||
|
||||
# Get page dimensions
|
||||
pp_dimensions = Dimensions(
|
||||
width=page_result.get('width', 0),
|
||||
height=page_result.get('height', 0)
|
||||
)
|
||||
|
||||
# Apply gap filling if enabled and raw regions available
|
||||
if self.gap_filling_service and raw_text_regions:
|
||||
# Filter raw regions for current page
|
||||
page_raw_regions = [
|
||||
r for r in raw_text_regions
|
||||
if r.get('page', 0) == page_idx or r.get('page', 1) == page_idx + 1
|
||||
]
|
||||
|
||||
if page_raw_regions:
|
||||
supplemented, stats = self.gap_filling_service.fill_gaps(
|
||||
raw_ocr_regions=page_raw_regions,
|
||||
pp_structure_elements=elements,
|
||||
page_number=page_idx + 1,
|
||||
ocr_dimensions=ocr_dimensions,
|
||||
pp_dimensions=pp_dimensions
|
||||
)
|
||||
|
||||
# Store statistics
|
||||
self.gap_filling_stats[f'page_{page_idx + 1}'] = stats
|
||||
|
||||
if supplemented:
|
||||
logger.info(
|
||||
f"Page {page_idx + 1}: Gap filling added {len(supplemented)} elements "
|
||||
f"(coverage: {stats.get('coverage_ratio', 0):.2%})"
|
||||
)
|
||||
elements.extend(supplemented)
|
||||
|
||||
# Recalculate reading order for combined elements
|
||||
reading_order = self.gap_filling_service.recalculate_reading_order(elements)
|
||||
page_result['reading_order'] = reading_order
|
||||
|
||||
# Create page
|
||||
page = Page(
|
||||
page_number=page_idx + 1,
|
||||
dimensions=Dimensions(
|
||||
width=page_result.get('width', 0),
|
||||
height=page_result.get('height', 0)
|
||||
),
|
||||
dimensions=pp_dimensions,
|
||||
elements=elements,
|
||||
metadata={'reading_order': page_result.get('reading_order', [])}
|
||||
metadata={
|
||||
'reading_order': page_result.get('reading_order', []),
|
||||
'gap_filling': self.gap_filling_stats.get(f'page_{page_idx + 1}', {})
|
||||
}
|
||||
)
|
||||
|
||||
pages.append(page)
|
||||
@@ -500,6 +723,9 @@ class OCRToUnifiedConverter:
|
||||
) -> Optional[DocumentElement]:
|
||||
"""Convert table data to DocumentElement."""
|
||||
try:
|
||||
# Clean up empty columns before building TableData
|
||||
table_dict = trim_empty_columns(table_dict)
|
||||
|
||||
# Extract bbox
|
||||
bbox_data = table_dict.get('bbox', [0, 0, 0, 0])
|
||||
bbox = BoundingBox(
|
||||
@@ -587,14 +813,22 @@ class OCRToUnifiedConverter:
|
||||
cells = []
|
||||
headers = []
|
||||
rows = table.find_all('tr')
|
||||
num_rows = len(rows)
|
||||
|
||||
# Track actual column positions accounting for rowspan/colspan
|
||||
# This is a simplified approach - complex spanning may need enhancement
|
||||
# First pass: calculate total columns by finding max column extent
|
||||
# Track cells that span multiple rows: occupied[row][col] = True
|
||||
occupied: Dict[int, Dict[int, bool]] = {r: {} for r in range(num_rows)}
|
||||
|
||||
# Parse all cells with proper rowspan/colspan handling
|
||||
for row_idx, row in enumerate(rows):
|
||||
row_cells = row.find_all(['td', 'th'])
|
||||
col_idx = 0
|
||||
|
||||
for cell in row_cells:
|
||||
# Skip columns that are occupied by rowspan from previous rows
|
||||
while occupied[row_idx].get(col_idx, False):
|
||||
col_idx += 1
|
||||
|
||||
cell_content = cell.get_text(strip=True)
|
||||
rowspan = int(cell.get('rowspan', 1))
|
||||
colspan = int(cell.get('colspan', 1))
|
||||
@@ -611,26 +845,66 @@ class OCRToUnifiedConverter:
|
||||
if cell.name == 'th' or row_idx == 0:
|
||||
headers.append(cell_content)
|
||||
|
||||
# Mark cells as occupied for rowspan/colspan
|
||||
for r in range(row_idx, min(row_idx + rowspan, num_rows)):
|
||||
for c in range(col_idx, col_idx + colspan):
|
||||
if r not in occupied:
|
||||
occupied[r] = {}
|
||||
occupied[r][c] = True
|
||||
|
||||
# Advance column index by colspan
|
||||
col_idx += colspan
|
||||
|
||||
# Calculate actual dimensions
|
||||
num_rows = len(rows)
|
||||
num_cols = max(
|
||||
sum(int(cell.get('colspan', 1)) for cell in row.find_all(['td', 'th']))
|
||||
for row in rows
|
||||
) if rows else 0
|
||||
# Calculate actual column count from occupied cells
|
||||
num_cols = 0
|
||||
for r in range(num_rows):
|
||||
if occupied[r]:
|
||||
max_col_in_row = max(occupied[r].keys()) + 1
|
||||
num_cols = max(num_cols, max_col_in_row)
|
||||
|
||||
logger.debug(
|
||||
f"Parsed HTML table: {num_rows} rows, {num_cols} cols, {len(cells)} cells"
|
||||
)
|
||||
|
||||
# Build table dict for cleanup
|
||||
table_dict = {
|
||||
'rows': num_rows,
|
||||
'cols': num_cols,
|
||||
'cells': [
|
||||
{
|
||||
'row': c.row,
|
||||
'col': c.col,
|
||||
'row_span': c.row_span,
|
||||
'col_span': c.col_span,
|
||||
'content': c.content
|
||||
}
|
||||
for c in cells
|
||||
],
|
||||
'headers': headers if headers else None,
|
||||
'caption': extracted_text if extracted_text else None
|
||||
}
|
||||
|
||||
# Clean up empty columns
|
||||
table_dict = trim_empty_columns(table_dict)
|
||||
|
||||
# Convert cleaned cells back to TableCell objects
|
||||
cleaned_cells = [
|
||||
TableCell(
|
||||
row=c['row'],
|
||||
col=c['col'],
|
||||
row_span=c.get('row_span', 1),
|
||||
col_span=c.get('col_span', 1),
|
||||
content=c.get('content', '')
|
||||
)
|
||||
for c in table_dict.get('cells', [])
|
||||
]
|
||||
|
||||
return TableData(
|
||||
rows=num_rows,
|
||||
cols=num_cols,
|
||||
cells=cells,
|
||||
headers=headers if headers else None,
|
||||
caption=extracted_text if extracted_text else None
|
||||
rows=table_dict.get('rows', num_rows),
|
||||
cols=table_dict.get('cols', num_cols),
|
||||
cells=cleaned_cells,
|
||||
headers=table_dict.get('headers'),
|
||||
caption=table_dict.get('caption')
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
|
||||
344
backend/app/services/pp_structure_debug.py
Normal file
344
backend/app/services/pp_structure_debug.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
PP-StructureV3 Debug Service
|
||||
|
||||
Provides debugging tools for visualizing and saving PP-StructureV3 results:
|
||||
- Save raw results as JSON for inspection
|
||||
- Generate visualization images showing detected bboxes
|
||||
- Compare raw OCR regions with PP-StructureV3 elements
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Color palette for different element types (RGB)
|
||||
ELEMENT_COLORS: Dict[str, Tuple[int, int, int]] = {
|
||||
'text': (0, 128, 0), # Green
|
||||
'title': (0, 0, 255), # Blue
|
||||
'table': (255, 0, 0), # Red
|
||||
'figure': (255, 165, 0), # Orange
|
||||
'image': (255, 165, 0), # Orange
|
||||
'header': (128, 0, 128), # Purple
|
||||
'footer': (128, 0, 128), # Purple
|
||||
'equation': (0, 255, 255), # Cyan
|
||||
'chart': (255, 192, 203), # Pink
|
||||
'list': (139, 69, 19), # Brown
|
||||
'reference': (128, 128, 128), # Gray
|
||||
'default': (255, 0, 255), # Magenta for unknown types
|
||||
}
|
||||
|
||||
# Color for raw OCR regions
|
||||
RAW_OCR_COLOR = (255, 215, 0) # Gold
|
||||
|
||||
|
||||
class PPStructureDebug:
|
||||
"""Debug service for PP-StructureV3 analysis results."""
|
||||
|
||||
def __init__(self, output_dir: Path):
|
||||
"""
|
||||
Initialize debug service.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save debug outputs
|
||||
"""
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def save_raw_results(
|
||||
self,
|
||||
pp_structure_results: Dict[str, Any],
|
||||
raw_ocr_regions: List[Dict[str, Any]],
|
||||
filename_prefix: str = "debug"
|
||||
) -> Dict[str, Path]:
|
||||
"""
|
||||
Save raw PP-StructureV3 results and OCR regions as JSON files.
|
||||
|
||||
Args:
|
||||
pp_structure_results: Raw PP-StructureV3 analysis results
|
||||
raw_ocr_regions: Raw OCR text regions
|
||||
filename_prefix: Prefix for output files
|
||||
|
||||
Returns:
|
||||
Dictionary with paths to saved files
|
||||
"""
|
||||
saved_files = {}
|
||||
|
||||
# Save PP-StructureV3 results
|
||||
pp_json_path = self.output_dir / f"{filename_prefix}_pp_structure_raw.json"
|
||||
try:
|
||||
# Convert any non-serializable types
|
||||
serializable_results = self._make_serializable(pp_structure_results)
|
||||
with open(pp_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(serializable_results, f, ensure_ascii=False, indent=2)
|
||||
saved_files['pp_structure'] = pp_json_path
|
||||
logger.info(f"Saved PP-StructureV3 raw results to {pp_json_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save PP-StructureV3 results: {e}")
|
||||
|
||||
# Save raw OCR regions
|
||||
ocr_json_path = self.output_dir / f"{filename_prefix}_raw_ocr_regions.json"
|
||||
try:
|
||||
serializable_ocr = self._make_serializable(raw_ocr_regions)
|
||||
with open(ocr_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(serializable_ocr, f, ensure_ascii=False, indent=2)
|
||||
saved_files['raw_ocr'] = ocr_json_path
|
||||
logger.info(f"Saved raw OCR regions to {ocr_json_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save raw OCR regions: {e}")
|
||||
|
||||
# Save summary comparison
|
||||
summary_path = self.output_dir / f"{filename_prefix}_debug_summary.json"
|
||||
try:
|
||||
summary = self._generate_summary(pp_structure_results, raw_ocr_regions)
|
||||
with open(summary_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||||
saved_files['summary'] = summary_path
|
||||
logger.info(f"Saved debug summary to {summary_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save debug summary: {e}")
|
||||
|
||||
return saved_files
|
||||
|
||||
def generate_visualization(
|
||||
self,
|
||||
image_path: Path,
|
||||
pp_structure_elements: List[Dict[str, Any]],
|
||||
raw_ocr_regions: Optional[List[Dict[str, Any]]] = None,
|
||||
filename_prefix: str = "debug",
|
||||
show_labels: bool = True,
|
||||
show_raw_ocr: bool = True
|
||||
) -> Optional[Path]:
|
||||
"""
|
||||
Generate visualization image showing detected elements.
|
||||
|
||||
Args:
|
||||
image_path: Path to original image
|
||||
pp_structure_elements: PP-StructureV3 detected elements
|
||||
raw_ocr_regions: Optional raw OCR regions to overlay
|
||||
filename_prefix: Prefix for output file
|
||||
show_labels: Whether to show element type labels
|
||||
show_raw_ocr: Whether to show raw OCR regions
|
||||
|
||||
Returns:
|
||||
Path to generated visualization image
|
||||
"""
|
||||
try:
|
||||
# Load original image
|
||||
img = Image.open(image_path)
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Create copy for drawing
|
||||
viz_img = img.copy()
|
||||
draw = ImageDraw.Draw(viz_img)
|
||||
|
||||
# Try to load a font, fall back to default
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
|
||||
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 10)
|
||||
except (IOError, OSError):
|
||||
try:
|
||||
font = ImageFont.truetype("/home/egg/project/Tool_OCR/backend/fonts/NotoSansSC-Regular.ttf", 14)
|
||||
small_font = ImageFont.truetype("/home/egg/project/Tool_OCR/backend/fonts/NotoSansSC-Regular.ttf", 10)
|
||||
except (IOError, OSError):
|
||||
font = ImageFont.load_default()
|
||||
small_font = font
|
||||
|
||||
# Draw raw OCR regions first (so PP-Structure boxes are on top)
|
||||
if show_raw_ocr and raw_ocr_regions:
|
||||
for idx, region in enumerate(raw_ocr_regions):
|
||||
bbox = self._normalize_bbox(region.get('bbox', []))
|
||||
if bbox:
|
||||
# Draw with dashed style simulation (draw thin lines)
|
||||
x0, y0, x1, y1 = bbox
|
||||
draw.rectangle([x0, y0, x1, y1], outline=RAW_OCR_COLOR, width=1)
|
||||
|
||||
# Add small label
|
||||
if show_labels:
|
||||
confidence = region.get('confidence', 0)
|
||||
label = f"OCR:{confidence:.2f}"
|
||||
draw.text((x0, y0 - 12), label, fill=RAW_OCR_COLOR, font=small_font)
|
||||
|
||||
# Draw PP-StructureV3 elements
|
||||
for idx, elem in enumerate(pp_structure_elements):
|
||||
elem_type = elem.get('type', 'default')
|
||||
if hasattr(elem_type, 'value'):
|
||||
elem_type = elem_type.value
|
||||
elem_type = str(elem_type).lower()
|
||||
|
||||
color = ELEMENT_COLORS.get(elem_type, ELEMENT_COLORS['default'])
|
||||
bbox = self._normalize_bbox(elem.get('bbox', []))
|
||||
|
||||
if bbox:
|
||||
x0, y0, x1, y1 = bbox
|
||||
# Draw thicker rectangle for PP-Structure elements
|
||||
draw.rectangle([x0, y0, x1, y1], outline=color, width=3)
|
||||
|
||||
# Add label
|
||||
if show_labels:
|
||||
label = f"{idx}:{elem_type}"
|
||||
# Draw label background
|
||||
text_bbox = draw.textbbox((x0, y0 - 18), label, font=font)
|
||||
draw.rectangle(text_bbox, fill=(255, 255, 255, 200))
|
||||
draw.text((x0, y0 - 18), label, fill=color, font=font)
|
||||
|
||||
# Add legend
|
||||
self._draw_legend(draw, img.width, font)
|
||||
|
||||
# Add image info
|
||||
info_text = f"PP-Structure: {len(pp_structure_elements)} elements"
|
||||
if raw_ocr_regions:
|
||||
info_text += f" | Raw OCR: {len(raw_ocr_regions)} regions"
|
||||
info_text += f" | Size: {img.width}x{img.height}"
|
||||
draw.text((10, img.height - 25), info_text, fill=(0, 0, 0), font=font)
|
||||
|
||||
# Save visualization
|
||||
viz_path = self.output_dir / f"{filename_prefix}_pp_structure_viz.png"
|
||||
viz_img.save(viz_path, 'PNG')
|
||||
logger.info(f"Saved visualization to {viz_path}")
|
||||
|
||||
return viz_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate visualization: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def _draw_legend(self, draw: ImageDraw, img_width: int, font: ImageFont):
|
||||
"""Draw a legend showing element type colors."""
|
||||
legend_x = img_width - 150
|
||||
legend_y = 10
|
||||
|
||||
# Draw legend background
|
||||
draw.rectangle(
|
||||
[legend_x - 5, legend_y - 5, img_width - 5, legend_y + len(ELEMENT_COLORS) * 18 + 25],
|
||||
fill=(255, 255, 255, 230),
|
||||
outline=(0, 0, 0)
|
||||
)
|
||||
|
||||
draw.text((legend_x, legend_y), "Legend:", fill=(0, 0, 0), font=font)
|
||||
legend_y += 20
|
||||
|
||||
for elem_type, color in ELEMENT_COLORS.items():
|
||||
if elem_type == 'default':
|
||||
continue
|
||||
draw.rectangle([legend_x, legend_y + 2, legend_x + 12, legend_y + 14], fill=color)
|
||||
draw.text((legend_x + 18, legend_y), elem_type, fill=(0, 0, 0), font=font)
|
||||
legend_y += 18
|
||||
|
||||
# Add raw OCR legend entry
|
||||
draw.rectangle([legend_x, legend_y + 2, legend_x + 12, legend_y + 14], fill=RAW_OCR_COLOR)
|
||||
draw.text((legend_x + 18, legend_y), "raw_ocr", fill=(0, 0, 0), font=font)
|
||||
|
||||
def _normalize_bbox(self, bbox: Any) -> Optional[Tuple[float, float, float, float]]:
|
||||
"""Normalize bbox to (x0, y0, x1, y1) format."""
|
||||
if not bbox:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Handle nested list format [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||||
if isinstance(bbox, (list, tuple)) and len(bbox) >= 1:
|
||||
if isinstance(bbox[0], (list, tuple)):
|
||||
xs = [pt[0] for pt in bbox if len(pt) >= 2]
|
||||
ys = [pt[1] for pt in bbox if len(pt) >= 2]
|
||||
if xs and ys:
|
||||
return (min(xs), min(ys), max(xs), max(ys))
|
||||
|
||||
# Handle flat list [x0, y0, x1, y1]
|
||||
if isinstance(bbox, (list, tuple)) and len(bbox) == 4:
|
||||
return (float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3]))
|
||||
|
||||
# Handle flat polygon [x1, y1, x2, y2, ...]
|
||||
if isinstance(bbox, (list, tuple)) and len(bbox) >= 8:
|
||||
xs = [bbox[i] for i in range(0, len(bbox), 2)]
|
||||
ys = [bbox[i] for i in range(1, len(bbox), 2)]
|
||||
return (min(xs), min(ys), max(xs), max(ys))
|
||||
|
||||
# Handle dict format
|
||||
if isinstance(bbox, dict):
|
||||
return (
|
||||
float(bbox.get('x0', bbox.get('x_min', 0))),
|
||||
float(bbox.get('y0', bbox.get('y_min', 0))),
|
||||
float(bbox.get('x1', bbox.get('x_max', 0))),
|
||||
float(bbox.get('y1', bbox.get('y_max', 0)))
|
||||
)
|
||||
|
||||
except (TypeError, ValueError, IndexError) as e:
|
||||
logger.warning(f"Failed to normalize bbox {bbox}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _generate_summary(
|
||||
self,
|
||||
pp_structure_results: Dict[str, Any],
|
||||
raw_ocr_regions: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate summary comparing PP-Structure and raw OCR."""
|
||||
pp_elements = pp_structure_results.get('elements', [])
|
||||
|
||||
# Count element types
|
||||
type_counts = {}
|
||||
for elem in pp_elements:
|
||||
elem_type = elem.get('type', 'unknown')
|
||||
if hasattr(elem_type, 'value'):
|
||||
elem_type = elem_type.value
|
||||
type_counts[str(elem_type)] = type_counts.get(str(elem_type), 0) + 1
|
||||
|
||||
# Calculate bounding box coverage
|
||||
pp_bbox_area = 0
|
||||
ocr_bbox_area = 0
|
||||
|
||||
for elem in pp_elements:
|
||||
bbox = self._normalize_bbox(elem.get('bbox'))
|
||||
if bbox:
|
||||
pp_bbox_area += (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
||||
|
||||
for region in raw_ocr_regions:
|
||||
bbox = self._normalize_bbox(region.get('bbox'))
|
||||
if bbox:
|
||||
ocr_bbox_area += (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
||||
|
||||
return {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'pp_structure': {
|
||||
'total_elements': len(pp_elements),
|
||||
'element_types': type_counts,
|
||||
'total_bbox_area': pp_bbox_area,
|
||||
'has_parsing_res_list': pp_structure_results.get('has_parsing_res_list', False)
|
||||
},
|
||||
'raw_ocr': {
|
||||
'total_regions': len(raw_ocr_regions),
|
||||
'total_bbox_area': ocr_bbox_area,
|
||||
'avg_confidence': sum(r.get('confidence', 0) for r in raw_ocr_regions) / len(raw_ocr_regions) if raw_ocr_regions else 0
|
||||
},
|
||||
'comparison': {
|
||||
'element_count_ratio': len(pp_elements) / len(raw_ocr_regions) if raw_ocr_regions else 0,
|
||||
'area_ratio': pp_bbox_area / ocr_bbox_area if ocr_bbox_area > 0 else 0,
|
||||
'potential_gap': len(raw_ocr_regions) - len(pp_elements) if raw_ocr_regions else 0
|
||||
}
|
||||
}
|
||||
|
||||
def _make_serializable(self, obj: Any) -> Any:
|
||||
"""Convert object to JSON-serializable format."""
|
||||
if obj is None:
|
||||
return None
|
||||
if isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [self._make_serializable(item) for item in obj]
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): self._make_serializable(v) for k, v in obj.items()}
|
||||
if hasattr(obj, 'value'):
|
||||
return obj.value
|
||||
if hasattr(obj, '__dict__'):
|
||||
return self._make_serializable(obj.__dict__)
|
||||
if hasattr(obj, 'tolist'): # numpy array
|
||||
return obj.tolist()
|
||||
return str(obj)
|
||||
Reference in New Issue
Block a user