feat: simplify layout model selection and archive proposals

Changes:
- Replace PP-Structure 7-slider parameter UI with simple 3-option layout model selector
- Add layout model mapping: chinese (PP-DocLayout-S), default (PubLayNet), cdla
- Add LayoutModelSelector component and zh-TW translations
- Fix "default" model behavior with sentinel value for PubLayNet
- Add gap filling service for OCR track coverage improvement
- Add PP-Structure debug utilities
- Archive completed/incomplete proposals:
  - add-ocr-track-gap-filling (complete)
  - fix-ocr-track-table-rendering (incomplete)
  - simplify-ppstructure-model-selection (22/25 tasks)
- Add new layout model tests, archive old PP-Structure param tests
- Update OpenSpec ocr-processing spec with layout model requirements

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
egg
2025-11-27 13:27:00 +08:00
parent c65df754cf
commit 59206a6ab8
35 changed files with 3621 additions and 658 deletions

View File

@@ -3,7 +3,7 @@ Tool_OCR - Configuration Management
Loads environment variables and provides centralized configuration
"""
from typing import List
from typing import List, Optional
from pydantic_settings import BaseSettings
from pydantic import Field
from pathlib import Path
@@ -99,6 +99,33 @@ class Settings(BaseSettings):
text_det_box_thresh: float = Field(default=0.3) # Lower box threshold for better detection
text_det_unclip_ratio: float = Field(default=1.2) # Smaller unclip for tighter text boxes
# Layout Detection Model Configuration
# Available models:
# - None (default): Use PP-StructureV3's built-in model (PubLayNet-based)
# - "PP-DocLayout-S": Better for Chinese docs, papers, contracts, exams (23 categories)
# - "picodet_lcnet_x1_0_fgd_layout_cdla": CDLA-based model for Chinese document layout
layout_detection_model_name: Optional[str] = Field(
default="PP-DocLayout-S",
description="Layout detection model name. Set to 'PP-DocLayout-S' for better Chinese document support."
)
layout_detection_model_dir: Optional[str] = Field(
default=None,
description="Custom layout detection model directory. If None, downloads official model."
)
# ===== Gap Filling Configuration =====
# Supplements PP-StructureV3 output with raw OCR regions when detection is incomplete
gap_filling_enabled: bool = Field(default=True) # Enable gap filling for OCR track
gap_filling_coverage_threshold: float = Field(default=0.7) # Activate when coverage < 70%
gap_filling_iou_threshold: float = Field(default=0.15) # IoU threshold for coverage detection
gap_filling_confidence_threshold: float = Field(default=0.3) # Min confidence for raw OCR regions
gap_filling_dedup_iou_threshold: float = Field(default=0.5) # IoU threshold for deduplication
# ===== Debug Configuration =====
# Enable debug outputs for PP-StructureV3 analysis
pp_structure_debug_enabled: bool = Field(default=True) # Save debug files for PP-StructureV3
pp_structure_debug_visualization: bool = Field(default=True) # Generate visualization images
# Performance tuning
use_fp16_inference: bool = Field(default=False) # Half-precision (if supported)
enable_cudnn_benchmark: bool = Field(default=True) # Optimize convolution algorithms

View File

@@ -68,7 +68,7 @@ def process_task_ocr(
use_dual_track: bool = True,
force_track: Optional[str] = None,
language: str = 'ch',
pp_structure_params: Optional[dict] = None
layout_model: Optional[str] = "chinese"
):
"""
Background task to process OCR for a task with dual-track support.
@@ -84,7 +84,7 @@ def process_task_ocr(
use_dual_track: Enable dual-track processing
force_track: Force specific track ('ocr' or 'direct')
language: OCR language code
pp_structure_params: Optional custom PP-StructureV3 parameters (dict)
layout_model: Layout detection model ('chinese', 'default', 'cdla')
"""
from app.core.database import SessionLocal
from app.models.task import Task
@@ -143,7 +143,7 @@ def process_task_ocr(
output_dir=result_dir,
use_dual_track=use_dual_track,
force_track=force_track,
pp_structure_params=pp_structure_params
layout_model=layout_model
)
else:
# Fall back to traditional processing (no force_track support)
@@ -152,7 +152,7 @@ def process_task_ocr(
lang=language,
detect_layout=True,
output_dir=result_dir,
pp_structure_params=pp_structure_params
layout_model=layout_model
)
# Calculate processing time
@@ -717,14 +717,14 @@ async def start_task(
current_user: User = Depends(get_current_user)
):
"""
Start processing a pending task with dual-track support and optional PP-StructureV3 parameter tuning
Start processing a pending task with dual-track support and layout model selection
- **task_id**: Task UUID
- **options**: Processing options (in request body):
- **use_dual_track**: Enable intelligent track selection (default: true)
- **force_track**: Force specific processing track ('ocr' or 'direct')
- **language**: OCR language code (default: 'ch')
- **pp_structure_params**: Fine-tuning parameters for PP-StructureV3 (OCR track only)
- **layout_model**: Layout detection model ('chinese', 'default', 'cdla')
"""
try:
# Parse processing options with defaults
@@ -735,11 +735,9 @@ async def start_task(
force_track = options.force_track.value if options.force_track else None
language = options.language
# Extract and convert PP-StructureV3 parameters to dict
pp_structure_params = None
if options.pp_structure_params:
pp_structure_params = options.pp_structure_params.model_dump(exclude_none=True)
logger.info(f"Using custom PP-StructureV3 parameters: {pp_structure_params}")
# Extract layout model (default to 'chinese' for best Chinese document support)
layout_model = options.layout_model.value if options.layout_model else "chinese"
logger.info(f"Using layout model: {layout_model}")
# Get task details
task = task_service.get_task_by_id(
@@ -777,7 +775,7 @@ async def start_task(
status=TaskStatus.PROCESSING
)
# Start OCR processing in background with dual-track parameters and custom PP-StructureV3 params
# Start OCR processing in background with dual-track parameters and layout model
background_tasks.add_task(
process_task_ocr,
task_id=task_id,
@@ -787,13 +785,11 @@ async def start_task(
use_dual_track=use_dual_track,
force_track=force_track,
language=language,
pp_structure_params=pp_structure_params
layout_model=layout_model
)
logger.info(f"Started OCR processing task {task_id} for user {current_user.email}")
logger.info(f"Options: dual_track={use_dual_track}, force_track={force_track}, lang={language}")
if pp_structure_params:
logger.info(f"Custom PP-StructureV3 params: {pp_structure_params}")
logger.info(f"Options: dual_track={use_dual_track}, force_track={force_track}, lang={language}, layout_model={layout_model}")
return task
except HTTPException:

View File

@@ -24,6 +24,19 @@ class ProcessingTrackEnum(str, Enum):
AUTO = "auto" # Auto-detect best track
class LayoutModelEnum(str, Enum):
"""Layout detection model selection for OCR track.
Different models are optimized for different document types:
- CHINESE: PP-DocLayout-S, optimized for Chinese documents (forms, contracts, invoices)
- DEFAULT: PubLayNet-based, optimized for English academic papers
- CDLA: CDLA model, specialized Chinese document layout analysis
"""
CHINESE = "chinese" # PP-DocLayout-S - Best for Chinese documents (recommended)
DEFAULT = "default" # PubLayNet-based - Best for English documents
CDLA = "cdla" # CDLA model - Alternative for Chinese layout
class TaskCreate(BaseModel):
"""Task creation request"""
filename: Optional[str] = Field(None, description="Original filename")
@@ -132,7 +145,11 @@ class UploadResponse(BaseModel):
# ===== Dual-Track Processing Schemas =====
class PPStructureV3Params(BaseModel):
"""PP-StructureV3 fine-tuning parameters for OCR track"""
"""PP-StructureV3 fine-tuning parameters for OCR track.
DEPRECATED: This class is deprecated and will be removed in a future version.
Use `layout_model` parameter in ProcessingOptions instead.
"""
layout_detection_threshold: Optional[float] = Field(
None, ge=0, le=1,
description="Layout block detection score threshold (lower=more blocks, higher=high confidence only)"
@@ -172,10 +189,10 @@ class ProcessingOptions(BaseModel):
include_images: bool = Field(default=True, description="Extract and save images")
confidence_threshold: Optional[float] = Field(None, ge=0, le=1, description="OCR confidence threshold")
# PP-StructureV3 fine-tuning parameters (OCR track only)
pp_structure_params: Optional[PPStructureV3Params] = Field(
None,
description="Fine-tuning parameters for PP-StructureV3 (OCR track only)"
# Layout model selection (OCR track only)
layout_model: Optional[LayoutModelEnum] = Field(
default=LayoutModelEnum.CHINESE,
description="Layout detection model: 'chinese' (recommended for Chinese docs), 'default' (English docs), 'cdla' (Chinese layout)"
)

View File

@@ -0,0 +1,649 @@
"""
Gap Filling Service for OCR Track
This service detects and fills gaps in PP-StructureV3 output by supplementing
with Raw OCR text regions when significant content loss is detected.
The hybrid approach uses Raw OCR's comprehensive text detection to compensate
for PP-StructureV3's layout model limitations on certain document types.
"""
import logging
from typing import Dict, List, Optional, Tuple, Set, Any
from dataclasses import dataclass
from app.models.unified_document import (
DocumentElement, BoundingBox, ElementType, Dimensions
)
from app.core.config import settings
logger = logging.getLogger(__name__)
# Element types that should NOT be supplemented (preserve structural integrity)
SKIP_ELEMENT_TYPES: Set[ElementType] = {
ElementType.TABLE,
ElementType.IMAGE,
ElementType.FIGURE,
ElementType.CHART,
ElementType.DIAGRAM,
ElementType.HEADER,
ElementType.FOOTER,
ElementType.FORMULA,
ElementType.CODE,
ElementType.BARCODE,
ElementType.QR_CODE,
ElementType.LOGO,
ElementType.STAMP,
ElementType.SIGNATURE,
}
@dataclass
class TextRegion:
"""Represents a raw OCR text region."""
text: str
bbox: List[float] # [x0, y0, x1, y1] or polygon format
confidence: float
page: int = 0
@property
def normalized_bbox(self) -> Tuple[float, float, float, float]:
"""Get normalized bbox as (x0, y0, x1, y1)."""
if not self.bbox:
return (0, 0, 0, 0)
# Check if bbox is nested list format [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
# This is common PaddleOCR polygon format
if len(self.bbox) >= 1 and isinstance(self.bbox[0], (list, tuple)):
# Nested format: extract all x and y coordinates
xs = [pt[0] for pt in self.bbox if len(pt) >= 2]
ys = [pt[1] for pt in self.bbox if len(pt) >= 2]
if xs and ys:
return (min(xs), min(ys), max(xs), max(ys))
return (0, 0, 0, 0)
# Flat format
if len(self.bbox) == 4:
# Simple [x0, y0, x1, y1] format
return (float(self.bbox[0]), float(self.bbox[1]),
float(self.bbox[2]), float(self.bbox[3]))
elif len(self.bbox) >= 8:
# Flat polygon format: [x1, y1, x2, y2, x3, y3, x4, y4]
xs = [self.bbox[i] for i in range(0, len(self.bbox), 2)]
ys = [self.bbox[i] for i in range(1, len(self.bbox), 2)]
return (min(xs), min(ys), max(xs), max(ys))
return (0, 0, 0, 0)
@property
def center(self) -> Tuple[float, float]:
"""Get center point of the bbox."""
x0, y0, x1, y1 = self.normalized_bbox
return ((x0 + x1) / 2, (y0 + y1) / 2)
class GapFillingService:
"""
Service for detecting and filling gaps in PP-StructureV3 output.
This service:
1. Calculates coverage of PP-StructureV3 elements over raw OCR regions
2. Identifies uncovered raw OCR regions
3. Supplements uncovered regions as TEXT elements
4. Deduplicates against existing PP-StructureV3 TEXT elements
5. Recalculates reading order for the combined result
"""
def __init__(
self,
coverage_threshold: float = None,
iou_threshold: float = None,
confidence_threshold: float = None,
dedup_iou_threshold: float = None,
enabled: bool = None
):
"""
Initialize the gap filling service.
Args:
coverage_threshold: Coverage ratio below which gap filling activates (default: 0.7)
iou_threshold: IoU threshold for coverage detection (default: 0.15)
confidence_threshold: Minimum confidence for raw OCR regions (default: 0.3)
dedup_iou_threshold: IoU threshold for deduplication (default: 0.5)
enabled: Whether gap filling is enabled (default: True)
"""
self.coverage_threshold = coverage_threshold if coverage_threshold is not None else getattr(
settings, 'gap_filling_coverage_threshold', 0.7
)
self.iou_threshold = iou_threshold if iou_threshold is not None else getattr(
settings, 'gap_filling_iou_threshold', 0.15
)
self.confidence_threshold = confidence_threshold if confidence_threshold is not None else getattr(
settings, 'gap_filling_confidence_threshold', 0.3
)
self.dedup_iou_threshold = dedup_iou_threshold if dedup_iou_threshold is not None else getattr(
settings, 'gap_filling_dedup_iou_threshold', 0.5
)
self.enabled = enabled if enabled is not None else getattr(
settings, 'gap_filling_enabled', True
)
def should_activate(
self,
raw_ocr_regions: List[TextRegion],
pp_structure_elements: List[DocumentElement]
) -> Tuple[bool, float]:
"""
Determine if gap filling should be activated.
Gap filling activates when:
1. Coverage ratio is below threshold (default: 70%)
2. OR element count disparity is significant
Args:
raw_ocr_regions: List of raw OCR text regions
pp_structure_elements: List of PP-StructureV3 elements
Returns:
Tuple of (should_activate, coverage_ratio)
"""
if not self.enabled:
return False, 1.0
if not raw_ocr_regions:
return False, 1.0
# Calculate coverage
covered_count = 0
for region in raw_ocr_regions:
if self._is_region_covered(region, pp_structure_elements):
covered_count += 1
coverage_ratio = covered_count / len(raw_ocr_regions)
# Check activation conditions
should_activate = coverage_ratio < self.coverage_threshold
if should_activate:
logger.info(
f"Gap filling activated: coverage={coverage_ratio:.2%} < threshold={self.coverage_threshold:.0%}, "
f"raw_regions={len(raw_ocr_regions)}, pp_elements={len(pp_structure_elements)}"
)
else:
logger.debug(
f"Gap filling not needed: coverage={coverage_ratio:.2%} >= threshold={self.coverage_threshold:.0%}"
)
return should_activate, coverage_ratio
def find_uncovered_regions(
self,
raw_ocr_regions: List[TextRegion],
pp_structure_elements: List[DocumentElement]
) -> List[TextRegion]:
"""
Find raw OCR regions not covered by PP-StructureV3 elements.
A region is considered covered if:
1. Its center point falls inside any PP-StructureV3 element bbox, OR
2. IoU with any PP-StructureV3 element exceeds iou_threshold
Args:
raw_ocr_regions: List of raw OCR text regions
pp_structure_elements: List of PP-StructureV3 elements
Returns:
List of uncovered raw OCR regions
"""
uncovered = []
for region in raw_ocr_regions:
# Skip low confidence regions
if region.confidence < self.confidence_threshold:
continue
if not self._is_region_covered(region, pp_structure_elements):
uncovered.append(region)
logger.debug(f"Found {len(uncovered)} uncovered regions out of {len(raw_ocr_regions)}")
return uncovered
def _is_region_covered(
self,
region: TextRegion,
pp_structure_elements: List[DocumentElement]
) -> bool:
"""
Check if a raw OCR region is covered by any PP-StructureV3 element.
Args:
region: Raw OCR text region
pp_structure_elements: List of PP-StructureV3 elements
Returns:
True if the region is covered
"""
center_x, center_y = region.center
region_bbox = region.normalized_bbox
for element in pp_structure_elements:
elem_bbox = (
element.bbox.x0, element.bbox.y0,
element.bbox.x1, element.bbox.y1
)
# Check 1: Center point falls inside element bbox
if self._point_in_bbox(center_x, center_y, elem_bbox):
return True
# Check 2: IoU exceeds threshold
iou = self._calculate_iou(region_bbox, elem_bbox)
if iou > self.iou_threshold:
return True
return False
def deduplicate_regions(
self,
uncovered_regions: List[TextRegion],
pp_structure_elements: List[DocumentElement]
) -> List[TextRegion]:
"""
Remove regions that highly overlap with existing PP-StructureV3 TEXT elements.
Args:
uncovered_regions: List of uncovered raw OCR regions
pp_structure_elements: List of PP-StructureV3 elements
Returns:
Deduplicated list of regions
"""
# Get TEXT elements only for deduplication
text_elements = [
e for e in pp_structure_elements
if e.type not in SKIP_ELEMENT_TYPES
]
deduplicated = []
for region in uncovered_regions:
region_bbox = region.normalized_bbox
is_duplicate = False
for element in text_elements:
elem_bbox = (
element.bbox.x0, element.bbox.y0,
element.bbox.x1, element.bbox.y1
)
iou = self._calculate_iou(region_bbox, elem_bbox)
if iou > self.dedup_iou_threshold:
logger.debug(
f"Skipping duplicate region (IoU={iou:.2f}): '{region.text[:30]}...'"
)
is_duplicate = True
break
if not is_duplicate:
deduplicated.append(region)
removed_count = len(uncovered_regions) - len(deduplicated)
if removed_count > 0:
logger.debug(f"Removed {removed_count} duplicate regions")
return deduplicated
def convert_regions_to_elements(
self,
regions: List[TextRegion],
page_number: int,
start_element_id: int = 0
) -> List[DocumentElement]:
"""
Convert raw OCR regions to DocumentElement objects.
Args:
regions: List of raw OCR regions to convert
page_number: Page number for the elements
start_element_id: Starting ID counter for elements
Returns:
List of DocumentElement objects
"""
elements = []
for idx, region in enumerate(regions):
x0, y0, x1, y1 = region.normalized_bbox
element = DocumentElement(
element_id=f"gap_fill_{page_number}_{start_element_id + idx}",
type=ElementType.TEXT,
content=region.text,
bbox=BoundingBox(x0=x0, y0=y0, x1=x1, y1=y1),
confidence=region.confidence,
metadata={
'source': 'gap_filling',
'original_confidence': region.confidence
}
)
elements.append(element)
return elements
def recalculate_reading_order(
self,
elements: List[DocumentElement]
) -> List[int]:
"""
Recalculate reading order for elements based on position.
Sorts elements by y0 (top to bottom) then x0 (left to right).
Args:
elements: List of DocumentElement objects
Returns:
List of element indices in reading order
"""
# Create indexed list with position info
indexed_elements = [
(idx, e.bbox.y0, e.bbox.x0)
for idx, e in enumerate(elements)
]
# Sort by y0 then x0
indexed_elements.sort(key=lambda x: (x[1], x[2]))
# Return indices in reading order
return [idx for idx, _, _ in indexed_elements]
def merge_adjacent_regions(
self,
regions: List[TextRegion],
max_horizontal_gap: float = 20.0,
max_vertical_gap: float = 5.0
) -> List[TextRegion]:
"""
Merge fragmented adjacent regions on the same line.
This is optional and can reduce fragmentation from raw OCR.
Args:
regions: List of raw OCR regions
max_horizontal_gap: Maximum horizontal gap to merge (pixels)
max_vertical_gap: Maximum vertical gap to merge (pixels)
Returns:
List of merged regions
"""
if not regions:
return regions
# Sort by y0, then x0
sorted_regions = sorted(
regions,
key=lambda r: (r.normalized_bbox[1], r.normalized_bbox[0])
)
merged = []
current = sorted_regions[0]
for next_region in sorted_regions[1:]:
curr_bbox = current.normalized_bbox
next_bbox = next_region.normalized_bbox
# Check if on same line (vertical overlap)
curr_y_center = (curr_bbox[1] + curr_bbox[3]) / 2
next_y_center = (next_bbox[1] + next_bbox[3]) / 2
vertical_distance = abs(curr_y_center - next_y_center)
# Check horizontal gap
horizontal_gap = next_bbox[0] - curr_bbox[2]
if (vertical_distance < max_vertical_gap and
0 <= horizontal_gap <= max_horizontal_gap):
# Merge regions
merged_bbox = [
min(curr_bbox[0], next_bbox[0]),
min(curr_bbox[1], next_bbox[1]),
max(curr_bbox[2], next_bbox[2]),
max(curr_bbox[3], next_bbox[3])
]
current = TextRegion(
text=current.text + " " + next_region.text,
bbox=merged_bbox,
confidence=min(current.confidence, next_region.confidence),
page=current.page
)
else:
merged.append(current)
current = next_region
merged.append(current)
if len(merged) < len(regions):
logger.debug(f"Merged {len(regions)} regions into {len(merged)}")
return merged
def fill_gaps(
self,
raw_ocr_regions: List[Dict[str, Any]],
pp_structure_elements: List[DocumentElement],
page_number: int,
ocr_dimensions: Optional[Dict[str, Any]] = None,
pp_dimensions: Optional[Dimensions] = None
) -> Tuple[List[DocumentElement], Dict[str, Any]]:
"""
Main entry point: detect gaps and fill with raw OCR regions.
Args:
raw_ocr_regions: Raw OCR results (list of dicts with text, bbox, confidence)
pp_structure_elements: PP-StructureV3 elements
page_number: Current page number
ocr_dimensions: OCR image dimensions for coordinate alignment
pp_dimensions: PP-Structure dimensions for coordinate alignment
Returns:
Tuple of (supplemented_elements, statistics)
"""
statistics = {
'enabled': self.enabled,
'activated': False,
'coverage_ratio': 1.0,
'raw_ocr_count': len(raw_ocr_regions),
'pp_structure_count': len(pp_structure_elements),
'uncovered_count': 0,
'deduplicated_count': 0,
'supplemented_count': 0
}
if not self.enabled:
logger.debug("Gap filling is disabled")
return [], statistics
# Convert raw OCR regions to TextRegion objects
text_regions = self._convert_raw_ocr_regions(
raw_ocr_regions, page_number, ocr_dimensions, pp_dimensions
)
if not text_regions:
logger.debug("No valid text regions to process")
return [], statistics
# Check if gap filling should activate
should_activate, coverage_ratio = self.should_activate(
text_regions, pp_structure_elements
)
statistics['coverage_ratio'] = coverage_ratio
statistics['activated'] = should_activate
if not should_activate:
return [], statistics
# Find uncovered regions
uncovered = self.find_uncovered_regions(text_regions, pp_structure_elements)
statistics['uncovered_count'] = len(uncovered)
if not uncovered:
logger.debug("No uncovered regions found")
return [], statistics
# Deduplicate against existing TEXT elements
deduplicated = self.deduplicate_regions(uncovered, pp_structure_elements)
statistics['deduplicated_count'] = len(deduplicated)
if not deduplicated:
logger.debug("All uncovered regions were duplicates")
return [], statistics
# Optional: Merge adjacent regions
# merged = self.merge_adjacent_regions(deduplicated)
# Convert to DocumentElements
start_id = len(pp_structure_elements)
supplemented = self.convert_regions_to_elements(
deduplicated, page_number, start_id
)
statistics['supplemented_count'] = len(supplemented)
logger.info(
f"Gap filling complete: supplemented {len(supplemented)} elements "
f"(coverage: {coverage_ratio:.2%} -> estimated {(coverage_ratio + len(supplemented)/len(text_regions) if text_regions else 0):.2%})"
)
return supplemented, statistics
def _convert_raw_ocr_regions(
self,
raw_regions: List[Dict[str, Any]],
page_number: int,
ocr_dimensions: Optional[Dict[str, Any]] = None,
pp_dimensions: Optional[Dimensions] = None
) -> List[TextRegion]:
"""
Convert raw OCR region dicts to TextRegion objects.
Handles coordinate alignment if dimensions are provided.
Args:
raw_regions: List of raw OCR region dictionaries
page_number: Current page number
ocr_dimensions: OCR image dimensions
pp_dimensions: PP-Structure dimensions
Returns:
List of TextRegion objects
"""
text_regions = []
# Calculate scale factors if needed
scale_x, scale_y = 1.0, 1.0
if ocr_dimensions and pp_dimensions:
ocr_width = ocr_dimensions.get('width', 0)
ocr_height = ocr_dimensions.get('height', 0)
if ocr_width > 0 and pp_dimensions.width > 0:
scale_x = pp_dimensions.width / ocr_width
if ocr_height > 0 and pp_dimensions.height > 0:
scale_y = pp_dimensions.height / ocr_height
if scale_x != 1.0 or scale_y != 1.0:
logger.debug(f"Coordinate scaling: x={scale_x:.3f}, y={scale_y:.3f}")
for region in raw_regions:
text = region.get('text', '')
if not text or not text.strip():
continue
confidence = region.get('confidence', 0.0)
bbox_raw = region.get('bbox', [])
# Normalize bbox
if isinstance(bbox_raw, dict):
# Dict format: {x_min, y_min, x_max, y_max}
bbox = [
bbox_raw.get('x_min', 0),
bbox_raw.get('y_min', 0),
bbox_raw.get('x_max', 0),
bbox_raw.get('y_max', 0)
]
elif isinstance(bbox_raw, (list, tuple)):
bbox = list(bbox_raw)
else:
continue
# Apply scaling if needed
if scale_x != 1.0 or scale_y != 1.0:
# Check if nested list format [[x1,y1], [x2,y2], ...]
if len(bbox) >= 1 and isinstance(bbox[0], (list, tuple)):
bbox = [
[pt[0] * scale_x, pt[1] * scale_y]
for pt in bbox if len(pt) >= 2
]
elif len(bbox) == 4 and not isinstance(bbox[0], (list, tuple)):
# Simple [x0, y0, x1, y1] format
bbox = [
bbox[0] * scale_x, bbox[1] * scale_y,
bbox[2] * scale_x, bbox[3] * scale_y
]
elif len(bbox) >= 8:
# Flat polygon format [x1, y1, x2, y2, ...]
bbox = [
bbox[i] * (scale_x if i % 2 == 0 else scale_y)
for i in range(len(bbox))
]
text_regions.append(TextRegion(
text=text,
bbox=bbox,
confidence=confidence,
page=page_number
))
return text_regions
@staticmethod
def _point_in_bbox(
x: float, y: float,
bbox: Tuple[float, float, float, float]
) -> bool:
"""Check if point (x, y) is inside bbox (x0, y0, x1, y1)."""
x0, y0, x1, y1 = bbox
return x0 <= x <= x1 and y0 <= y <= y1
@staticmethod
def _calculate_iou(
bbox1: Tuple[float, float, float, float],
bbox2: Tuple[float, float, float, float]
) -> float:
"""
Calculate Intersection over Union (IoU) of two bboxes.
Args:
bbox1: First bbox (x0, y0, x1, y1)
bbox2: Second bbox (x0, y0, x1, y1)
Returns:
IoU value between 0 and 1
"""
# Calculate intersection
x0 = max(bbox1[0], bbox2[0])
y0 = max(bbox1[1], bbox2[1])
x1 = min(bbox1[2], bbox2[2])
y1 = min(bbox1[3], bbox2[3])
if x1 <= x0 or y1 <= y0:
return 0.0
intersection = (x1 - x0) * (y1 - y0)
# Calculate union
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
union = area1 + area2 - intersection
if union <= 0:
return 0.0
return intersection / union

View File

@@ -46,6 +46,19 @@ except ImportError as e:
logger = logging.getLogger(__name__)
# Sentinel value for "use PubLayNet default" - explicitly NO model specification
_USE_PUBLAYNET_DEFAULT = "__USE_PUBLAYNET_DEFAULT__"
# Layout model mapping: user-friendly names to actual model names
# - "chinese": PP-DocLayout-S - Best for Chinese documents (forms, contracts, invoices)
# - "default": PubLayNet-based default model - Best for English documents
# - "cdla": picodet_lcnet_x1_0_fgd_layout_cdla - Alternative for Chinese layout
LAYOUT_MODEL_MAPPING = {
"chinese": "PP-DocLayout-S",
"default": _USE_PUBLAYNET_DEFAULT, # Uses default PubLayNet-based model (no custom model)
"cdla": "picodet_lcnet_x1_0_fgd_layout_cdla",
}
class OCRService:
"""
@@ -436,77 +449,45 @@ class OCRService:
return self.ocr_engines[lang]
def _ensure_structure_engine(self, custom_params: Optional[Dict[str, any]] = None) -> PPStructureV3:
def _ensure_structure_engine(self, layout_model: Optional[str] = None) -> PPStructureV3:
"""
Get or create PP-Structure engine for layout analysis with GPU support.
Supports custom parameters that override default settings.
Supports layout model selection for different document types.
Args:
custom_params: Optional dictionary of custom PP-StructureV3 parameters.
If provided, creates a new engine instance (not cached).
Supported keys: layout_detection_threshold, layout_nms_threshold,
layout_merge_bboxes_mode, layout_unclip_ratio, text_det_thresh,
text_det_box_thresh, text_det_unclip_ratio
layout_model: Layout detection model selection:
- "chinese": PP-DocLayout-S (best for Chinese documents)
- "default": PubLayNet-based (best for English documents)
- "cdla": CDLA model (alternative for Chinese layout)
- None: Use config default
Returns:
PPStructure engine instance
"""
# If custom params provided, create a new engine instance (don't use cache)
if custom_params:
logger.info(f"Creating PP-StructureV3 engine with custom parameters (GPU: {self.use_gpu})")
logger.info(f"Custom params: {custom_params}")
# Resolve layout model name from user-friendly name
resolved_model_name = None
use_publaynet_default = False # Flag to explicitly use PubLayNet default (no model param)
try:
# Base configuration from settings
use_chart = settings.enable_chart_recognition
use_formula = settings.enable_formula_recognition
use_table = settings.enable_table_recognition
if layout_model:
resolved_model_name = LAYOUT_MODEL_MAPPING.get(layout_model)
if layout_model not in LAYOUT_MODEL_MAPPING:
logger.warning(f"Unknown layout model '{layout_model}', using config default")
resolved_model_name = settings.layout_detection_model_name
elif resolved_model_name == _USE_PUBLAYNET_DEFAULT:
# User explicitly selected "default" - use PubLayNet without custom model
use_publaynet_default = True
resolved_model_name = None
logger.info(f"Using layout model: {layout_model} -> PubLayNet default (no custom model)")
else:
logger.info(f"Using layout model: {layout_model} -> {resolved_model_name}")
# Parameter priority: custom > settings default
layout_threshold = custom_params.get('layout_detection_threshold', settings.layout_detection_threshold)
layout_nms = custom_params.get('layout_nms_threshold', settings.layout_nms_threshold)
layout_merge = custom_params.get('layout_merge_bboxes_mode', settings.layout_merge_mode)
layout_unclip = custom_params.get('layout_unclip_ratio', settings.layout_unclip_ratio)
text_thresh = custom_params.get('text_det_thresh', settings.text_det_thresh)
text_box_thresh = custom_params.get('text_det_box_thresh', settings.text_det_box_thresh)
text_unclip = custom_params.get('text_det_unclip_ratio', settings.text_det_unclip_ratio)
# Check if we need to recreate the engine due to different model
current_model = getattr(self, '_current_layout_model', None)
if self.structure_engine is not None and layout_model and layout_model != current_model:
logger.info(f"Layout model changed from {current_model} to {layout_model}, recreating engine")
self.structure_engine = None # Force recreation
logger.info(f"PP-StructureV3 config: table={use_table}, formula={use_formula}, chart={use_chart}")
logger.info(f"Layout config: threshold={layout_threshold}, nms={layout_nms}, merge={layout_merge}, unclip={layout_unclip}")
logger.info(f"Text detection: thresh={text_thresh}, box_thresh={text_box_thresh}, unclip={text_unclip}")
# Create temporary engine with custom params (not cached)
custom_engine = PPStructureV3(
use_doc_orientation_classify=False,
use_doc_unwarping=False,
use_textline_orientation=False,
use_table_recognition=use_table,
use_formula_recognition=use_formula,
use_chart_recognition=use_chart,
layout_threshold=layout_threshold,
layout_nms=layout_nms,
layout_unclip_ratio=layout_unclip,
layout_merge_bboxes_mode=layout_merge,
text_det_thresh=text_thresh,
text_det_box_thresh=text_box_thresh,
text_det_unclip_ratio=text_unclip,
)
logger.info(f"PP-StructureV3 engine with custom params ready (PaddlePaddle {paddle.__version__}, {'GPU' if self.use_gpu else 'CPU'} mode)")
# Check GPU memory after loading
if self.use_gpu and settings.enable_memory_optimization:
self._check_gpu_memory_usage()
return custom_engine
except Exception as e:
logger.error(f"Failed to create PP-StructureV3 engine with custom params: {e}")
# Fall back to default cached engine
logger.warning("Falling back to default cached engine")
custom_params = None # Clear custom params to use cached engine
# Use cached default engine
# Use cached engine or create new one
if self.structure_engine is None:
logger.info(f"Initializing PP-StructureV3 engine (GPU: {self.use_gpu})")
@@ -524,28 +505,51 @@ class OCRService:
text_box_thresh = settings.text_det_box_thresh
text_unclip = settings.text_det_unclip_ratio
# Layout model configuration:
# - If use_publaynet_default: don't specify any model (use PubLayNet default)
# - If resolved_model_name: use the specified model
# - Otherwise: use config default
if use_publaynet_default:
layout_model_name = None # Explicitly no model = PubLayNet default
elif resolved_model_name:
layout_model_name = resolved_model_name
else:
layout_model_name = settings.layout_detection_model_name
layout_model_dir = settings.layout_detection_model_dir
logger.info(f"PP-StructureV3 config: table={use_table}, formula={use_formula}, chart={use_chart}")
logger.info(f"Layout model: name={layout_model_name}, dir={layout_model_dir}")
logger.info(f"Layout config: threshold={layout_threshold}, nms={layout_nms}, merge={layout_merge}, unclip={layout_unclip}")
logger.info(f"Text detection: thresh={text_thresh}, box_thresh={text_box_thresh}, unclip={text_unclip}")
self.structure_engine = PPStructureV3(
use_doc_orientation_classify=False,
use_doc_unwarping=False,
use_textline_orientation=False,
use_table_recognition=use_table,
use_formula_recognition=use_formula,
use_chart_recognition=use_chart,
layout_threshold=layout_threshold,
layout_nms=layout_nms,
layout_unclip_ratio=layout_unclip,
layout_merge_bboxes_mode=layout_merge, # Use 'small' to minimize merging
text_det_thresh=text_thresh,
text_det_box_thresh=text_box_thresh,
text_det_unclip_ratio=text_unclip,
)
# Build PPStructureV3 kwargs
pp_kwargs = {
'use_doc_orientation_classify': False,
'use_doc_unwarping': False,
'use_textline_orientation': False,
'use_table_recognition': use_table,
'use_formula_recognition': use_formula,
'use_chart_recognition': use_chart,
'layout_threshold': layout_threshold,
'layout_nms': layout_nms,
'layout_unclip_ratio': layout_unclip,
'layout_merge_bboxes_mode': layout_merge,
'text_det_thresh': text_thresh,
'text_det_box_thresh': text_box_thresh,
'text_det_unclip_ratio': text_unclip,
}
# Add layout model configuration if specified
if layout_model_name:
pp_kwargs['layout_detection_model_name'] = layout_model_name
if layout_model_dir:
pp_kwargs['layout_detection_model_dir'] = layout_model_dir
self.structure_engine = PPStructureV3(**pp_kwargs)
# Track model loading for cache management
self._model_last_used['structure'] = datetime.now()
self._current_layout_model = layout_model # Track current model for recreation check
logger.info(f"PP-StructureV3 engine ready (PaddlePaddle {paddle.__version__}, {'GPU' if self.use_gpu else 'CPU'} mode)")
@@ -565,17 +569,27 @@ class OCRService:
use_formula = settings.enable_formula_recognition
use_table = settings.enable_table_recognition
layout_threshold = settings.layout_detection_threshold
layout_model_name = settings.layout_detection_model_name
layout_model_dir = settings.layout_detection_model_dir
self.structure_engine = PPStructureV3(
use_doc_orientation_classify=False,
use_doc_unwarping=False,
use_textline_orientation=False,
use_table_recognition=use_table,
use_formula_recognition=use_formula,
use_chart_recognition=use_chart,
layout_threshold=layout_threshold,
)
logger.info("PP-StructureV3 engine ready (CPU mode - fallback)")
# Build CPU fallback kwargs
cpu_kwargs = {
'use_doc_orientation_classify': False,
'use_doc_unwarping': False,
'use_textline_orientation': False,
'use_table_recognition': use_table,
'use_formula_recognition': use_formula,
'use_chart_recognition': use_chart,
'layout_threshold': layout_threshold,
}
if layout_model_name:
cpu_kwargs['layout_detection_model_name'] = layout_model_name
if layout_model_dir:
cpu_kwargs['layout_detection_model_dir'] = layout_model_dir
self.structure_engine = PPStructureV3(**cpu_kwargs)
self._current_layout_model = layout_model # Track current model for recreation check
logger.info(f"PP-StructureV3 engine ready (CPU mode - fallback, layout_model={layout_model_name})")
else:
raise
@@ -813,7 +827,7 @@ class OCRService:
confidence_threshold: Optional[float] = None,
output_dir: Optional[Path] = None,
current_page: int = 0,
pp_structure_params: Optional[Dict[str, any]] = None
layout_model: Optional[str] = None
) -> Dict:
"""
Process single image with OCR and layout analysis
@@ -825,7 +839,7 @@ class OCRService:
confidence_threshold: Minimum confidence threshold (uses default if None)
output_dir: Optional output directory for saving extracted images
current_page: Current page number (0-based) for multi-page documents
pp_structure_params: Optional custom PP-StructureV3 parameters
layout_model: Layout detection model ('chinese', 'default', 'cdla')
Returns:
Dictionary with OCR results and metadata
@@ -894,7 +908,7 @@ class OCRService:
confidence_threshold=confidence_threshold,
output_dir=output_dir,
current_page=page_num - 1, # Convert to 0-based page number for layout data
pp_structure_params=pp_structure_params
layout_model=layout_model
)
# Accumulate results
@@ -1040,7 +1054,7 @@ class OCRService:
image_path,
output_dir=output_dir,
current_page=current_page,
pp_structure_params=pp_structure_params
layout_model=layout_model
)
# Generate Markdown
@@ -1078,6 +1092,38 @@ class OCRService:
'height': ocr_height
}]
# Generate PP-StructureV3 debug outputs if enabled
if settings.pp_structure_debug_enabled and output_dir:
try:
from app.services.pp_structure_debug import PPStructureDebug
debug_service = PPStructureDebug(output_dir)
# Save raw results as JSON
debug_service.save_raw_results(
pp_structure_results={
'elements': layout_data.get('elements', []),
'total_elements': layout_data.get('total_elements', 0),
'element_types': layout_data.get('element_types', {}),
'reading_order': layout_data.get('reading_order', []),
'enhanced': True,
'has_parsing_res_list': True
},
raw_ocr_regions=text_regions,
filename_prefix=image_path.stem
)
# Generate visualization if enabled
if settings.pp_structure_debug_visualization:
debug_service.generate_visualization(
image_path=image_path,
pp_structure_elements=layout_data.get('elements', []),
raw_ocr_regions=text_regions,
filename_prefix=image_path.stem
)
logger.info(f"Generated PP-StructureV3 debug outputs for {image_path.name}")
except Exception as debug_error:
logger.warning(f"Failed to generate debug outputs: {debug_error}")
logger.info(
f"OCR completed: {image_path.name} - "
f"{len(text_regions)} regions, "
@@ -1164,7 +1210,7 @@ class OCRService:
image_path: Path,
output_dir: Optional[Path] = None,
current_page: int = 0,
pp_structure_params: Optional[Dict[str, any]] = None
layout_model: Optional[str] = None
) -> Tuple[Optional[Dict], List[Dict]]:
"""
Analyze document layout using PP-StructureV3 with enhanced element extraction
@@ -1173,7 +1219,7 @@ class OCRService:
image_path: Path to image file
output_dir: Optional output directory for saving extracted images (defaults to image_path.parent)
current_page: Current page number (0-based) for multi-page documents
pp_structure_params: Optional custom PP-StructureV3 parameters
layout_model: Layout detection model ('chinese', 'default', 'cdla')
Returns:
Tuple of (layout_data, images_metadata)
@@ -1191,7 +1237,7 @@ class OCRService:
f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU'}"
)
structure_engine = self._ensure_structure_engine(pp_structure_params)
structure_engine = self._ensure_structure_engine(layout_model)
# Try enhanced processing first
try:
@@ -1425,7 +1471,7 @@ class OCRService:
confidence_threshold: Optional[float] = None,
output_dir: Optional[Path] = None,
force_track: Optional[str] = None,
pp_structure_params: Optional[Dict[str, any]] = None
layout_model: Optional[str] = None
) -> Union[UnifiedDocument, Dict]:
"""
Process document using dual-track approach.
@@ -1437,7 +1483,7 @@ class OCRService:
confidence_threshold: Minimum confidence threshold
output_dir: Optional output directory for extracted images
force_track: Force specific track ("ocr" or "direct"), None for auto-detection
pp_structure_params: Optional custom PP-StructureV3 parameters (used for OCR track only)
layout_model: Layout detection model ('chinese', 'default', 'cdla') (used for OCR track only)
Returns:
UnifiedDocument if dual-track is enabled, Dict otherwise
@@ -1445,7 +1491,7 @@ class OCRService:
if not self.dual_track_enabled:
# Fallback to traditional OCR processing
return self.process_file_traditional(
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
file_path, lang, detect_layout, confidence_threshold, output_dir, layout_model
)
start_time = datetime.now()
@@ -1517,7 +1563,7 @@ class OCRService:
ocr_result = self.process_file_traditional(
actual_file_path, lang, detect_layout=True,
confidence_threshold=confidence_threshold,
output_dir=output_dir, pp_structure_params=pp_structure_params
output_dir=output_dir, layout_model=layout_model
)
# Convert OCR result to extract images
@@ -1550,7 +1596,7 @@ class OCRService:
# Use OCR for scanned documents, images, etc.
logger.info("Using OCR track (PaddleOCR)")
ocr_result = self.process_file_traditional(
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
file_path, lang, detect_layout, confidence_threshold, output_dir, layout_model
)
# Convert OCR result to UnifiedDocument using the converter
@@ -1580,7 +1626,7 @@ class OCRService:
logger.error(f"Error in dual-track processing: {e}")
# Fallback to traditional OCR
return self.process_file_traditional(
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
file_path, lang, detect_layout, confidence_threshold, output_dir, layout_model
)
def _merge_ocr_images_into_direct(
@@ -1659,7 +1705,7 @@ class OCRService:
detect_layout: bool = True,
confidence_threshold: Optional[float] = None,
output_dir: Optional[Path] = None,
pp_structure_params: Optional[Dict[str, any]] = None
layout_model: Optional[str] = None
) -> Dict:
"""
Traditional OCR processing (legacy method).
@@ -1670,7 +1716,7 @@ class OCRService:
detect_layout: Whether to perform layout analysis
confidence_threshold: Minimum confidence threshold
output_dir: Optional output directory
pp_structure_params: Optional custom PP-StructureV3 parameters
layout_model: Layout detection model ('chinese', 'default', 'cdla')
Returns:
Dictionary with OCR results in legacy format
@@ -1683,7 +1729,7 @@ class OCRService:
all_results = []
for i, image_path in enumerate(image_paths):
result = self.process_image(
image_path, lang, detect_layout, confidence_threshold, output_dir, i, pp_structure_params
image_path, lang, detect_layout, confidence_threshold, output_dir, i, layout_model
)
all_results.append(result)
@@ -1699,7 +1745,7 @@ class OCRService:
else:
# Single image or other file
return self.process_image(
file_path, lang, detect_layout, confidence_threshold, output_dir, 0, pp_structure_params
file_path, lang, detect_layout, confidence_threshold, output_dir, 0, layout_model
)
def _combine_results(self, results: List[Dict]) -> Dict:
@@ -1784,7 +1830,7 @@ class OCRService:
output_dir: Optional[Path] = None,
use_dual_track: bool = True,
force_track: Optional[str] = None,
pp_structure_params: Optional[Dict[str, any]] = None
layout_model: Optional[str] = None
) -> Union[UnifiedDocument, Dict]:
"""
Main processing method with dual-track support.
@@ -1797,7 +1843,7 @@ class OCRService:
output_dir: Optional output directory
use_dual_track: Whether to use dual-track processing (default True)
force_track: Force specific track ("ocr" or "direct")
pp_structure_params: Optional custom PP-StructureV3 parameters (used for OCR track only)
layout_model: Layout detection model ('chinese', 'default', 'cdla') (used for OCR track only)
Returns:
UnifiedDocument if dual-track is enabled and use_dual_track=True,
@@ -1809,12 +1855,12 @@ class OCRService:
if (use_dual_track or force_track) and self.dual_track_enabled:
# Use dual-track processing (or forced track)
return self.process_with_dual_track(
file_path, lang, detect_layout, confidence_threshold, output_dir, force_track, pp_structure_params
file_path, lang, detect_layout, confidence_threshold, output_dir, force_track, layout_model
)
else:
# Use traditional OCR processing (no force_track support)
return self.process_file_traditional(
file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params
file_path, lang, detect_layout, confidence_threshold, output_dir, layout_model
)
def process_legacy(

View File

@@ -3,6 +3,9 @@ OCR to UnifiedDocument Converter
Converts PP-StructureV3 OCR results to UnifiedDocument format, preserving
all structure information and metadata.
Includes gap filling support to supplement PP-StructureV3 output with raw OCR
regions when significant content loss is detected.
"""
import logging
@@ -16,10 +19,165 @@ from app.models.unified_document import (
BoundingBox, StyleInfo, TableData, ElementType,
ProcessingTrack, TableCell, Dimensions
)
from app.services.gap_filling_service import GapFillingService
logger = logging.getLogger(__name__)
def trim_empty_columns(table_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Remove empty columns from a table dictionary.
A column is considered empty if ALL cells in that column have content that is
empty or whitespace-only (using .strip() to determine emptiness).
This function:
1. Identifies columns where every cell's content is empty/whitespace
2. Removes identified empty columns
3. Updates cols/columns value
4. Recalculates each cell's col index
5. Adjusts col_span when spans cross removed columns
6. Removes cells entirely when their complete span falls within removed columns
7. Preserves original bbox (no layout drift)
Args:
table_dict: Table dictionary with keys: rows, cols/columns, cells
Returns:
Cleaned table dictionary with empty columns removed
"""
cells = table_dict.get('cells', [])
if not cells:
return table_dict
# Get original column count
original_cols = table_dict.get('cols', table_dict.get('columns', 0))
if original_cols == 0:
# Calculate from cells if not provided
max_col = 0
for cell in cells:
cell_col = cell.get('col', 0) if isinstance(cell, dict) else getattr(cell, 'col', 0)
cell_span = cell.get('col_span', 1) if isinstance(cell, dict) else getattr(cell, 'col_span', 1)
max_col = max(max_col, cell_col + cell_span)
original_cols = max_col
if original_cols == 0:
return table_dict
# Build a map: column_index -> list of cell contents
# For cells with col_span > 1, we only check their primary column
column_contents: Dict[int, List[str]] = {i: [] for i in range(original_cols)}
for cell in cells:
if isinstance(cell, dict):
col = cell.get('col', 0)
col_span = cell.get('col_span', 1)
content = cell.get('content', '')
else:
col = getattr(cell, 'col', 0)
col_span = getattr(cell, 'col_span', 1)
content = getattr(cell, 'content', '')
# Mark content for each column this cell spans
for c in range(col, min(col + col_span, original_cols)):
if c in column_contents:
column_contents[c].append(str(content).strip() if content else '')
# Identify empty columns (all content is empty/whitespace)
empty_columns = set()
for col_idx, contents in column_contents.items():
# A column is empty if ALL cells in it have empty content
# Note: If a column has no cells at all, it's considered empty
if all(c == '' for c in contents):
empty_columns.add(col_idx)
if not empty_columns:
# No empty columns to remove, just ensure cols is set
result = dict(table_dict)
if result.get('cols', result.get('columns', 0)) == 0:
result['cols'] = original_cols
if 'columns' in result:
result['columns'] = original_cols
return result
logger.debug(f"Removing empty columns: {sorted(empty_columns)} from table with {original_cols} cols")
# Build column mapping: old_col -> new_col (or None if removed)
col_mapping: Dict[int, Optional[int]] = {}
new_col = 0
for old_col in range(original_cols):
if old_col in empty_columns:
col_mapping[old_col] = None
else:
col_mapping[old_col] = new_col
new_col += 1
new_cols = new_col
# Process cells
new_cells = []
for cell in cells:
if isinstance(cell, dict):
old_col = cell.get('col', 0)
old_col_span = cell.get('col_span', 1)
else:
old_col = getattr(cell, 'col', 0)
old_col_span = getattr(cell, 'col_span', 1)
# Calculate new col and col_span
# Find the first non-removed column in this cell's span
new_start_col = None
new_end_col = None
for c in range(old_col, min(old_col + old_col_span, original_cols)):
mapped = col_mapping.get(c)
if mapped is not None:
if new_start_col is None:
new_start_col = mapped
new_end_col = mapped
# If entire span falls within removed columns, skip this cell
if new_start_col is None:
logger.debug(f"Removing cell at row={cell.get('row', 0) if isinstance(cell, dict) else cell.row}, "
f"col={old_col} (entire span in removed columns)")
continue
new_col_span = new_end_col - new_start_col + 1
# Create new cell
if isinstance(cell, dict):
new_cell = dict(cell)
new_cell['col'] = new_start_col
new_cell['col_span'] = new_col_span
else:
# Handle TableCell objects
new_cell = {
'row': cell.row,
'col': new_start_col,
'row_span': cell.row_span,
'col_span': new_col_span,
'content': cell.content
}
if hasattr(cell, 'bbox') and cell.bbox:
new_cell['bbox'] = cell.bbox
if hasattr(cell, 'style') and cell.style:
new_cell['style'] = cell.style
new_cells.append(new_cell)
# Build result
result = dict(table_dict)
result['cells'] = new_cells
result['cols'] = new_cols
if 'columns' in result:
result['columns'] = new_cols
logger.info(f"Trimmed table: {original_cols} -> {new_cols} columns, "
f"{len(cells)} -> {len(new_cells)} cells")
return result
class OCRToUnifiedConverter:
"""
Converter for transforming PP-StructureV3 OCR results to UnifiedDocument format.
@@ -30,11 +188,19 @@ class OCRToUnifiedConverter:
- Multi-page document assembly
- Metadata preservation
- Structure relationship mapping
- Gap filling with raw OCR regions (when PP-StructureV3 misses content)
"""
def __init__(self):
"""Initialize the converter."""
def __init__(self, enable_gap_filling: bool = True):
"""
Initialize the converter.
Args:
enable_gap_filling: Whether to enable gap filling with raw OCR regions
"""
self.element_counter = 0
self.gap_filling_service = GapFillingService() if enable_gap_filling else None
self.gap_filling_stats: Dict[str, Any] = {}
def convert(
self,
@@ -120,13 +286,21 @@ class OCRToUnifiedConverter:
Extract pages from OCR results.
Handles both enhanced PP-StructureV3 results (with parsing_res_list)
and traditional markdown results.
and traditional markdown results. Applies gap filling when enabled.
"""
pages = []
# Extract raw OCR text regions for gap filling
raw_text_regions = ocr_results.get('text_regions', [])
ocr_dimensions = ocr_results.get('ocr_dimensions', {})
# Check if we have enhanced results from PPStructureEnhanced
if 'enhanced_results' in ocr_results:
pages = self._extract_from_enhanced_results(ocr_results['enhanced_results'])
pages = self._extract_from_enhanced_results(
ocr_results['enhanced_results'],
raw_text_regions=raw_text_regions,
ocr_dimensions=ocr_dimensions
)
# Check for traditional OCR results with text_regions at top level (from process_file_traditional)
elif 'text_regions' in ocr_results:
pages = self._extract_from_traditional_ocr(ocr_results)
@@ -143,9 +317,21 @@ class OCRToUnifiedConverter:
def _extract_from_enhanced_results(
self,
enhanced_results: List[Dict[str, Any]]
enhanced_results: List[Dict[str, Any]],
raw_text_regions: Optional[List[Dict[str, Any]]] = None,
ocr_dimensions: Optional[Dict[str, Any]] = None
) -> List[Page]:
"""Extract pages from enhanced PP-StructureV3 results."""
"""
Extract pages from enhanced PP-StructureV3 results.
Applies gap filling when enabled to supplement PP-StructureV3 output
with raw OCR regions that were not detected by the layout model.
Args:
enhanced_results: PP-StructureV3 enhanced results
raw_text_regions: Raw OCR text regions for gap filling
ocr_dimensions: OCR image dimensions for coordinate alignment
"""
pages = []
for page_idx, page_result in enumerate(enhanced_results):
@@ -158,15 +344,52 @@ class OCRToUnifiedConverter:
if element:
elements.append(element)
# Get page dimensions
pp_dimensions = Dimensions(
width=page_result.get('width', 0),
height=page_result.get('height', 0)
)
# Apply gap filling if enabled and raw regions available
if self.gap_filling_service and raw_text_regions:
# Filter raw regions for current page
page_raw_regions = [
r for r in raw_text_regions
if r.get('page', 0) == page_idx or r.get('page', 1) == page_idx + 1
]
if page_raw_regions:
supplemented, stats = self.gap_filling_service.fill_gaps(
raw_ocr_regions=page_raw_regions,
pp_structure_elements=elements,
page_number=page_idx + 1,
ocr_dimensions=ocr_dimensions,
pp_dimensions=pp_dimensions
)
# Store statistics
self.gap_filling_stats[f'page_{page_idx + 1}'] = stats
if supplemented:
logger.info(
f"Page {page_idx + 1}: Gap filling added {len(supplemented)} elements "
f"(coverage: {stats.get('coverage_ratio', 0):.2%})"
)
elements.extend(supplemented)
# Recalculate reading order for combined elements
reading_order = self.gap_filling_service.recalculate_reading_order(elements)
page_result['reading_order'] = reading_order
# Create page
page = Page(
page_number=page_idx + 1,
dimensions=Dimensions(
width=page_result.get('width', 0),
height=page_result.get('height', 0)
),
dimensions=pp_dimensions,
elements=elements,
metadata={'reading_order': page_result.get('reading_order', [])}
metadata={
'reading_order': page_result.get('reading_order', []),
'gap_filling': self.gap_filling_stats.get(f'page_{page_idx + 1}', {})
}
)
pages.append(page)
@@ -500,6 +723,9 @@ class OCRToUnifiedConverter:
) -> Optional[DocumentElement]:
"""Convert table data to DocumentElement."""
try:
# Clean up empty columns before building TableData
table_dict = trim_empty_columns(table_dict)
# Extract bbox
bbox_data = table_dict.get('bbox', [0, 0, 0, 0])
bbox = BoundingBox(
@@ -587,14 +813,22 @@ class OCRToUnifiedConverter:
cells = []
headers = []
rows = table.find_all('tr')
num_rows = len(rows)
# Track actual column positions accounting for rowspan/colspan
# This is a simplified approach - complex spanning may need enhancement
# First pass: calculate total columns by finding max column extent
# Track cells that span multiple rows: occupied[row][col] = True
occupied: Dict[int, Dict[int, bool]] = {r: {} for r in range(num_rows)}
# Parse all cells with proper rowspan/colspan handling
for row_idx, row in enumerate(rows):
row_cells = row.find_all(['td', 'th'])
col_idx = 0
for cell in row_cells:
# Skip columns that are occupied by rowspan from previous rows
while occupied[row_idx].get(col_idx, False):
col_idx += 1
cell_content = cell.get_text(strip=True)
rowspan = int(cell.get('rowspan', 1))
colspan = int(cell.get('colspan', 1))
@@ -611,26 +845,66 @@ class OCRToUnifiedConverter:
if cell.name == 'th' or row_idx == 0:
headers.append(cell_content)
# Mark cells as occupied for rowspan/colspan
for r in range(row_idx, min(row_idx + rowspan, num_rows)):
for c in range(col_idx, col_idx + colspan):
if r not in occupied:
occupied[r] = {}
occupied[r][c] = True
# Advance column index by colspan
col_idx += colspan
# Calculate actual dimensions
num_rows = len(rows)
num_cols = max(
sum(int(cell.get('colspan', 1)) for cell in row.find_all(['td', 'th']))
for row in rows
) if rows else 0
# Calculate actual column count from occupied cells
num_cols = 0
for r in range(num_rows):
if occupied[r]:
max_col_in_row = max(occupied[r].keys()) + 1
num_cols = max(num_cols, max_col_in_row)
logger.debug(
f"Parsed HTML table: {num_rows} rows, {num_cols} cols, {len(cells)} cells"
)
# Build table dict for cleanup
table_dict = {
'rows': num_rows,
'cols': num_cols,
'cells': [
{
'row': c.row,
'col': c.col,
'row_span': c.row_span,
'col_span': c.col_span,
'content': c.content
}
for c in cells
],
'headers': headers if headers else None,
'caption': extracted_text if extracted_text else None
}
# Clean up empty columns
table_dict = trim_empty_columns(table_dict)
# Convert cleaned cells back to TableCell objects
cleaned_cells = [
TableCell(
row=c['row'],
col=c['col'],
row_span=c.get('row_span', 1),
col_span=c.get('col_span', 1),
content=c.get('content', '')
)
for c in table_dict.get('cells', [])
]
return TableData(
rows=num_rows,
cols=num_cols,
cells=cells,
headers=headers if headers else None,
caption=extracted_text if extracted_text else None
rows=table_dict.get('rows', num_rows),
cols=table_dict.get('cols', num_cols),
cells=cleaned_cells,
headers=table_dict.get('headers'),
caption=table_dict.get('caption')
)
except ImportError:

View File

@@ -0,0 +1,344 @@
"""
PP-StructureV3 Debug Service
Provides debugging tools for visualizing and saving PP-StructureV3 results:
- Save raw results as JSON for inspection
- Generate visualization images showing detected bboxes
- Compare raw OCR regions with PP-StructureV3 elements
"""
import json
import logging
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
logger = logging.getLogger(__name__)
# Color palette for different element types (RGB)
ELEMENT_COLORS: Dict[str, Tuple[int, int, int]] = {
'text': (0, 128, 0), # Green
'title': (0, 0, 255), # Blue
'table': (255, 0, 0), # Red
'figure': (255, 165, 0), # Orange
'image': (255, 165, 0), # Orange
'header': (128, 0, 128), # Purple
'footer': (128, 0, 128), # Purple
'equation': (0, 255, 255), # Cyan
'chart': (255, 192, 203), # Pink
'list': (139, 69, 19), # Brown
'reference': (128, 128, 128), # Gray
'default': (255, 0, 255), # Magenta for unknown types
}
# Color for raw OCR regions
RAW_OCR_COLOR = (255, 215, 0) # Gold
class PPStructureDebug:
"""Debug service for PP-StructureV3 analysis results."""
def __init__(self, output_dir: Path):
"""
Initialize debug service.
Args:
output_dir: Directory to save debug outputs
"""
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
def save_raw_results(
self,
pp_structure_results: Dict[str, Any],
raw_ocr_regions: List[Dict[str, Any]],
filename_prefix: str = "debug"
) -> Dict[str, Path]:
"""
Save raw PP-StructureV3 results and OCR regions as JSON files.
Args:
pp_structure_results: Raw PP-StructureV3 analysis results
raw_ocr_regions: Raw OCR text regions
filename_prefix: Prefix for output files
Returns:
Dictionary with paths to saved files
"""
saved_files = {}
# Save PP-StructureV3 results
pp_json_path = self.output_dir / f"{filename_prefix}_pp_structure_raw.json"
try:
# Convert any non-serializable types
serializable_results = self._make_serializable(pp_structure_results)
with open(pp_json_path, 'w', encoding='utf-8') as f:
json.dump(serializable_results, f, ensure_ascii=False, indent=2)
saved_files['pp_structure'] = pp_json_path
logger.info(f"Saved PP-StructureV3 raw results to {pp_json_path}")
except Exception as e:
logger.error(f"Failed to save PP-StructureV3 results: {e}")
# Save raw OCR regions
ocr_json_path = self.output_dir / f"{filename_prefix}_raw_ocr_regions.json"
try:
serializable_ocr = self._make_serializable(raw_ocr_regions)
with open(ocr_json_path, 'w', encoding='utf-8') as f:
json.dump(serializable_ocr, f, ensure_ascii=False, indent=2)
saved_files['raw_ocr'] = ocr_json_path
logger.info(f"Saved raw OCR regions to {ocr_json_path}")
except Exception as e:
logger.error(f"Failed to save raw OCR regions: {e}")
# Save summary comparison
summary_path = self.output_dir / f"{filename_prefix}_debug_summary.json"
try:
summary = self._generate_summary(pp_structure_results, raw_ocr_regions)
with open(summary_path, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
saved_files['summary'] = summary_path
logger.info(f"Saved debug summary to {summary_path}")
except Exception as e:
logger.error(f"Failed to save debug summary: {e}")
return saved_files
def generate_visualization(
self,
image_path: Path,
pp_structure_elements: List[Dict[str, Any]],
raw_ocr_regions: Optional[List[Dict[str, Any]]] = None,
filename_prefix: str = "debug",
show_labels: bool = True,
show_raw_ocr: bool = True
) -> Optional[Path]:
"""
Generate visualization image showing detected elements.
Args:
image_path: Path to original image
pp_structure_elements: PP-StructureV3 detected elements
raw_ocr_regions: Optional raw OCR regions to overlay
filename_prefix: Prefix for output file
show_labels: Whether to show element type labels
show_raw_ocr: Whether to show raw OCR regions
Returns:
Path to generated visualization image
"""
try:
# Load original image
img = Image.open(image_path)
if img.mode != 'RGB':
img = img.convert('RGB')
# Create copy for drawing
viz_img = img.copy()
draw = ImageDraw.Draw(viz_img)
# Try to load a font, fall back to default
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 10)
except (IOError, OSError):
try:
font = ImageFont.truetype("/home/egg/project/Tool_OCR/backend/fonts/NotoSansSC-Regular.ttf", 14)
small_font = ImageFont.truetype("/home/egg/project/Tool_OCR/backend/fonts/NotoSansSC-Regular.ttf", 10)
except (IOError, OSError):
font = ImageFont.load_default()
small_font = font
# Draw raw OCR regions first (so PP-Structure boxes are on top)
if show_raw_ocr and raw_ocr_regions:
for idx, region in enumerate(raw_ocr_regions):
bbox = self._normalize_bbox(region.get('bbox', []))
if bbox:
# Draw with dashed style simulation (draw thin lines)
x0, y0, x1, y1 = bbox
draw.rectangle([x0, y0, x1, y1], outline=RAW_OCR_COLOR, width=1)
# Add small label
if show_labels:
confidence = region.get('confidence', 0)
label = f"OCR:{confidence:.2f}"
draw.text((x0, y0 - 12), label, fill=RAW_OCR_COLOR, font=small_font)
# Draw PP-StructureV3 elements
for idx, elem in enumerate(pp_structure_elements):
elem_type = elem.get('type', 'default')
if hasattr(elem_type, 'value'):
elem_type = elem_type.value
elem_type = str(elem_type).lower()
color = ELEMENT_COLORS.get(elem_type, ELEMENT_COLORS['default'])
bbox = self._normalize_bbox(elem.get('bbox', []))
if bbox:
x0, y0, x1, y1 = bbox
# Draw thicker rectangle for PP-Structure elements
draw.rectangle([x0, y0, x1, y1], outline=color, width=3)
# Add label
if show_labels:
label = f"{idx}:{elem_type}"
# Draw label background
text_bbox = draw.textbbox((x0, y0 - 18), label, font=font)
draw.rectangle(text_bbox, fill=(255, 255, 255, 200))
draw.text((x0, y0 - 18), label, fill=color, font=font)
# Add legend
self._draw_legend(draw, img.width, font)
# Add image info
info_text = f"PP-Structure: {len(pp_structure_elements)} elements"
if raw_ocr_regions:
info_text += f" | Raw OCR: {len(raw_ocr_regions)} regions"
info_text += f" | Size: {img.width}x{img.height}"
draw.text((10, img.height - 25), info_text, fill=(0, 0, 0), font=font)
# Save visualization
viz_path = self.output_dir / f"{filename_prefix}_pp_structure_viz.png"
viz_img.save(viz_path, 'PNG')
logger.info(f"Saved visualization to {viz_path}")
return viz_path
except Exception as e:
logger.error(f"Failed to generate visualization: {e}")
import traceback
traceback.print_exc()
return None
def _draw_legend(self, draw: ImageDraw, img_width: int, font: ImageFont):
"""Draw a legend showing element type colors."""
legend_x = img_width - 150
legend_y = 10
# Draw legend background
draw.rectangle(
[legend_x - 5, legend_y - 5, img_width - 5, legend_y + len(ELEMENT_COLORS) * 18 + 25],
fill=(255, 255, 255, 230),
outline=(0, 0, 0)
)
draw.text((legend_x, legend_y), "Legend:", fill=(0, 0, 0), font=font)
legend_y += 20
for elem_type, color in ELEMENT_COLORS.items():
if elem_type == 'default':
continue
draw.rectangle([legend_x, legend_y + 2, legend_x + 12, legend_y + 14], fill=color)
draw.text((legend_x + 18, legend_y), elem_type, fill=(0, 0, 0), font=font)
legend_y += 18
# Add raw OCR legend entry
draw.rectangle([legend_x, legend_y + 2, legend_x + 12, legend_y + 14], fill=RAW_OCR_COLOR)
draw.text((legend_x + 18, legend_y), "raw_ocr", fill=(0, 0, 0), font=font)
def _normalize_bbox(self, bbox: Any) -> Optional[Tuple[float, float, float, float]]:
"""Normalize bbox to (x0, y0, x1, y1) format."""
if not bbox:
return None
try:
# Handle nested list format [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
if isinstance(bbox, (list, tuple)) and len(bbox) >= 1:
if isinstance(bbox[0], (list, tuple)):
xs = [pt[0] for pt in bbox if len(pt) >= 2]
ys = [pt[1] for pt in bbox if len(pt) >= 2]
if xs and ys:
return (min(xs), min(ys), max(xs), max(ys))
# Handle flat list [x0, y0, x1, y1]
if isinstance(bbox, (list, tuple)) and len(bbox) == 4:
return (float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3]))
# Handle flat polygon [x1, y1, x2, y2, ...]
if isinstance(bbox, (list, tuple)) and len(bbox) >= 8:
xs = [bbox[i] for i in range(0, len(bbox), 2)]
ys = [bbox[i] for i in range(1, len(bbox), 2)]
return (min(xs), min(ys), max(xs), max(ys))
# Handle dict format
if isinstance(bbox, dict):
return (
float(bbox.get('x0', bbox.get('x_min', 0))),
float(bbox.get('y0', bbox.get('y_min', 0))),
float(bbox.get('x1', bbox.get('x_max', 0))),
float(bbox.get('y1', bbox.get('y_max', 0)))
)
except (TypeError, ValueError, IndexError) as e:
logger.warning(f"Failed to normalize bbox {bbox}: {e}")
return None
def _generate_summary(
self,
pp_structure_results: Dict[str, Any],
raw_ocr_regions: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Generate summary comparing PP-Structure and raw OCR."""
pp_elements = pp_structure_results.get('elements', [])
# Count element types
type_counts = {}
for elem in pp_elements:
elem_type = elem.get('type', 'unknown')
if hasattr(elem_type, 'value'):
elem_type = elem_type.value
type_counts[str(elem_type)] = type_counts.get(str(elem_type), 0) + 1
# Calculate bounding box coverage
pp_bbox_area = 0
ocr_bbox_area = 0
for elem in pp_elements:
bbox = self._normalize_bbox(elem.get('bbox'))
if bbox:
pp_bbox_area += (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
for region in raw_ocr_regions:
bbox = self._normalize_bbox(region.get('bbox'))
if bbox:
ocr_bbox_area += (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
return {
'timestamp': datetime.now().isoformat(),
'pp_structure': {
'total_elements': len(pp_elements),
'element_types': type_counts,
'total_bbox_area': pp_bbox_area,
'has_parsing_res_list': pp_structure_results.get('has_parsing_res_list', False)
},
'raw_ocr': {
'total_regions': len(raw_ocr_regions),
'total_bbox_area': ocr_bbox_area,
'avg_confidence': sum(r.get('confidence', 0) for r in raw_ocr_regions) / len(raw_ocr_regions) if raw_ocr_regions else 0
},
'comparison': {
'element_count_ratio': len(pp_elements) / len(raw_ocr_regions) if raw_ocr_regions else 0,
'area_ratio': pp_bbox_area / ocr_bbox_area if ocr_bbox_area > 0 else 0,
'potential_gap': len(raw_ocr_regions) - len(pp_elements) if raw_ocr_regions else 0
}
}
def _make_serializable(self, obj: Any) -> Any:
"""Convert object to JSON-serializable format."""
if obj is None:
return None
if isinstance(obj, (str, int, float, bool)):
return obj
if isinstance(obj, (list, tuple)):
return [self._make_serializable(item) for item in obj]
if isinstance(obj, dict):
return {str(k): self._make_serializable(v) for k, v in obj.items()}
if hasattr(obj, 'value'):
return obj.value
if hasattr(obj, '__dict__'):
return self._make_serializable(obj.__dict__)
if hasattr(obj, 'tolist'): # numpy array
return obj.tolist()
return str(obj)

View File

@@ -0,0 +1,332 @@
"""
API integration tests for Layout Model Selection feature.
This replaces the deprecated PP-StructureV3 parameter tests.
"""
import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch
from app.main import app
from app.core.database import get_db
from app.models.user import User
from app.models.task import Task, TaskStatus, TaskFile
@pytest.fixture
def client():
"""Create test client"""
return TestClient(app)
@pytest.fixture
def test_user(db_session):
"""Create test user"""
user = User(
email="test@example.com",
hashed_password="test_hash",
is_active=True
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user
@pytest.fixture
def test_task(db_session, test_user):
"""Create test task with uploaded file"""
task = Task(
user_id=test_user.id,
task_id="test-task-123",
filename="test.pdf",
status=TaskStatus.PENDING
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
# Add task file
task_file = TaskFile(
task_id=task.id,
original_name="test.pdf",
stored_path="/tmp/test.pdf",
file_size=1024,
mime_type="application/pdf"
)
db_session.add(task_file)
db_session.commit()
return task
class TestLayoutModelSchema:
"""Test LayoutModel and ProcessingOptions schema validation"""
def test_processing_options_accepts_layout_model(self):
"""Verify ProcessingOptions schema accepts layout_model parameter"""
from app.schemas.task import ProcessingOptions, LayoutModelEnum
options = ProcessingOptions(
use_dual_track=True,
language='ch',
layout_model=LayoutModelEnum.CHINESE
)
assert options.layout_model == LayoutModelEnum.CHINESE
def test_layout_model_enum_values(self):
"""Verify all layout model enum values are valid"""
from app.schemas.task import LayoutModelEnum
assert LayoutModelEnum.CHINESE.value == "chinese"
assert LayoutModelEnum.DEFAULT.value == "default"
assert LayoutModelEnum.CDLA.value == "cdla"
def test_default_layout_model_is_chinese(self):
"""Verify default layout model is 'chinese' for best Chinese document support"""
from app.schemas.task import ProcessingOptions
options = ProcessingOptions()
# Default should be chinese
assert options.layout_model.value == "chinese"
def test_layout_model_string_values_accepted(self):
"""Verify string values are accepted for layout_model"""
from app.schemas.task import ProcessingOptions
# String values should be converted to enum
options = ProcessingOptions(layout_model="default")
assert options.layout_model.value == "default"
options = ProcessingOptions(layout_model="cdla")
assert options.layout_model.value == "cdla"
def test_invalid_layout_model_rejected(self):
"""Verify invalid layout model values are rejected"""
from app.schemas.task import ProcessingOptions
from pydantic import ValidationError
with pytest.raises(ValidationError):
ProcessingOptions(layout_model="invalid_model")
class TestStartTaskEndpoint:
"""Test /tasks/{task_id}/start endpoint with layout_model parameter"""
@patch('app.routers.tasks.process_task_ocr')
def test_start_task_with_layout_model(self, mock_process_ocr, client, test_task, db_session):
"""Verify layout_model is accepted and passed to OCR service"""
# Override get_db dependency
def override_get_db():
try:
yield db_session
finally:
pass
# Override auth dependency
def override_get_current_user():
return test_task.user
app.dependency_overrides[get_db] = override_get_db
from app.core.deps import get_current_user
app.dependency_overrides[get_current_user] = override_get_current_user
# Request body with layout_model
request_body = {
"use_dual_track": True,
"language": "ch",
"layout_model": "chinese"
}
# Make API call
response = client.post(
f"/api/v2/tasks/{test_task.task_id}/start",
json=request_body
)
# Verify response
assert response.status_code == 200
data = response.json()
assert data['status'] == 'processing'
# Verify background task was called with layout_model
mock_process_ocr.assert_called_once()
call_kwargs = mock_process_ocr.call_args[1]
assert 'layout_model' in call_kwargs
assert call_kwargs['layout_model'] == 'chinese'
# Clean up
app.dependency_overrides.clear()
@patch('app.routers.tasks.process_task_ocr')
def test_start_task_with_default_model(self, mock_process_ocr, client, test_task, db_session):
"""Verify 'default' layout model is accepted"""
def override_get_db():
try:
yield db_session
finally:
pass
def override_get_current_user():
return test_task.user
app.dependency_overrides[get_db] = override_get_db
from app.core.deps import get_current_user
app.dependency_overrides[get_current_user] = override_get_current_user
request_body = {
"use_dual_track": True,
"layout_model": "default"
}
response = client.post(
f"/api/v2/tasks/{test_task.task_id}/start",
json=request_body
)
assert response.status_code == 200
mock_process_ocr.assert_called_once()
call_kwargs = mock_process_ocr.call_args[1]
assert call_kwargs['layout_model'] == 'default'
app.dependency_overrides.clear()
@patch('app.routers.tasks.process_task_ocr')
def test_start_task_with_cdla_model(self, mock_process_ocr, client, test_task, db_session):
"""Verify 'cdla' layout model is accepted"""
def override_get_db():
try:
yield db_session
finally:
pass
def override_get_current_user():
return test_task.user
app.dependency_overrides[get_db] = override_get_db
from app.core.deps import get_current_user
app.dependency_overrides[get_current_user] = override_get_current_user
request_body = {
"use_dual_track": True,
"layout_model": "cdla"
}
response = client.post(
f"/api/v2/tasks/{test_task.task_id}/start",
json=request_body
)
assert response.status_code == 200
mock_process_ocr.assert_called_once()
call_kwargs = mock_process_ocr.call_args[1]
assert call_kwargs['layout_model'] == 'cdla'
app.dependency_overrides.clear()
@patch('app.routers.tasks.process_task_ocr')
def test_start_task_without_layout_model_uses_default(self, mock_process_ocr, client, test_task, db_session):
"""Verify task can start without layout_model (uses 'chinese' as default)"""
def override_get_db():
try:
yield db_session
finally:
pass
def override_get_current_user():
return test_task.user
app.dependency_overrides[get_db] = override_get_db
from app.core.deps import get_current_user
app.dependency_overrides[get_current_user] = override_get_current_user
# Request without layout_model
request_body = {
"use_dual_track": True,
"language": "ch"
}
response = client.post(
f"/api/v2/tasks/{test_task.task_id}/start",
json=request_body
)
assert response.status_code == 200
mock_process_ocr.assert_called_once()
call_kwargs = mock_process_ocr.call_args[1]
# layout_model should default to 'chinese'
assert call_kwargs['layout_model'] == 'chinese'
app.dependency_overrides.clear()
def test_start_task_with_invalid_layout_model(self, client, test_task, db_session):
"""Verify invalid layout_model returns 422 validation error"""
def override_get_db():
try:
yield db_session
finally:
pass
def override_get_current_user():
return test_task.user
app.dependency_overrides[get_db] = override_get_db
from app.core.deps import get_current_user
app.dependency_overrides[get_current_user] = override_get_current_user
# Request with invalid layout_model
request_body = {
"use_dual_track": True,
"layout_model": "invalid_model"
}
response = client.post(
f"/api/v2/tasks/{test_task.task_id}/start",
json=request_body
)
# Should return validation error
assert response.status_code == 422
app.dependency_overrides.clear()
class TestOpenAPISchema:
"""Test OpenAPI schema includes layout_model parameter"""
def test_openapi_schema_includes_layout_model(self, client):
"""Verify OpenAPI schema documents layout_model parameter"""
response = client.get("/openapi.json")
assert response.status_code == 200
schema = response.json()
# Check LayoutModelEnum schema exists
assert 'LayoutModelEnum' in schema['components']['schemas']
model_schema = schema['components']['schemas']['LayoutModelEnum']
# Verify all 3 model options are documented
assert 'chinese' in model_schema['enum']
assert 'default' in model_schema['enum']
assert 'cdla' in model_schema['enum']
# Verify ProcessingOptions includes layout_model
options_schema = schema['components']['schemas']['ProcessingOptions']
assert 'layout_model' in options_schema['properties']
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,244 @@
"""
Unit tests for Layout Model Selection feature in OCR Service.
This replaces the deprecated PP-StructureV3 parameter tests.
"""
import pytest
import sys
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
# Mock all external dependencies before importing OCRService
sys.modules['paddleocr'] = MagicMock()
sys.modules['PIL'] = MagicMock()
sys.modules['pdf2image'] = MagicMock()
# Mock paddle with version attribute
paddle_mock = MagicMock()
paddle_mock.__version__ = '2.5.0'
paddle_mock.device.get_device.return_value = 'cpu'
paddle_mock.device.get_available_device.return_value = 'cpu'
sys.modules['paddle'] = paddle_mock
# Mock torch
torch_mock = MagicMock()
torch_mock.cuda.is_available.return_value = False
sys.modules['torch'] = torch_mock
from app.services.ocr_service import OCRService, LAYOUT_MODEL_MAPPING, _USE_PUBLAYNET_DEFAULT
from app.core.config import settings
class TestLayoutModelMapping:
"""Test layout model name mapping"""
def test_layout_model_mapping_exists(self):
"""Verify LAYOUT_MODEL_MAPPING constant exists and has correct values"""
assert 'chinese' in LAYOUT_MODEL_MAPPING
assert 'default' in LAYOUT_MODEL_MAPPING
assert 'cdla' in LAYOUT_MODEL_MAPPING
def test_chinese_model_maps_to_pp_doclayout(self):
"""Verify 'chinese' maps to PP-DocLayout-S"""
assert LAYOUT_MODEL_MAPPING['chinese'] == 'PP-DocLayout-S'
def test_default_model_maps_to_publaynet_sentinel(self):
"""Verify 'default' maps to sentinel value for PubLayNet default"""
# The 'default' model uses a sentinel value that signals "use PubLayNet default (no custom model)"
assert LAYOUT_MODEL_MAPPING['default'] == _USE_PUBLAYNET_DEFAULT
def test_cdla_model_maps_to_picodet(self):
"""Verify 'cdla' maps to picodet_lcnet_x1_0_fgd_layout_cdla"""
assert LAYOUT_MODEL_MAPPING['cdla'] == 'picodet_lcnet_x1_0_fgd_layout_cdla'
class TestLayoutModelEngine:
"""Test engine creation with different layout models"""
def test_chinese_model_creates_engine_with_pp_doclayout(self):
"""Verify 'chinese' layout model uses PP-DocLayout-S"""
ocr_service = OCRService()
with patch.object(ocr_service, 'structure_engine', None):
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
mock_engine = Mock()
mock_ppstructure.return_value = mock_engine
engine = ocr_service._ensure_structure_engine(layout_model='chinese')
mock_ppstructure.assert_called_once()
call_kwargs = mock_ppstructure.call_args[1]
assert call_kwargs.get('layout_detection_model_name') == 'PP-DocLayout-S'
def test_default_model_creates_engine_without_model_name(self):
"""Verify 'default' layout model does not specify model name (uses default)"""
ocr_service = OCRService()
with patch.object(ocr_service, 'structure_engine', None):
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
mock_engine = Mock()
mock_ppstructure.return_value = mock_engine
engine = ocr_service._ensure_structure_engine(layout_model='default')
mock_ppstructure.assert_called_once()
call_kwargs = mock_ppstructure.call_args[1]
# For 'default', layout_detection_model_name should be None or not set
assert call_kwargs.get('layout_detection_model_name') is None
def test_cdla_model_creates_engine_with_picodet(self):
"""Verify 'cdla' layout model uses picodet_lcnet_x1_0_fgd_layout_cdla"""
ocr_service = OCRService()
with patch.object(ocr_service, 'structure_engine', None):
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
mock_engine = Mock()
mock_ppstructure.return_value = mock_engine
engine = ocr_service._ensure_structure_engine(layout_model='cdla')
mock_ppstructure.assert_called_once()
call_kwargs = mock_ppstructure.call_args[1]
assert call_kwargs.get('layout_detection_model_name') == 'picodet_lcnet_x1_0_fgd_layout_cdla'
def test_none_layout_model_uses_chinese_default(self):
"""Verify None layout_model defaults to 'chinese' model"""
ocr_service = OCRService()
with patch.object(ocr_service, 'structure_engine', None):
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
mock_engine = Mock()
mock_ppstructure.return_value = mock_engine
# Pass None for layout_model
engine = ocr_service._ensure_structure_engine(layout_model=None)
mock_ppstructure.assert_called_once()
call_kwargs = mock_ppstructure.call_args[1]
# Should use 'chinese' model as default
assert call_kwargs.get('layout_detection_model_name') == 'PP-DocLayout-S'
class TestLayoutModelCaching:
"""Test engine caching behavior with layout models"""
def test_same_layout_model_uses_cached_engine(self):
"""Verify same layout model reuses cached engine"""
ocr_service = OCRService()
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
mock_engine = Mock()
mock_ppstructure.return_value = mock_engine
# First call with 'chinese'
engine1 = ocr_service._ensure_structure_engine(layout_model='chinese')
# Second call with same model should use cache
engine2 = ocr_service._ensure_structure_engine(layout_model='chinese')
# Verify only one engine was created
assert mock_ppstructure.call_count == 1
assert engine1 is engine2
def test_different_layout_model_creates_new_engine(self):
"""Verify different layout model creates new engine"""
ocr_service = OCRService()
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
mock_engine1 = Mock()
mock_engine2 = Mock()
mock_ppstructure.side_effect = [mock_engine1, mock_engine2]
# First call with 'chinese'
engine1 = ocr_service._ensure_structure_engine(layout_model='chinese')
# Second call with 'cdla' should create new engine
engine2 = ocr_service._ensure_structure_engine(layout_model='cdla')
# Verify two engines were created
assert mock_ppstructure.call_count == 2
assert engine1 is not engine2
class TestLayoutModelFlow:
"""Test layout model parameter flow through processing pipeline"""
def test_layout_model_passed_to_engine_creation(self):
"""Verify layout_model is passed through to _ensure_structure_engine"""
ocr_service = OCRService()
# Test that _ensure_structure_engine accepts layout_model parameter
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
mock_engine = Mock()
mock_ppstructure.return_value = mock_engine
# Call with specific layout_model
engine = ocr_service._ensure_structure_engine(layout_model='cdla')
# Verify correct model was requested
mock_ppstructure.assert_called_once()
call_kwargs = mock_ppstructure.call_args[1]
assert call_kwargs.get('layout_detection_model_name') == 'picodet_lcnet_x1_0_fgd_layout_cdla'
def test_layout_model_default_behavior(self):
"""Verify default layout model behavior when None is passed"""
ocr_service = OCRService()
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
mock_engine = Mock()
mock_ppstructure.return_value = mock_engine
# Call without layout_model (None)
engine = ocr_service._ensure_structure_engine(layout_model=None)
# Should use config default (PP-DocLayout-S)
mock_ppstructure.assert_called_once()
call_kwargs = mock_ppstructure.call_args[1]
assert call_kwargs.get('layout_detection_model_name') == settings.layout_detection_model_name
def test_layout_model_unknown_value_falls_back(self):
"""Verify unknown layout model falls back to config default"""
ocr_service = OCRService()
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
mock_engine = Mock()
mock_ppstructure.return_value = mock_engine
# Call with unknown layout_model
engine = ocr_service._ensure_structure_engine(layout_model='unknown_model')
# Should use config default
mock_ppstructure.assert_called_once()
call_kwargs = mock_ppstructure.call_args[1]
assert call_kwargs.get('layout_detection_model_name') == settings.layout_detection_model_name
class TestLayoutModelLogging:
"""Test layout model logging"""
def test_layout_model_is_logged(self):
"""Verify layout model selection is logged"""
ocr_service = OCRService()
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
with patch('app.services.ocr_service.logger') as mock_logger:
mock_engine = Mock()
mock_ppstructure.return_value = mock_engine
# Call with specific layout_model
ocr_service._ensure_structure_engine(layout_model='cdla')
# Verify logging occurred
assert mock_logger.info.call_count >= 1
# Check that model name was logged
log_calls = [str(call) for call in mock_logger.info.call_args_list]
assert any('cdla' in str(call).lower() or 'layout' in str(call).lower() for call in log_calls)
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,503 @@
"""
Tests for Gap Filling Service
Tests the detection and filling of gaps in PP-StructureV3 output
using raw OCR text regions.
"""
import pytest
from typing import List, Dict, Any
from app.services.gap_filling_service import GapFillingService, TextRegion, SKIP_ELEMENT_TYPES
from app.models.unified_document import DocumentElement, BoundingBox, ElementType, Dimensions
class TestGapFillingService:
"""Tests for GapFillingService class."""
@pytest.fixture
def service(self) -> GapFillingService:
"""Create a GapFillingService instance with default settings."""
return GapFillingService(
coverage_threshold=0.7,
iou_threshold=0.15,
confidence_threshold=0.3,
dedup_iou_threshold=0.5,
enabled=True
)
@pytest.fixture
def disabled_service(self) -> GapFillingService:
"""Create a disabled GapFillingService instance."""
return GapFillingService(enabled=False)
@pytest.fixture
def sample_raw_regions(self) -> List[TextRegion]:
"""Create sample raw OCR text regions."""
return [
TextRegion(text="Header text", bbox=[100, 50, 300, 80], confidence=0.95, page=1),
TextRegion(text="Title of document", bbox=[100, 100, 500, 150], confidence=0.92, page=1),
TextRegion(text="First paragraph", bbox=[100, 200, 500, 250], confidence=0.90, page=1),
TextRegion(text="Second paragraph", bbox=[100, 300, 500, 350], confidence=0.88, page=1),
TextRegion(text="Footer note", bbox=[100, 900, 300, 930], confidence=0.85, page=1),
# Low confidence region (should be filtered)
TextRegion(text="Noise", bbox=[50, 50, 80, 80], confidence=0.1, page=1),
]
@pytest.fixture
def sample_pp_elements(self) -> List[DocumentElement]:
"""Create sample PP-StructureV3 elements that cover only some regions."""
return [
DocumentElement(
element_id="pp_1",
type=ElementType.TITLE,
content="Title of document",
bbox=BoundingBox(x0=100, y0=100, x1=500, y1=150),
confidence=0.95
),
DocumentElement(
element_id="pp_2",
type=ElementType.TEXT,
content="First paragraph",
bbox=BoundingBox(x0=100, y0=200, x1=500, y1=250),
confidence=0.90
),
# Note: Header, Second paragraph, and Footer are NOT covered
]
def test_service_initialization(self, service: GapFillingService):
"""Test service initializes with correct parameters."""
assert service.enabled is True
assert service.coverage_threshold == 0.7
assert service.iou_threshold == 0.15
assert service.confidence_threshold == 0.3
assert service.dedup_iou_threshold == 0.5
def test_disabled_service(self, disabled_service: GapFillingService):
"""Test disabled service does not activate."""
regions = [TextRegion(text="Test", bbox=[0, 0, 100, 100], confidence=0.9, page=1)]
elements = []
should_activate, coverage = disabled_service.should_activate(regions, elements)
assert should_activate is False
assert coverage == 1.0
def test_should_activate_low_coverage(
self,
service: GapFillingService,
sample_raw_regions: List[TextRegion],
sample_pp_elements: List[DocumentElement]
):
"""Test activation when coverage is below threshold."""
# Filter out low confidence regions
valid_regions = [r for r in sample_raw_regions if r.confidence >= 0.3]
should_activate, coverage = service.should_activate(valid_regions, sample_pp_elements)
# Only 2 out of 5 valid regions are covered (Title, First paragraph)
assert should_activate is True
assert coverage < 0.7 # Below threshold
def test_should_not_activate_high_coverage(self, service: GapFillingService):
"""Test no activation when coverage is above threshold."""
# All regions covered
regions = [
TextRegion(text="Text 1", bbox=[100, 100, 200, 150], confidence=0.9, page=1),
TextRegion(text="Text 2", bbox=[100, 200, 200, 250], confidence=0.9, page=1),
]
elements = [
DocumentElement(
element_id="pp_1",
type=ElementType.TEXT,
content="Text 1",
bbox=BoundingBox(x0=50, y0=50, x1=250, y1=200), # Covers first region
confidence=0.95
),
DocumentElement(
element_id="pp_2",
type=ElementType.TEXT,
content="Text 2",
bbox=BoundingBox(x0=50, y0=180, x1=250, y1=300), # Covers second region
confidence=0.95
),
]
should_activate, coverage = service.should_activate(regions, elements)
assert should_activate is False
assert coverage >= 0.7
def test_find_uncovered_regions(
self,
service: GapFillingService,
sample_raw_regions: List[TextRegion],
sample_pp_elements: List[DocumentElement]
):
"""Test finding uncovered regions."""
uncovered = service.find_uncovered_regions(sample_raw_regions, sample_pp_elements)
# Should find Header, Second paragraph, Footer (not Title, First paragraph, or low-confidence Noise)
assert len(uncovered) == 3
uncovered_texts = [r.text for r in uncovered]
assert "Header text" in uncovered_texts
assert "Second paragraph" in uncovered_texts
assert "Footer note" in uncovered_texts
assert "Title of document" not in uncovered_texts # Covered
assert "First paragraph" not in uncovered_texts # Covered
assert "Noise" not in uncovered_texts # Low confidence
def test_coverage_by_center_point(self, service: GapFillingService):
"""Test coverage detection via center point."""
region = TextRegion(text="Test", bbox=[150, 150, 250, 200], confidence=0.9, page=1)
element = DocumentElement(
element_id="pp_1",
type=ElementType.TEXT,
content="Container",
bbox=BoundingBox(x0=100, y0=100, x1=300, y1=250), # Contains region's center
confidence=0.95
)
is_covered = service._is_region_covered(region, [element])
assert is_covered is True
def test_coverage_by_iou(self, service: GapFillingService):
"""Test coverage detection via IoU threshold."""
region = TextRegion(text="Test", bbox=[100, 100, 200, 150], confidence=0.9, page=1)
element = DocumentElement(
element_id="pp_1",
type=ElementType.TEXT,
content="Overlap",
bbox=BoundingBox(x0=150, y0=100, x1=250, y1=150), # Partial overlap
confidence=0.95
)
# Calculate expected IoU
# Intersection: (150-200) x (100-150) = 50 x 50 = 2500
# Union: 100x50 + 100x50 - 2500 = 7500
# IoU = 2500/7500 = 0.33 > 0.15 threshold
is_covered = service._is_region_covered(region, [element])
assert is_covered is True
def test_deduplication(
self,
service: GapFillingService,
sample_pp_elements: List[DocumentElement]
):
"""Test deduplication removes high-overlap regions."""
uncovered = [
# High overlap with pp_2 (First paragraph)
TextRegion(text="First paragraph variant", bbox=[100, 200, 500, 250], confidence=0.9, page=1),
# No overlap
TextRegion(text="Unique region", bbox=[100, 500, 300, 550], confidence=0.9, page=1),
]
deduplicated = service.deduplicate_regions(uncovered, sample_pp_elements)
assert len(deduplicated) == 1
assert deduplicated[0].text == "Unique region"
def test_convert_regions_to_elements(self, service: GapFillingService):
"""Test conversion of TextRegions to DocumentElements."""
regions = [
TextRegion(text="Test text 1", bbox=[100, 100, 200, 150], confidence=0.85, page=1),
TextRegion(text="Test text 2", bbox=[100, 200, 200, 250], confidence=0.90, page=1),
]
elements = service.convert_regions_to_elements(regions, page_number=1, start_element_id=0)
assert len(elements) == 2
assert elements[0].element_id == "gap_fill_1_0"
assert elements[0].type == ElementType.TEXT
assert elements[0].content == "Test text 1"
assert elements[0].confidence == 0.85
assert elements[0].metadata.get('source') == 'gap_filling'
assert elements[1].element_id == "gap_fill_1_1"
assert elements[1].content == "Test text 2"
def test_recalculate_reading_order(self, service: GapFillingService):
"""Test reading order recalculation."""
elements = [
DocumentElement(
element_id="e3",
type=ElementType.TEXT,
content="Bottom",
bbox=BoundingBox(x0=100, y0=300, x1=200, y1=350),
confidence=0.9
),
DocumentElement(
element_id="e1",
type=ElementType.TEXT,
content="Top",
bbox=BoundingBox(x0=100, y0=100, x1=200, y1=150),
confidence=0.9
),
DocumentElement(
element_id="e2",
type=ElementType.TEXT,
content="Middle",
bbox=BoundingBox(x0=100, y0=200, x1=200, y1=250),
confidence=0.9
),
]
reading_order = service.recalculate_reading_order(elements)
# Should be sorted by y0: Top (100), Middle (200), Bottom (300)
assert reading_order == [1, 2, 0] # Indices of elements in reading order
def test_fill_gaps_integration(
self,
service: GapFillingService,
):
"""Integration test for fill_gaps method."""
# Raw OCR regions (dict format as received from OCR service)
raw_regions = [
{'text': 'Header', 'bbox': [100, 50, 300, 80], 'confidence': 0.95, 'page': 1},
{'text': 'Title', 'bbox': [100, 100, 500, 150], 'confidence': 0.92, 'page': 1},
{'text': 'Paragraph 1', 'bbox': [100, 200, 500, 250], 'confidence': 0.90, 'page': 1},
{'text': 'Paragraph 2', 'bbox': [100, 300, 500, 350], 'confidence': 0.88, 'page': 1},
{'text': 'Paragraph 3', 'bbox': [100, 400, 500, 450], 'confidence': 0.86, 'page': 1},
{'text': 'Footer', 'bbox': [100, 900, 300, 930], 'confidence': 0.85, 'page': 1},
]
# PP-StructureV3 only detected Title (missing 5 out of 6 regions = 16.7% coverage)
pp_elements = [
DocumentElement(
element_id="pp_1",
type=ElementType.TITLE,
content="Title",
bbox=BoundingBox(x0=100, y0=100, x1=500, y1=150),
confidence=0.95
),
]
supplemented, stats = service.fill_gaps(
raw_ocr_regions=raw_regions,
pp_structure_elements=pp_elements,
page_number=1
)
# Should have activated and supplemented missing regions
assert stats['activated'] is True
assert stats['coverage_ratio'] < 0.7
assert len(supplemented) == 5 # Header, Paragraph 1, 2, 3, Footer
def test_fill_gaps_no_activation_when_coverage_high(self, service: GapFillingService):
"""Test fill_gaps does not activate when coverage is high."""
raw_regions = [
{'text': 'Text 1', 'bbox': [100, 100, 200, 150], 'confidence': 0.9, 'page': 1},
]
pp_elements = [
DocumentElement(
element_id="pp_1",
type=ElementType.TEXT,
content="Text 1",
bbox=BoundingBox(x0=50, y0=50, x1=250, y1=200), # Fully covers
confidence=0.95
),
]
supplemented, stats = service.fill_gaps(
raw_ocr_regions=raw_regions,
pp_structure_elements=pp_elements,
page_number=1
)
assert stats['activated'] is False
assert len(supplemented) == 0
def test_skip_element_types_not_supplemented(self, service: GapFillingService):
"""Test that TABLE/IMAGE/etc. elements are not supplemented over."""
raw_regions = [
{'text': 'Table cell text', 'bbox': [100, 100, 200, 150], 'confidence': 0.9, 'page': 1},
]
# PP-StructureV3 has a table covering this region
pp_elements = [
DocumentElement(
element_id="pp_1",
type=ElementType.TABLE,
content="<table>...</table>",
bbox=BoundingBox(x0=50, y0=50, x1=250, y1=200),
confidence=0.95
),
]
# The region should be considered covered by the table
supplemented, stats = service.fill_gaps(
raw_ocr_regions=raw_regions,
pp_structure_elements=pp_elements,
page_number=1
)
# Should not supplement because the table covers it
assert len(supplemented) == 0
def test_coordinate_scaling(self, service: GapFillingService):
"""Test coordinate alignment with different dimensions."""
# OCR was done at 2000x3000, PP-Structure at 1000x1500
ocr_dimensions = {'width': 2000, 'height': 3000}
pp_dimensions = Dimensions(width=1000, height=1500)
raw_regions = [
# At OCR scale: (200, 300) to (400, 450) -> at PP scale: (100, 150) to (200, 225)
{'text': 'Scaled text', 'bbox': [200, 300, 400, 450], 'confidence': 0.9, 'page': 1},
]
pp_elements = [
DocumentElement(
element_id="pp_1",
type=ElementType.TEXT,
content="Scaled text",
bbox=BoundingBox(x0=100, y0=150, x1=200, y1=225), # Should cover after scaling
confidence=0.95
),
]
supplemented, stats = service.fill_gaps(
raw_ocr_regions=raw_regions,
pp_structure_elements=pp_elements,
page_number=1,
ocr_dimensions=ocr_dimensions,
pp_dimensions=pp_dimensions
)
# After scaling, the region should be covered
assert stats['coverage_ratio'] >= 0.7 or len(supplemented) == 0
def test_iou_calculation(self, service: GapFillingService):
"""Test IoU calculation accuracy."""
# Two identical boxes
bbox1 = (0, 0, 100, 100)
bbox2 = (0, 0, 100, 100)
assert service._calculate_iou(bbox1, bbox2) == 1.0
# No overlap
bbox1 = (0, 0, 100, 100)
bbox2 = (200, 200, 300, 300)
assert service._calculate_iou(bbox1, bbox2) == 0.0
# 50% overlap
bbox1 = (0, 0, 100, 100)
bbox2 = (50, 0, 150, 100) # Shifted right by 50
# Intersection: 50x100 = 5000
# Union: 10000 + 10000 - 5000 = 15000
# IoU = 5000/15000 = 0.333...
iou = service._calculate_iou(bbox1, bbox2)
assert abs(iou - 1/3) < 0.01
def test_point_in_bbox(self, service: GapFillingService):
"""Test point-in-bbox check."""
bbox = (100, 100, 200, 200)
# Inside
assert service._point_in_bbox(150, 150, bbox) is True
# On edge
assert service._point_in_bbox(100, 100, bbox) is True
assert service._point_in_bbox(200, 200, bbox) is True
# Outside
assert service._point_in_bbox(50, 150, bbox) is False
assert service._point_in_bbox(250, 150, bbox) is False
def test_merge_adjacent_regions(self, service: GapFillingService):
"""Test merging of adjacent text regions."""
regions = [
TextRegion(text="Hello", bbox=[100, 100, 150, 130], confidence=0.9, page=1),
TextRegion(text="World", bbox=[160, 100, 210, 130], confidence=0.85, page=1), # Adjacent
TextRegion(text="Far away", bbox=[100, 300, 200, 330], confidence=0.9, page=1), # Not adjacent
]
merged = service.merge_adjacent_regions(regions, max_horizontal_gap=20, max_vertical_gap=10)
assert len(merged) == 2
# First two should be merged
assert "Hello" in merged[0].text and "World" in merged[0].text
assert merged[1].text == "Far away"
class TestTextRegion:
"""Tests for TextRegion dataclass."""
def test_normalized_bbox_4_values(self):
"""Test bbox normalization with 4 values."""
region = TextRegion(text="Test", bbox=[100, 200, 300, 400], confidence=0.9, page=1)
assert region.normalized_bbox == (100, 200, 300, 400)
def test_normalized_bbox_polygon_flat(self):
"""Test bbox normalization with flat polygon format (8 values)."""
# Polygon: 4 points as flat list [x1, y1, x2, y2, x3, y3, x4, y4]
region = TextRegion(
text="Test",
bbox=[100, 200, 300, 200, 300, 400, 100, 400],
confidence=0.9,
page=1
)
assert region.normalized_bbox == (100, 200, 300, 400)
def test_normalized_bbox_polygon_nested(self):
"""Test bbox normalization with nested polygon format (PaddleOCR format)."""
# PaddleOCR format: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
region = TextRegion(
text="Test",
bbox=[[100, 200], [300, 200], [300, 400], [100, 400]],
confidence=0.9,
page=1
)
assert region.normalized_bbox == (100, 200, 300, 400)
def test_normalized_bbox_numpy_polygon(self):
"""Test bbox normalization with numpy-like nested format."""
# Sometimes PaddleOCR returns numpy arrays converted to lists
region = TextRegion(
text="Test",
bbox=[[100.5, 200.5], [300.5, 200.5], [300.5, 400.5], [100.5, 400.5]],
confidence=0.9,
page=1
)
bbox = region.normalized_bbox
assert bbox[0] == 100.5
assert bbox[1] == 200.5
assert bbox[2] == 300.5
assert bbox[3] == 400.5
def test_center_calculation(self):
"""Test center point calculation."""
region = TextRegion(text="Test", bbox=[100, 200, 300, 400], confidence=0.9, page=1)
assert region.center == (200, 300)
def test_center_calculation_nested_bbox(self):
"""Test center point calculation with nested bbox format."""
region = TextRegion(
text="Test",
bbox=[[100, 200], [300, 200], [300, 400], [100, 400]],
confidence=0.9,
page=1
)
assert region.center == (200, 300)
class TestOCRToUnifiedConverterIntegration:
"""Integration tests for OCRToUnifiedConverter with gap filling."""
def test_converter_with_gap_filling_enabled(self):
"""Test converter initializes with gap filling enabled."""
from app.services.ocr_to_unified_converter import OCRToUnifiedConverter
converter = OCRToUnifiedConverter(enable_gap_filling=True)
assert converter.gap_filling_service is not None
def test_converter_with_gap_filling_disabled(self):
"""Test converter initializes without gap filling."""
from app.services.ocr_to_unified_converter import OCRToUnifiedConverter
converter = OCRToUnifiedConverter(enable_gap_filling=False)
assert converter.gap_filling_service is None