diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 9ffc019..8ae7508 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -3,7 +3,7 @@ Tool_OCR - Configuration Management Loads environment variables and provides centralized configuration """ -from typing import List +from typing import List, Optional from pydantic_settings import BaseSettings from pydantic import Field from pathlib import Path @@ -99,6 +99,33 @@ class Settings(BaseSettings): text_det_box_thresh: float = Field(default=0.3) # Lower box threshold for better detection text_det_unclip_ratio: float = Field(default=1.2) # Smaller unclip for tighter text boxes + # Layout Detection Model Configuration + # Available models: + # - None (default): Use PP-StructureV3's built-in model (PubLayNet-based) + # - "PP-DocLayout-S": Better for Chinese docs, papers, contracts, exams (23 categories) + # - "picodet_lcnet_x1_0_fgd_layout_cdla": CDLA-based model for Chinese document layout + layout_detection_model_name: Optional[str] = Field( + default="PP-DocLayout-S", + description="Layout detection model name. Set to 'PP-DocLayout-S' for better Chinese document support." + ) + layout_detection_model_dir: Optional[str] = Field( + default=None, + description="Custom layout detection model directory. If None, downloads official model." + ) + + # ===== Gap Filling Configuration ===== + # Supplements PP-StructureV3 output with raw OCR regions when detection is incomplete + gap_filling_enabled: bool = Field(default=True) # Enable gap filling for OCR track + gap_filling_coverage_threshold: float = Field(default=0.7) # Activate when coverage < 70% + gap_filling_iou_threshold: float = Field(default=0.15) # IoU threshold for coverage detection + gap_filling_confidence_threshold: float = Field(default=0.3) # Min confidence for raw OCR regions + gap_filling_dedup_iou_threshold: float = Field(default=0.5) # IoU threshold for deduplication + + # ===== Debug Configuration ===== + # Enable debug outputs for PP-StructureV3 analysis + pp_structure_debug_enabled: bool = Field(default=True) # Save debug files for PP-StructureV3 + pp_structure_debug_visualization: bool = Field(default=True) # Generate visualization images + # Performance tuning use_fp16_inference: bool = Field(default=False) # Half-precision (if supported) enable_cudnn_benchmark: bool = Field(default=True) # Optimize convolution algorithms diff --git a/backend/app/routers/tasks.py b/backend/app/routers/tasks.py index 38a170b..6ed2672 100644 --- a/backend/app/routers/tasks.py +++ b/backend/app/routers/tasks.py @@ -68,7 +68,7 @@ def process_task_ocr( use_dual_track: bool = True, force_track: Optional[str] = None, language: str = 'ch', - pp_structure_params: Optional[dict] = None + layout_model: Optional[str] = "chinese" ): """ Background task to process OCR for a task with dual-track support. @@ -84,7 +84,7 @@ def process_task_ocr( use_dual_track: Enable dual-track processing force_track: Force specific track ('ocr' or 'direct') language: OCR language code - pp_structure_params: Optional custom PP-StructureV3 parameters (dict) + layout_model: Layout detection model ('chinese', 'default', 'cdla') """ from app.core.database import SessionLocal from app.models.task import Task @@ -143,7 +143,7 @@ def process_task_ocr( output_dir=result_dir, use_dual_track=use_dual_track, force_track=force_track, - pp_structure_params=pp_structure_params + layout_model=layout_model ) else: # Fall back to traditional processing (no force_track support) @@ -152,7 +152,7 @@ def process_task_ocr( lang=language, detect_layout=True, output_dir=result_dir, - pp_structure_params=pp_structure_params + layout_model=layout_model ) # Calculate processing time @@ -717,14 +717,14 @@ async def start_task( current_user: User = Depends(get_current_user) ): """ - Start processing a pending task with dual-track support and optional PP-StructureV3 parameter tuning + Start processing a pending task with dual-track support and layout model selection - **task_id**: Task UUID - **options**: Processing options (in request body): - **use_dual_track**: Enable intelligent track selection (default: true) - **force_track**: Force specific processing track ('ocr' or 'direct') - **language**: OCR language code (default: 'ch') - - **pp_structure_params**: Fine-tuning parameters for PP-StructureV3 (OCR track only) + - **layout_model**: Layout detection model ('chinese', 'default', 'cdla') """ try: # Parse processing options with defaults @@ -735,11 +735,9 @@ async def start_task( force_track = options.force_track.value if options.force_track else None language = options.language - # Extract and convert PP-StructureV3 parameters to dict - pp_structure_params = None - if options.pp_structure_params: - pp_structure_params = options.pp_structure_params.model_dump(exclude_none=True) - logger.info(f"Using custom PP-StructureV3 parameters: {pp_structure_params}") + # Extract layout model (default to 'chinese' for best Chinese document support) + layout_model = options.layout_model.value if options.layout_model else "chinese" + logger.info(f"Using layout model: {layout_model}") # Get task details task = task_service.get_task_by_id( @@ -777,7 +775,7 @@ async def start_task( status=TaskStatus.PROCESSING ) - # Start OCR processing in background with dual-track parameters and custom PP-StructureV3 params + # Start OCR processing in background with dual-track parameters and layout model background_tasks.add_task( process_task_ocr, task_id=task_id, @@ -787,13 +785,11 @@ async def start_task( use_dual_track=use_dual_track, force_track=force_track, language=language, - pp_structure_params=pp_structure_params + layout_model=layout_model ) logger.info(f"Started OCR processing task {task_id} for user {current_user.email}") - logger.info(f"Options: dual_track={use_dual_track}, force_track={force_track}, lang={language}") - if pp_structure_params: - logger.info(f"Custom PP-StructureV3 params: {pp_structure_params}") + logger.info(f"Options: dual_track={use_dual_track}, force_track={force_track}, lang={language}, layout_model={layout_model}") return task except HTTPException: diff --git a/backend/app/schemas/task.py b/backend/app/schemas/task.py index 0ecdf87..85705c1 100644 --- a/backend/app/schemas/task.py +++ b/backend/app/schemas/task.py @@ -24,6 +24,19 @@ class ProcessingTrackEnum(str, Enum): AUTO = "auto" # Auto-detect best track +class LayoutModelEnum(str, Enum): + """Layout detection model selection for OCR track. + + Different models are optimized for different document types: + - CHINESE: PP-DocLayout-S, optimized for Chinese documents (forms, contracts, invoices) + - DEFAULT: PubLayNet-based, optimized for English academic papers + - CDLA: CDLA model, specialized Chinese document layout analysis + """ + CHINESE = "chinese" # PP-DocLayout-S - Best for Chinese documents (recommended) + DEFAULT = "default" # PubLayNet-based - Best for English documents + CDLA = "cdla" # CDLA model - Alternative for Chinese layout + + class TaskCreate(BaseModel): """Task creation request""" filename: Optional[str] = Field(None, description="Original filename") @@ -132,7 +145,11 @@ class UploadResponse(BaseModel): # ===== Dual-Track Processing Schemas ===== class PPStructureV3Params(BaseModel): - """PP-StructureV3 fine-tuning parameters for OCR track""" + """PP-StructureV3 fine-tuning parameters for OCR track. + + DEPRECATED: This class is deprecated and will be removed in a future version. + Use `layout_model` parameter in ProcessingOptions instead. + """ layout_detection_threshold: Optional[float] = Field( None, ge=0, le=1, description="Layout block detection score threshold (lower=more blocks, higher=high confidence only)" @@ -172,10 +189,10 @@ class ProcessingOptions(BaseModel): include_images: bool = Field(default=True, description="Extract and save images") confidence_threshold: Optional[float] = Field(None, ge=0, le=1, description="OCR confidence threshold") - # PP-StructureV3 fine-tuning parameters (OCR track only) - pp_structure_params: Optional[PPStructureV3Params] = Field( - None, - description="Fine-tuning parameters for PP-StructureV3 (OCR track only)" + # Layout model selection (OCR track only) + layout_model: Optional[LayoutModelEnum] = Field( + default=LayoutModelEnum.CHINESE, + description="Layout detection model: 'chinese' (recommended for Chinese docs), 'default' (English docs), 'cdla' (Chinese layout)" ) diff --git a/backend/app/services/gap_filling_service.py b/backend/app/services/gap_filling_service.py new file mode 100644 index 0000000..e616124 --- /dev/null +++ b/backend/app/services/gap_filling_service.py @@ -0,0 +1,649 @@ +""" +Gap Filling Service for OCR Track + +This service detects and fills gaps in PP-StructureV3 output by supplementing +with Raw OCR text regions when significant content loss is detected. + +The hybrid approach uses Raw OCR's comprehensive text detection to compensate +for PP-StructureV3's layout model limitations on certain document types. +""" + +import logging +from typing import Dict, List, Optional, Tuple, Set, Any +from dataclasses import dataclass + +from app.models.unified_document import ( + DocumentElement, BoundingBox, ElementType, Dimensions +) +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +# Element types that should NOT be supplemented (preserve structural integrity) +SKIP_ELEMENT_TYPES: Set[ElementType] = { + ElementType.TABLE, + ElementType.IMAGE, + ElementType.FIGURE, + ElementType.CHART, + ElementType.DIAGRAM, + ElementType.HEADER, + ElementType.FOOTER, + ElementType.FORMULA, + ElementType.CODE, + ElementType.BARCODE, + ElementType.QR_CODE, + ElementType.LOGO, + ElementType.STAMP, + ElementType.SIGNATURE, +} + + +@dataclass +class TextRegion: + """Represents a raw OCR text region.""" + text: str + bbox: List[float] # [x0, y0, x1, y1] or polygon format + confidence: float + page: int = 0 + + @property + def normalized_bbox(self) -> Tuple[float, float, float, float]: + """Get normalized bbox as (x0, y0, x1, y1).""" + if not self.bbox: + return (0, 0, 0, 0) + + # Check if bbox is nested list format [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] + # This is common PaddleOCR polygon format + if len(self.bbox) >= 1 and isinstance(self.bbox[0], (list, tuple)): + # Nested format: extract all x and y coordinates + xs = [pt[0] for pt in self.bbox if len(pt) >= 2] + ys = [pt[1] for pt in self.bbox if len(pt) >= 2] + if xs and ys: + return (min(xs), min(ys), max(xs), max(ys)) + return (0, 0, 0, 0) + + # Flat format + if len(self.bbox) == 4: + # Simple [x0, y0, x1, y1] format + return (float(self.bbox[0]), float(self.bbox[1]), + float(self.bbox[2]), float(self.bbox[3])) + elif len(self.bbox) >= 8: + # Flat polygon format: [x1, y1, x2, y2, x3, y3, x4, y4] + xs = [self.bbox[i] for i in range(0, len(self.bbox), 2)] + ys = [self.bbox[i] for i in range(1, len(self.bbox), 2)] + return (min(xs), min(ys), max(xs), max(ys)) + + return (0, 0, 0, 0) + + @property + def center(self) -> Tuple[float, float]: + """Get center point of the bbox.""" + x0, y0, x1, y1 = self.normalized_bbox + return ((x0 + x1) / 2, (y0 + y1) / 2) + + +class GapFillingService: + """ + Service for detecting and filling gaps in PP-StructureV3 output. + + This service: + 1. Calculates coverage of PP-StructureV3 elements over raw OCR regions + 2. Identifies uncovered raw OCR regions + 3. Supplements uncovered regions as TEXT elements + 4. Deduplicates against existing PP-StructureV3 TEXT elements + 5. Recalculates reading order for the combined result + """ + + def __init__( + self, + coverage_threshold: float = None, + iou_threshold: float = None, + confidence_threshold: float = None, + dedup_iou_threshold: float = None, + enabled: bool = None + ): + """ + Initialize the gap filling service. + + Args: + coverage_threshold: Coverage ratio below which gap filling activates (default: 0.7) + iou_threshold: IoU threshold for coverage detection (default: 0.15) + confidence_threshold: Minimum confidence for raw OCR regions (default: 0.3) + dedup_iou_threshold: IoU threshold for deduplication (default: 0.5) + enabled: Whether gap filling is enabled (default: True) + """ + self.coverage_threshold = coverage_threshold if coverage_threshold is not None else getattr( + settings, 'gap_filling_coverage_threshold', 0.7 + ) + self.iou_threshold = iou_threshold if iou_threshold is not None else getattr( + settings, 'gap_filling_iou_threshold', 0.15 + ) + self.confidence_threshold = confidence_threshold if confidence_threshold is not None else getattr( + settings, 'gap_filling_confidence_threshold', 0.3 + ) + self.dedup_iou_threshold = dedup_iou_threshold if dedup_iou_threshold is not None else getattr( + settings, 'gap_filling_dedup_iou_threshold', 0.5 + ) + self.enabled = enabled if enabled is not None else getattr( + settings, 'gap_filling_enabled', True + ) + + def should_activate( + self, + raw_ocr_regions: List[TextRegion], + pp_structure_elements: List[DocumentElement] + ) -> Tuple[bool, float]: + """ + Determine if gap filling should be activated. + + Gap filling activates when: + 1. Coverage ratio is below threshold (default: 70%) + 2. OR element count disparity is significant + + Args: + raw_ocr_regions: List of raw OCR text regions + pp_structure_elements: List of PP-StructureV3 elements + + Returns: + Tuple of (should_activate, coverage_ratio) + """ + if not self.enabled: + return False, 1.0 + + if not raw_ocr_regions: + return False, 1.0 + + # Calculate coverage + covered_count = 0 + for region in raw_ocr_regions: + if self._is_region_covered(region, pp_structure_elements): + covered_count += 1 + + coverage_ratio = covered_count / len(raw_ocr_regions) + + # Check activation conditions + should_activate = coverage_ratio < self.coverage_threshold + + if should_activate: + logger.info( + f"Gap filling activated: coverage={coverage_ratio:.2%} < threshold={self.coverage_threshold:.0%}, " + f"raw_regions={len(raw_ocr_regions)}, pp_elements={len(pp_structure_elements)}" + ) + else: + logger.debug( + f"Gap filling not needed: coverage={coverage_ratio:.2%} >= threshold={self.coverage_threshold:.0%}" + ) + + return should_activate, coverage_ratio + + def find_uncovered_regions( + self, + raw_ocr_regions: List[TextRegion], + pp_structure_elements: List[DocumentElement] + ) -> List[TextRegion]: + """ + Find raw OCR regions not covered by PP-StructureV3 elements. + + A region is considered covered if: + 1. Its center point falls inside any PP-StructureV3 element bbox, OR + 2. IoU with any PP-StructureV3 element exceeds iou_threshold + + Args: + raw_ocr_regions: List of raw OCR text regions + pp_structure_elements: List of PP-StructureV3 elements + + Returns: + List of uncovered raw OCR regions + """ + uncovered = [] + + for region in raw_ocr_regions: + # Skip low confidence regions + if region.confidence < self.confidence_threshold: + continue + + if not self._is_region_covered(region, pp_structure_elements): + uncovered.append(region) + + logger.debug(f"Found {len(uncovered)} uncovered regions out of {len(raw_ocr_regions)}") + return uncovered + + def _is_region_covered( + self, + region: TextRegion, + pp_structure_elements: List[DocumentElement] + ) -> bool: + """ + Check if a raw OCR region is covered by any PP-StructureV3 element. + + Args: + region: Raw OCR text region + pp_structure_elements: List of PP-StructureV3 elements + + Returns: + True if the region is covered + """ + center_x, center_y = region.center + region_bbox = region.normalized_bbox + + for element in pp_structure_elements: + elem_bbox = ( + element.bbox.x0, element.bbox.y0, + element.bbox.x1, element.bbox.y1 + ) + + # Check 1: Center point falls inside element bbox + if self._point_in_bbox(center_x, center_y, elem_bbox): + return True + + # Check 2: IoU exceeds threshold + iou = self._calculate_iou(region_bbox, elem_bbox) + if iou > self.iou_threshold: + return True + + return False + + def deduplicate_regions( + self, + uncovered_regions: List[TextRegion], + pp_structure_elements: List[DocumentElement] + ) -> List[TextRegion]: + """ + Remove regions that highly overlap with existing PP-StructureV3 TEXT elements. + + Args: + uncovered_regions: List of uncovered raw OCR regions + pp_structure_elements: List of PP-StructureV3 elements + + Returns: + Deduplicated list of regions + """ + # Get TEXT elements only for deduplication + text_elements = [ + e for e in pp_structure_elements + if e.type not in SKIP_ELEMENT_TYPES + ] + + deduplicated = [] + for region in uncovered_regions: + region_bbox = region.normalized_bbox + is_duplicate = False + + for element in text_elements: + elem_bbox = ( + element.bbox.x0, element.bbox.y0, + element.bbox.x1, element.bbox.y1 + ) + + iou = self._calculate_iou(region_bbox, elem_bbox) + if iou > self.dedup_iou_threshold: + logger.debug( + f"Skipping duplicate region (IoU={iou:.2f}): '{region.text[:30]}...'" + ) + is_duplicate = True + break + + if not is_duplicate: + deduplicated.append(region) + + removed_count = len(uncovered_regions) - len(deduplicated) + if removed_count > 0: + logger.debug(f"Removed {removed_count} duplicate regions") + + return deduplicated + + def convert_regions_to_elements( + self, + regions: List[TextRegion], + page_number: int, + start_element_id: int = 0 + ) -> List[DocumentElement]: + """ + Convert raw OCR regions to DocumentElement objects. + + Args: + regions: List of raw OCR regions to convert + page_number: Page number for the elements + start_element_id: Starting ID counter for elements + + Returns: + List of DocumentElement objects + """ + elements = [] + + for idx, region in enumerate(regions): + x0, y0, x1, y1 = region.normalized_bbox + + element = DocumentElement( + element_id=f"gap_fill_{page_number}_{start_element_id + idx}", + type=ElementType.TEXT, + content=region.text, + bbox=BoundingBox(x0=x0, y0=y0, x1=x1, y1=y1), + confidence=region.confidence, + metadata={ + 'source': 'gap_filling', + 'original_confidence': region.confidence + } + ) + elements.append(element) + + return elements + + def recalculate_reading_order( + self, + elements: List[DocumentElement] + ) -> List[int]: + """ + Recalculate reading order for elements based on position. + + Sorts elements by y0 (top to bottom) then x0 (left to right). + + Args: + elements: List of DocumentElement objects + + Returns: + List of element indices in reading order + """ + # Create indexed list with position info + indexed_elements = [ + (idx, e.bbox.y0, e.bbox.x0) + for idx, e in enumerate(elements) + ] + + # Sort by y0 then x0 + indexed_elements.sort(key=lambda x: (x[1], x[2])) + + # Return indices in reading order + return [idx for idx, _, _ in indexed_elements] + + def merge_adjacent_regions( + self, + regions: List[TextRegion], + max_horizontal_gap: float = 20.0, + max_vertical_gap: float = 5.0 + ) -> List[TextRegion]: + """ + Merge fragmented adjacent regions on the same line. + + This is optional and can reduce fragmentation from raw OCR. + + Args: + regions: List of raw OCR regions + max_horizontal_gap: Maximum horizontal gap to merge (pixels) + max_vertical_gap: Maximum vertical gap to merge (pixels) + + Returns: + List of merged regions + """ + if not regions: + return regions + + # Sort by y0, then x0 + sorted_regions = sorted( + regions, + key=lambda r: (r.normalized_bbox[1], r.normalized_bbox[0]) + ) + + merged = [] + current = sorted_regions[0] + + for next_region in sorted_regions[1:]: + curr_bbox = current.normalized_bbox + next_bbox = next_region.normalized_bbox + + # Check if on same line (vertical overlap) + curr_y_center = (curr_bbox[1] + curr_bbox[3]) / 2 + next_y_center = (next_bbox[1] + next_bbox[3]) / 2 + vertical_distance = abs(curr_y_center - next_y_center) + + # Check horizontal gap + horizontal_gap = next_bbox[0] - curr_bbox[2] + + if (vertical_distance < max_vertical_gap and + 0 <= horizontal_gap <= max_horizontal_gap): + # Merge regions + merged_bbox = [ + min(curr_bbox[0], next_bbox[0]), + min(curr_bbox[1], next_bbox[1]), + max(curr_bbox[2], next_bbox[2]), + max(curr_bbox[3], next_bbox[3]) + ] + current = TextRegion( + text=current.text + " " + next_region.text, + bbox=merged_bbox, + confidence=min(current.confidence, next_region.confidence), + page=current.page + ) + else: + merged.append(current) + current = next_region + + merged.append(current) + + if len(merged) < len(regions): + logger.debug(f"Merged {len(regions)} regions into {len(merged)}") + + return merged + + def fill_gaps( + self, + raw_ocr_regions: List[Dict[str, Any]], + pp_structure_elements: List[DocumentElement], + page_number: int, + ocr_dimensions: Optional[Dict[str, Any]] = None, + pp_dimensions: Optional[Dimensions] = None + ) -> Tuple[List[DocumentElement], Dict[str, Any]]: + """ + Main entry point: detect gaps and fill with raw OCR regions. + + Args: + raw_ocr_regions: Raw OCR results (list of dicts with text, bbox, confidence) + pp_structure_elements: PP-StructureV3 elements + page_number: Current page number + ocr_dimensions: OCR image dimensions for coordinate alignment + pp_dimensions: PP-Structure dimensions for coordinate alignment + + Returns: + Tuple of (supplemented_elements, statistics) + """ + statistics = { + 'enabled': self.enabled, + 'activated': False, + 'coverage_ratio': 1.0, + 'raw_ocr_count': len(raw_ocr_regions), + 'pp_structure_count': len(pp_structure_elements), + 'uncovered_count': 0, + 'deduplicated_count': 0, + 'supplemented_count': 0 + } + + if not self.enabled: + logger.debug("Gap filling is disabled") + return [], statistics + + # Convert raw OCR regions to TextRegion objects + text_regions = self._convert_raw_ocr_regions( + raw_ocr_regions, page_number, ocr_dimensions, pp_dimensions + ) + + if not text_regions: + logger.debug("No valid text regions to process") + return [], statistics + + # Check if gap filling should activate + should_activate, coverage_ratio = self.should_activate( + text_regions, pp_structure_elements + ) + statistics['coverage_ratio'] = coverage_ratio + statistics['activated'] = should_activate + + if not should_activate: + return [], statistics + + # Find uncovered regions + uncovered = self.find_uncovered_regions(text_regions, pp_structure_elements) + statistics['uncovered_count'] = len(uncovered) + + if not uncovered: + logger.debug("No uncovered regions found") + return [], statistics + + # Deduplicate against existing TEXT elements + deduplicated = self.deduplicate_regions(uncovered, pp_structure_elements) + statistics['deduplicated_count'] = len(deduplicated) + + if not deduplicated: + logger.debug("All uncovered regions were duplicates") + return [], statistics + + # Optional: Merge adjacent regions + # merged = self.merge_adjacent_regions(deduplicated) + + # Convert to DocumentElements + start_id = len(pp_structure_elements) + supplemented = self.convert_regions_to_elements( + deduplicated, page_number, start_id + ) + statistics['supplemented_count'] = len(supplemented) + + logger.info( + f"Gap filling complete: supplemented {len(supplemented)} elements " + f"(coverage: {coverage_ratio:.2%} -> estimated {(coverage_ratio + len(supplemented)/len(text_regions) if text_regions else 0):.2%})" + ) + + return supplemented, statistics + + def _convert_raw_ocr_regions( + self, + raw_regions: List[Dict[str, Any]], + page_number: int, + ocr_dimensions: Optional[Dict[str, Any]] = None, + pp_dimensions: Optional[Dimensions] = None + ) -> List[TextRegion]: + """ + Convert raw OCR region dicts to TextRegion objects. + + Handles coordinate alignment if dimensions are provided. + + Args: + raw_regions: List of raw OCR region dictionaries + page_number: Current page number + ocr_dimensions: OCR image dimensions + pp_dimensions: PP-Structure dimensions + + Returns: + List of TextRegion objects + """ + text_regions = [] + + # Calculate scale factors if needed + scale_x, scale_y = 1.0, 1.0 + if ocr_dimensions and pp_dimensions: + ocr_width = ocr_dimensions.get('width', 0) + ocr_height = ocr_dimensions.get('height', 0) + + if ocr_width > 0 and pp_dimensions.width > 0: + scale_x = pp_dimensions.width / ocr_width + if ocr_height > 0 and pp_dimensions.height > 0: + scale_y = pp_dimensions.height / ocr_height + + if scale_x != 1.0 or scale_y != 1.0: + logger.debug(f"Coordinate scaling: x={scale_x:.3f}, y={scale_y:.3f}") + + for region in raw_regions: + text = region.get('text', '') + if not text or not text.strip(): + continue + + confidence = region.get('confidence', 0.0) + bbox_raw = region.get('bbox', []) + + # Normalize bbox + if isinstance(bbox_raw, dict): + # Dict format: {x_min, y_min, x_max, y_max} + bbox = [ + bbox_raw.get('x_min', 0), + bbox_raw.get('y_min', 0), + bbox_raw.get('x_max', 0), + bbox_raw.get('y_max', 0) + ] + elif isinstance(bbox_raw, (list, tuple)): + bbox = list(bbox_raw) + else: + continue + + # Apply scaling if needed + if scale_x != 1.0 or scale_y != 1.0: + # Check if nested list format [[x1,y1], [x2,y2], ...] + if len(bbox) >= 1 and isinstance(bbox[0], (list, tuple)): + bbox = [ + [pt[0] * scale_x, pt[1] * scale_y] + for pt in bbox if len(pt) >= 2 + ] + elif len(bbox) == 4 and not isinstance(bbox[0], (list, tuple)): + # Simple [x0, y0, x1, y1] format + bbox = [ + bbox[0] * scale_x, bbox[1] * scale_y, + bbox[2] * scale_x, bbox[3] * scale_y + ] + elif len(bbox) >= 8: + # Flat polygon format [x1, y1, x2, y2, ...] + bbox = [ + bbox[i] * (scale_x if i % 2 == 0 else scale_y) + for i in range(len(bbox)) + ] + + text_regions.append(TextRegion( + text=text, + bbox=bbox, + confidence=confidence, + page=page_number + )) + + return text_regions + + @staticmethod + def _point_in_bbox( + x: float, y: float, + bbox: Tuple[float, float, float, float] + ) -> bool: + """Check if point (x, y) is inside bbox (x0, y0, x1, y1).""" + x0, y0, x1, y1 = bbox + return x0 <= x <= x1 and y0 <= y <= y1 + + @staticmethod + def _calculate_iou( + bbox1: Tuple[float, float, float, float], + bbox2: Tuple[float, float, float, float] + ) -> float: + """ + Calculate Intersection over Union (IoU) of two bboxes. + + Args: + bbox1: First bbox (x0, y0, x1, y1) + bbox2: Second bbox (x0, y0, x1, y1) + + Returns: + IoU value between 0 and 1 + """ + # Calculate intersection + x0 = max(bbox1[0], bbox2[0]) + y0 = max(bbox1[1], bbox2[1]) + x1 = min(bbox1[2], bbox2[2]) + y1 = min(bbox1[3], bbox2[3]) + + if x1 <= x0 or y1 <= y0: + return 0.0 + + intersection = (x1 - x0) * (y1 - y0) + + # Calculate union + area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) + area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + union = area1 + area2 - intersection + + if union <= 0: + return 0.0 + + return intersection / union diff --git a/backend/app/services/ocr_service.py b/backend/app/services/ocr_service.py index a419a6b..05ab5c2 100644 --- a/backend/app/services/ocr_service.py +++ b/backend/app/services/ocr_service.py @@ -46,6 +46,19 @@ except ImportError as e: logger = logging.getLogger(__name__) +# Sentinel value for "use PubLayNet default" - explicitly NO model specification +_USE_PUBLAYNET_DEFAULT = "__USE_PUBLAYNET_DEFAULT__" + +# Layout model mapping: user-friendly names to actual model names +# - "chinese": PP-DocLayout-S - Best for Chinese documents (forms, contracts, invoices) +# - "default": PubLayNet-based default model - Best for English documents +# - "cdla": picodet_lcnet_x1_0_fgd_layout_cdla - Alternative for Chinese layout +LAYOUT_MODEL_MAPPING = { + "chinese": "PP-DocLayout-S", + "default": _USE_PUBLAYNET_DEFAULT, # Uses default PubLayNet-based model (no custom model) + "cdla": "picodet_lcnet_x1_0_fgd_layout_cdla", +} + class OCRService: """ @@ -436,77 +449,45 @@ class OCRService: return self.ocr_engines[lang] - def _ensure_structure_engine(self, custom_params: Optional[Dict[str, any]] = None) -> PPStructureV3: + def _ensure_structure_engine(self, layout_model: Optional[str] = None) -> PPStructureV3: """ Get or create PP-Structure engine for layout analysis with GPU support. - Supports custom parameters that override default settings. + Supports layout model selection for different document types. Args: - custom_params: Optional dictionary of custom PP-StructureV3 parameters. - If provided, creates a new engine instance (not cached). - Supported keys: layout_detection_threshold, layout_nms_threshold, - layout_merge_bboxes_mode, layout_unclip_ratio, text_det_thresh, - text_det_box_thresh, text_det_unclip_ratio + layout_model: Layout detection model selection: + - "chinese": PP-DocLayout-S (best for Chinese documents) + - "default": PubLayNet-based (best for English documents) + - "cdla": CDLA model (alternative for Chinese layout) + - None: Use config default Returns: PPStructure engine instance """ - # If custom params provided, create a new engine instance (don't use cache) - if custom_params: - logger.info(f"Creating PP-StructureV3 engine with custom parameters (GPU: {self.use_gpu})") - logger.info(f"Custom params: {custom_params}") + # Resolve layout model name from user-friendly name + resolved_model_name = None + use_publaynet_default = False # Flag to explicitly use PubLayNet default (no model param) - try: - # Base configuration from settings - use_chart = settings.enable_chart_recognition - use_formula = settings.enable_formula_recognition - use_table = settings.enable_table_recognition + if layout_model: + resolved_model_name = LAYOUT_MODEL_MAPPING.get(layout_model) + if layout_model not in LAYOUT_MODEL_MAPPING: + logger.warning(f"Unknown layout model '{layout_model}', using config default") + resolved_model_name = settings.layout_detection_model_name + elif resolved_model_name == _USE_PUBLAYNET_DEFAULT: + # User explicitly selected "default" - use PubLayNet without custom model + use_publaynet_default = True + resolved_model_name = None + logger.info(f"Using layout model: {layout_model} -> PubLayNet default (no custom model)") + else: + logger.info(f"Using layout model: {layout_model} -> {resolved_model_name}") - # Parameter priority: custom > settings default - layout_threshold = custom_params.get('layout_detection_threshold', settings.layout_detection_threshold) - layout_nms = custom_params.get('layout_nms_threshold', settings.layout_nms_threshold) - layout_merge = custom_params.get('layout_merge_bboxes_mode', settings.layout_merge_mode) - layout_unclip = custom_params.get('layout_unclip_ratio', settings.layout_unclip_ratio) - text_thresh = custom_params.get('text_det_thresh', settings.text_det_thresh) - text_box_thresh = custom_params.get('text_det_box_thresh', settings.text_det_box_thresh) - text_unclip = custom_params.get('text_det_unclip_ratio', settings.text_det_unclip_ratio) + # Check if we need to recreate the engine due to different model + current_model = getattr(self, '_current_layout_model', None) + if self.structure_engine is not None and layout_model and layout_model != current_model: + logger.info(f"Layout model changed from {current_model} to {layout_model}, recreating engine") + self.structure_engine = None # Force recreation - logger.info(f"PP-StructureV3 config: table={use_table}, formula={use_formula}, chart={use_chart}") - logger.info(f"Layout config: threshold={layout_threshold}, nms={layout_nms}, merge={layout_merge}, unclip={layout_unclip}") - logger.info(f"Text detection: thresh={text_thresh}, box_thresh={text_box_thresh}, unclip={text_unclip}") - - # Create temporary engine with custom params (not cached) - custom_engine = PPStructureV3( - use_doc_orientation_classify=False, - use_doc_unwarping=False, - use_textline_orientation=False, - use_table_recognition=use_table, - use_formula_recognition=use_formula, - use_chart_recognition=use_chart, - layout_threshold=layout_threshold, - layout_nms=layout_nms, - layout_unclip_ratio=layout_unclip, - layout_merge_bboxes_mode=layout_merge, - text_det_thresh=text_thresh, - text_det_box_thresh=text_box_thresh, - text_det_unclip_ratio=text_unclip, - ) - - logger.info(f"PP-StructureV3 engine with custom params ready (PaddlePaddle {paddle.__version__}, {'GPU' if self.use_gpu else 'CPU'} mode)") - - # Check GPU memory after loading - if self.use_gpu and settings.enable_memory_optimization: - self._check_gpu_memory_usage() - - return custom_engine - - except Exception as e: - logger.error(f"Failed to create PP-StructureV3 engine with custom params: {e}") - # Fall back to default cached engine - logger.warning("Falling back to default cached engine") - custom_params = None # Clear custom params to use cached engine - - # Use cached default engine + # Use cached engine or create new one if self.structure_engine is None: logger.info(f"Initializing PP-StructureV3 engine (GPU: {self.use_gpu})") @@ -524,28 +505,51 @@ class OCRService: text_box_thresh = settings.text_det_box_thresh text_unclip = settings.text_det_unclip_ratio + # Layout model configuration: + # - If use_publaynet_default: don't specify any model (use PubLayNet default) + # - If resolved_model_name: use the specified model + # - Otherwise: use config default + if use_publaynet_default: + layout_model_name = None # Explicitly no model = PubLayNet default + elif resolved_model_name: + layout_model_name = resolved_model_name + else: + layout_model_name = settings.layout_detection_model_name + layout_model_dir = settings.layout_detection_model_dir + logger.info(f"PP-StructureV3 config: table={use_table}, formula={use_formula}, chart={use_chart}") + logger.info(f"Layout model: name={layout_model_name}, dir={layout_model_dir}") logger.info(f"Layout config: threshold={layout_threshold}, nms={layout_nms}, merge={layout_merge}, unclip={layout_unclip}") logger.info(f"Text detection: thresh={text_thresh}, box_thresh={text_box_thresh}, unclip={text_unclip}") - self.structure_engine = PPStructureV3( - use_doc_orientation_classify=False, - use_doc_unwarping=False, - use_textline_orientation=False, - use_table_recognition=use_table, - use_formula_recognition=use_formula, - use_chart_recognition=use_chart, - layout_threshold=layout_threshold, - layout_nms=layout_nms, - layout_unclip_ratio=layout_unclip, - layout_merge_bboxes_mode=layout_merge, # Use 'small' to minimize merging - text_det_thresh=text_thresh, - text_det_box_thresh=text_box_thresh, - text_det_unclip_ratio=text_unclip, - ) + # Build PPStructureV3 kwargs + pp_kwargs = { + 'use_doc_orientation_classify': False, + 'use_doc_unwarping': False, + 'use_textline_orientation': False, + 'use_table_recognition': use_table, + 'use_formula_recognition': use_formula, + 'use_chart_recognition': use_chart, + 'layout_threshold': layout_threshold, + 'layout_nms': layout_nms, + 'layout_unclip_ratio': layout_unclip, + 'layout_merge_bboxes_mode': layout_merge, + 'text_det_thresh': text_thresh, + 'text_det_box_thresh': text_box_thresh, + 'text_det_unclip_ratio': text_unclip, + } + + # Add layout model configuration if specified + if layout_model_name: + pp_kwargs['layout_detection_model_name'] = layout_model_name + if layout_model_dir: + pp_kwargs['layout_detection_model_dir'] = layout_model_dir + + self.structure_engine = PPStructureV3(**pp_kwargs) # Track model loading for cache management self._model_last_used['structure'] = datetime.now() + self._current_layout_model = layout_model # Track current model for recreation check logger.info(f"PP-StructureV3 engine ready (PaddlePaddle {paddle.__version__}, {'GPU' if self.use_gpu else 'CPU'} mode)") @@ -565,17 +569,27 @@ class OCRService: use_formula = settings.enable_formula_recognition use_table = settings.enable_table_recognition layout_threshold = settings.layout_detection_threshold + layout_model_name = settings.layout_detection_model_name + layout_model_dir = settings.layout_detection_model_dir - self.structure_engine = PPStructureV3( - use_doc_orientation_classify=False, - use_doc_unwarping=False, - use_textline_orientation=False, - use_table_recognition=use_table, - use_formula_recognition=use_formula, - use_chart_recognition=use_chart, - layout_threshold=layout_threshold, - ) - logger.info("PP-StructureV3 engine ready (CPU mode - fallback)") + # Build CPU fallback kwargs + cpu_kwargs = { + 'use_doc_orientation_classify': False, + 'use_doc_unwarping': False, + 'use_textline_orientation': False, + 'use_table_recognition': use_table, + 'use_formula_recognition': use_formula, + 'use_chart_recognition': use_chart, + 'layout_threshold': layout_threshold, + } + if layout_model_name: + cpu_kwargs['layout_detection_model_name'] = layout_model_name + if layout_model_dir: + cpu_kwargs['layout_detection_model_dir'] = layout_model_dir + + self.structure_engine = PPStructureV3(**cpu_kwargs) + self._current_layout_model = layout_model # Track current model for recreation check + logger.info(f"PP-StructureV3 engine ready (CPU mode - fallback, layout_model={layout_model_name})") else: raise @@ -813,7 +827,7 @@ class OCRService: confidence_threshold: Optional[float] = None, output_dir: Optional[Path] = None, current_page: int = 0, - pp_structure_params: Optional[Dict[str, any]] = None + layout_model: Optional[str] = None ) -> Dict: """ Process single image with OCR and layout analysis @@ -825,7 +839,7 @@ class OCRService: confidence_threshold: Minimum confidence threshold (uses default if None) output_dir: Optional output directory for saving extracted images current_page: Current page number (0-based) for multi-page documents - pp_structure_params: Optional custom PP-StructureV3 parameters + layout_model: Layout detection model ('chinese', 'default', 'cdla') Returns: Dictionary with OCR results and metadata @@ -894,7 +908,7 @@ class OCRService: confidence_threshold=confidence_threshold, output_dir=output_dir, current_page=page_num - 1, # Convert to 0-based page number for layout data - pp_structure_params=pp_structure_params + layout_model=layout_model ) # Accumulate results @@ -1040,7 +1054,7 @@ class OCRService: image_path, output_dir=output_dir, current_page=current_page, - pp_structure_params=pp_structure_params + layout_model=layout_model ) # Generate Markdown @@ -1078,6 +1092,38 @@ class OCRService: 'height': ocr_height }] + # Generate PP-StructureV3 debug outputs if enabled + if settings.pp_structure_debug_enabled and output_dir: + try: + from app.services.pp_structure_debug import PPStructureDebug + debug_service = PPStructureDebug(output_dir) + + # Save raw results as JSON + debug_service.save_raw_results( + pp_structure_results={ + 'elements': layout_data.get('elements', []), + 'total_elements': layout_data.get('total_elements', 0), + 'element_types': layout_data.get('element_types', {}), + 'reading_order': layout_data.get('reading_order', []), + 'enhanced': True, + 'has_parsing_res_list': True + }, + raw_ocr_regions=text_regions, + filename_prefix=image_path.stem + ) + + # Generate visualization if enabled + if settings.pp_structure_debug_visualization: + debug_service.generate_visualization( + image_path=image_path, + pp_structure_elements=layout_data.get('elements', []), + raw_ocr_regions=text_regions, + filename_prefix=image_path.stem + ) + logger.info(f"Generated PP-StructureV3 debug outputs for {image_path.name}") + except Exception as debug_error: + logger.warning(f"Failed to generate debug outputs: {debug_error}") + logger.info( f"OCR completed: {image_path.name} - " f"{len(text_regions)} regions, " @@ -1164,7 +1210,7 @@ class OCRService: image_path: Path, output_dir: Optional[Path] = None, current_page: int = 0, - pp_structure_params: Optional[Dict[str, any]] = None + layout_model: Optional[str] = None ) -> Tuple[Optional[Dict], List[Dict]]: """ Analyze document layout using PP-StructureV3 with enhanced element extraction @@ -1173,7 +1219,7 @@ class OCRService: image_path: Path to image file output_dir: Optional output directory for saving extracted images (defaults to image_path.parent) current_page: Current page number (0-based) for multi-page documents - pp_structure_params: Optional custom PP-StructureV3 parameters + layout_model: Layout detection model ('chinese', 'default', 'cdla') Returns: Tuple of (layout_data, images_metadata) @@ -1191,7 +1237,7 @@ class OCRService: f"Mode: {'CPU fallback' if self._cpu_fallback_active else 'GPU'}" ) - structure_engine = self._ensure_structure_engine(pp_structure_params) + structure_engine = self._ensure_structure_engine(layout_model) # Try enhanced processing first try: @@ -1425,7 +1471,7 @@ class OCRService: confidence_threshold: Optional[float] = None, output_dir: Optional[Path] = None, force_track: Optional[str] = None, - pp_structure_params: Optional[Dict[str, any]] = None + layout_model: Optional[str] = None ) -> Union[UnifiedDocument, Dict]: """ Process document using dual-track approach. @@ -1437,7 +1483,7 @@ class OCRService: confidence_threshold: Minimum confidence threshold output_dir: Optional output directory for extracted images force_track: Force specific track ("ocr" or "direct"), None for auto-detection - pp_structure_params: Optional custom PP-StructureV3 parameters (used for OCR track only) + layout_model: Layout detection model ('chinese', 'default', 'cdla') (used for OCR track only) Returns: UnifiedDocument if dual-track is enabled, Dict otherwise @@ -1445,7 +1491,7 @@ class OCRService: if not self.dual_track_enabled: # Fallback to traditional OCR processing return self.process_file_traditional( - file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params + file_path, lang, detect_layout, confidence_threshold, output_dir, layout_model ) start_time = datetime.now() @@ -1517,7 +1563,7 @@ class OCRService: ocr_result = self.process_file_traditional( actual_file_path, lang, detect_layout=True, confidence_threshold=confidence_threshold, - output_dir=output_dir, pp_structure_params=pp_structure_params + output_dir=output_dir, layout_model=layout_model ) # Convert OCR result to extract images @@ -1550,7 +1596,7 @@ class OCRService: # Use OCR for scanned documents, images, etc. logger.info("Using OCR track (PaddleOCR)") ocr_result = self.process_file_traditional( - file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params + file_path, lang, detect_layout, confidence_threshold, output_dir, layout_model ) # Convert OCR result to UnifiedDocument using the converter @@ -1580,7 +1626,7 @@ class OCRService: logger.error(f"Error in dual-track processing: {e}") # Fallback to traditional OCR return self.process_file_traditional( - file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params + file_path, lang, detect_layout, confidence_threshold, output_dir, layout_model ) def _merge_ocr_images_into_direct( @@ -1659,7 +1705,7 @@ class OCRService: detect_layout: bool = True, confidence_threshold: Optional[float] = None, output_dir: Optional[Path] = None, - pp_structure_params: Optional[Dict[str, any]] = None + layout_model: Optional[str] = None ) -> Dict: """ Traditional OCR processing (legacy method). @@ -1670,7 +1716,7 @@ class OCRService: detect_layout: Whether to perform layout analysis confidence_threshold: Minimum confidence threshold output_dir: Optional output directory - pp_structure_params: Optional custom PP-StructureV3 parameters + layout_model: Layout detection model ('chinese', 'default', 'cdla') Returns: Dictionary with OCR results in legacy format @@ -1683,7 +1729,7 @@ class OCRService: all_results = [] for i, image_path in enumerate(image_paths): result = self.process_image( - image_path, lang, detect_layout, confidence_threshold, output_dir, i, pp_structure_params + image_path, lang, detect_layout, confidence_threshold, output_dir, i, layout_model ) all_results.append(result) @@ -1699,7 +1745,7 @@ class OCRService: else: # Single image or other file return self.process_image( - file_path, lang, detect_layout, confidence_threshold, output_dir, 0, pp_structure_params + file_path, lang, detect_layout, confidence_threshold, output_dir, 0, layout_model ) def _combine_results(self, results: List[Dict]) -> Dict: @@ -1784,7 +1830,7 @@ class OCRService: output_dir: Optional[Path] = None, use_dual_track: bool = True, force_track: Optional[str] = None, - pp_structure_params: Optional[Dict[str, any]] = None + layout_model: Optional[str] = None ) -> Union[UnifiedDocument, Dict]: """ Main processing method with dual-track support. @@ -1797,7 +1843,7 @@ class OCRService: output_dir: Optional output directory use_dual_track: Whether to use dual-track processing (default True) force_track: Force specific track ("ocr" or "direct") - pp_structure_params: Optional custom PP-StructureV3 parameters (used for OCR track only) + layout_model: Layout detection model ('chinese', 'default', 'cdla') (used for OCR track only) Returns: UnifiedDocument if dual-track is enabled and use_dual_track=True, @@ -1809,12 +1855,12 @@ class OCRService: if (use_dual_track or force_track) and self.dual_track_enabled: # Use dual-track processing (or forced track) return self.process_with_dual_track( - file_path, lang, detect_layout, confidence_threshold, output_dir, force_track, pp_structure_params + file_path, lang, detect_layout, confidence_threshold, output_dir, force_track, layout_model ) else: # Use traditional OCR processing (no force_track support) return self.process_file_traditional( - file_path, lang, detect_layout, confidence_threshold, output_dir, pp_structure_params + file_path, lang, detect_layout, confidence_threshold, output_dir, layout_model ) def process_legacy( diff --git a/backend/app/services/ocr_to_unified_converter.py b/backend/app/services/ocr_to_unified_converter.py index 8971bea..9addffe 100644 --- a/backend/app/services/ocr_to_unified_converter.py +++ b/backend/app/services/ocr_to_unified_converter.py @@ -3,6 +3,9 @@ OCR to UnifiedDocument Converter Converts PP-StructureV3 OCR results to UnifiedDocument format, preserving all structure information and metadata. + +Includes gap filling support to supplement PP-StructureV3 output with raw OCR +regions when significant content loss is detected. """ import logging @@ -16,10 +19,165 @@ from app.models.unified_document import ( BoundingBox, StyleInfo, TableData, ElementType, ProcessingTrack, TableCell, Dimensions ) +from app.services.gap_filling_service import GapFillingService logger = logging.getLogger(__name__) +def trim_empty_columns(table_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Remove empty columns from a table dictionary. + + A column is considered empty if ALL cells in that column have content that is + empty or whitespace-only (using .strip() to determine emptiness). + + This function: + 1. Identifies columns where every cell's content is empty/whitespace + 2. Removes identified empty columns + 3. Updates cols/columns value + 4. Recalculates each cell's col index + 5. Adjusts col_span when spans cross removed columns + 6. Removes cells entirely when their complete span falls within removed columns + 7. Preserves original bbox (no layout drift) + + Args: + table_dict: Table dictionary with keys: rows, cols/columns, cells + + Returns: + Cleaned table dictionary with empty columns removed + """ + cells = table_dict.get('cells', []) + if not cells: + return table_dict + + # Get original column count + original_cols = table_dict.get('cols', table_dict.get('columns', 0)) + if original_cols == 0: + # Calculate from cells if not provided + max_col = 0 + for cell in cells: + cell_col = cell.get('col', 0) if isinstance(cell, dict) else getattr(cell, 'col', 0) + cell_span = cell.get('col_span', 1) if isinstance(cell, dict) else getattr(cell, 'col_span', 1) + max_col = max(max_col, cell_col + cell_span) + original_cols = max_col + + if original_cols == 0: + return table_dict + + # Build a map: column_index -> list of cell contents + # For cells with col_span > 1, we only check their primary column + column_contents: Dict[int, List[str]] = {i: [] for i in range(original_cols)} + + for cell in cells: + if isinstance(cell, dict): + col = cell.get('col', 0) + col_span = cell.get('col_span', 1) + content = cell.get('content', '') + else: + col = getattr(cell, 'col', 0) + col_span = getattr(cell, 'col_span', 1) + content = getattr(cell, 'content', '') + + # Mark content for each column this cell spans + for c in range(col, min(col + col_span, original_cols)): + if c in column_contents: + column_contents[c].append(str(content).strip() if content else '') + + # Identify empty columns (all content is empty/whitespace) + empty_columns = set() + for col_idx, contents in column_contents.items(): + # A column is empty if ALL cells in it have empty content + # Note: If a column has no cells at all, it's considered empty + if all(c == '' for c in contents): + empty_columns.add(col_idx) + + if not empty_columns: + # No empty columns to remove, just ensure cols is set + result = dict(table_dict) + if result.get('cols', result.get('columns', 0)) == 0: + result['cols'] = original_cols + if 'columns' in result: + result['columns'] = original_cols + return result + + logger.debug(f"Removing empty columns: {sorted(empty_columns)} from table with {original_cols} cols") + + # Build column mapping: old_col -> new_col (or None if removed) + col_mapping: Dict[int, Optional[int]] = {} + new_col = 0 + for old_col in range(original_cols): + if old_col in empty_columns: + col_mapping[old_col] = None + else: + col_mapping[old_col] = new_col + new_col += 1 + + new_cols = new_col + + # Process cells + new_cells = [] + for cell in cells: + if isinstance(cell, dict): + old_col = cell.get('col', 0) + old_col_span = cell.get('col_span', 1) + else: + old_col = getattr(cell, 'col', 0) + old_col_span = getattr(cell, 'col_span', 1) + + # Calculate new col and col_span + # Find the first non-removed column in this cell's span + new_start_col = None + new_end_col = None + + for c in range(old_col, min(old_col + old_col_span, original_cols)): + mapped = col_mapping.get(c) + if mapped is not None: + if new_start_col is None: + new_start_col = mapped + new_end_col = mapped + + # If entire span falls within removed columns, skip this cell + if new_start_col is None: + logger.debug(f"Removing cell at row={cell.get('row', 0) if isinstance(cell, dict) else cell.row}, " + f"col={old_col} (entire span in removed columns)") + continue + + new_col_span = new_end_col - new_start_col + 1 + + # Create new cell + if isinstance(cell, dict): + new_cell = dict(cell) + new_cell['col'] = new_start_col + new_cell['col_span'] = new_col_span + else: + # Handle TableCell objects + new_cell = { + 'row': cell.row, + 'col': new_start_col, + 'row_span': cell.row_span, + 'col_span': new_col_span, + 'content': cell.content + } + if hasattr(cell, 'bbox') and cell.bbox: + new_cell['bbox'] = cell.bbox + if hasattr(cell, 'style') and cell.style: + new_cell['style'] = cell.style + + new_cells.append(new_cell) + + # Build result + result = dict(table_dict) + result['cells'] = new_cells + result['cols'] = new_cols + if 'columns' in result: + result['columns'] = new_cols + + logger.info(f"Trimmed table: {original_cols} -> {new_cols} columns, " + f"{len(cells)} -> {len(new_cells)} cells") + + return result + + class OCRToUnifiedConverter: """ Converter for transforming PP-StructureV3 OCR results to UnifiedDocument format. @@ -30,11 +188,19 @@ class OCRToUnifiedConverter: - Multi-page document assembly - Metadata preservation - Structure relationship mapping + - Gap filling with raw OCR regions (when PP-StructureV3 misses content) """ - def __init__(self): - """Initialize the converter.""" + def __init__(self, enable_gap_filling: bool = True): + """ + Initialize the converter. + + Args: + enable_gap_filling: Whether to enable gap filling with raw OCR regions + """ self.element_counter = 0 + self.gap_filling_service = GapFillingService() if enable_gap_filling else None + self.gap_filling_stats: Dict[str, Any] = {} def convert( self, @@ -120,13 +286,21 @@ class OCRToUnifiedConverter: Extract pages from OCR results. Handles both enhanced PP-StructureV3 results (with parsing_res_list) - and traditional markdown results. + and traditional markdown results. Applies gap filling when enabled. """ pages = [] + # Extract raw OCR text regions for gap filling + raw_text_regions = ocr_results.get('text_regions', []) + ocr_dimensions = ocr_results.get('ocr_dimensions', {}) + # Check if we have enhanced results from PPStructureEnhanced if 'enhanced_results' in ocr_results: - pages = self._extract_from_enhanced_results(ocr_results['enhanced_results']) + pages = self._extract_from_enhanced_results( + ocr_results['enhanced_results'], + raw_text_regions=raw_text_regions, + ocr_dimensions=ocr_dimensions + ) # Check for traditional OCR results with text_regions at top level (from process_file_traditional) elif 'text_regions' in ocr_results: pages = self._extract_from_traditional_ocr(ocr_results) @@ -143,9 +317,21 @@ class OCRToUnifiedConverter: def _extract_from_enhanced_results( self, - enhanced_results: List[Dict[str, Any]] + enhanced_results: List[Dict[str, Any]], + raw_text_regions: Optional[List[Dict[str, Any]]] = None, + ocr_dimensions: Optional[Dict[str, Any]] = None ) -> List[Page]: - """Extract pages from enhanced PP-StructureV3 results.""" + """ + Extract pages from enhanced PP-StructureV3 results. + + Applies gap filling when enabled to supplement PP-StructureV3 output + with raw OCR regions that were not detected by the layout model. + + Args: + enhanced_results: PP-StructureV3 enhanced results + raw_text_regions: Raw OCR text regions for gap filling + ocr_dimensions: OCR image dimensions for coordinate alignment + """ pages = [] for page_idx, page_result in enumerate(enhanced_results): @@ -158,15 +344,52 @@ class OCRToUnifiedConverter: if element: elements.append(element) + # Get page dimensions + pp_dimensions = Dimensions( + width=page_result.get('width', 0), + height=page_result.get('height', 0) + ) + + # Apply gap filling if enabled and raw regions available + if self.gap_filling_service and raw_text_regions: + # Filter raw regions for current page + page_raw_regions = [ + r for r in raw_text_regions + if r.get('page', 0) == page_idx or r.get('page', 1) == page_idx + 1 + ] + + if page_raw_regions: + supplemented, stats = self.gap_filling_service.fill_gaps( + raw_ocr_regions=page_raw_regions, + pp_structure_elements=elements, + page_number=page_idx + 1, + ocr_dimensions=ocr_dimensions, + pp_dimensions=pp_dimensions + ) + + # Store statistics + self.gap_filling_stats[f'page_{page_idx + 1}'] = stats + + if supplemented: + logger.info( + f"Page {page_idx + 1}: Gap filling added {len(supplemented)} elements " + f"(coverage: {stats.get('coverage_ratio', 0):.2%})" + ) + elements.extend(supplemented) + + # Recalculate reading order for combined elements + reading_order = self.gap_filling_service.recalculate_reading_order(elements) + page_result['reading_order'] = reading_order + # Create page page = Page( page_number=page_idx + 1, - dimensions=Dimensions( - width=page_result.get('width', 0), - height=page_result.get('height', 0) - ), + dimensions=pp_dimensions, elements=elements, - metadata={'reading_order': page_result.get('reading_order', [])} + metadata={ + 'reading_order': page_result.get('reading_order', []), + 'gap_filling': self.gap_filling_stats.get(f'page_{page_idx + 1}', {}) + } ) pages.append(page) @@ -500,6 +723,9 @@ class OCRToUnifiedConverter: ) -> Optional[DocumentElement]: """Convert table data to DocumentElement.""" try: + # Clean up empty columns before building TableData + table_dict = trim_empty_columns(table_dict) + # Extract bbox bbox_data = table_dict.get('bbox', [0, 0, 0, 0]) bbox = BoundingBox( @@ -587,14 +813,22 @@ class OCRToUnifiedConverter: cells = [] headers = [] rows = table.find_all('tr') + num_rows = len(rows) - # Track actual column positions accounting for rowspan/colspan - # This is a simplified approach - complex spanning may need enhancement + # First pass: calculate total columns by finding max column extent + # Track cells that span multiple rows: occupied[row][col] = True + occupied: Dict[int, Dict[int, bool]] = {r: {} for r in range(num_rows)} + + # Parse all cells with proper rowspan/colspan handling for row_idx, row in enumerate(rows): row_cells = row.find_all(['td', 'th']) col_idx = 0 for cell in row_cells: + # Skip columns that are occupied by rowspan from previous rows + while occupied[row_idx].get(col_idx, False): + col_idx += 1 + cell_content = cell.get_text(strip=True) rowspan = int(cell.get('rowspan', 1)) colspan = int(cell.get('colspan', 1)) @@ -611,26 +845,66 @@ class OCRToUnifiedConverter: if cell.name == 'th' or row_idx == 0: headers.append(cell_content) + # Mark cells as occupied for rowspan/colspan + for r in range(row_idx, min(row_idx + rowspan, num_rows)): + for c in range(col_idx, col_idx + colspan): + if r not in occupied: + occupied[r] = {} + occupied[r][c] = True + # Advance column index by colspan col_idx += colspan - # Calculate actual dimensions - num_rows = len(rows) - num_cols = max( - sum(int(cell.get('colspan', 1)) for cell in row.find_all(['td', 'th'])) - for row in rows - ) if rows else 0 + # Calculate actual column count from occupied cells + num_cols = 0 + for r in range(num_rows): + if occupied[r]: + max_col_in_row = max(occupied[r].keys()) + 1 + num_cols = max(num_cols, max_col_in_row) logger.debug( f"Parsed HTML table: {num_rows} rows, {num_cols} cols, {len(cells)} cells" ) + # Build table dict for cleanup + table_dict = { + 'rows': num_rows, + 'cols': num_cols, + 'cells': [ + { + 'row': c.row, + 'col': c.col, + 'row_span': c.row_span, + 'col_span': c.col_span, + 'content': c.content + } + for c in cells + ], + 'headers': headers if headers else None, + 'caption': extracted_text if extracted_text else None + } + + # Clean up empty columns + table_dict = trim_empty_columns(table_dict) + + # Convert cleaned cells back to TableCell objects + cleaned_cells = [ + TableCell( + row=c['row'], + col=c['col'], + row_span=c.get('row_span', 1), + col_span=c.get('col_span', 1), + content=c.get('content', '') + ) + for c in table_dict.get('cells', []) + ] + return TableData( - rows=num_rows, - cols=num_cols, - cells=cells, - headers=headers if headers else None, - caption=extracted_text if extracted_text else None + rows=table_dict.get('rows', num_rows), + cols=table_dict.get('cols', num_cols), + cells=cleaned_cells, + headers=table_dict.get('headers'), + caption=table_dict.get('caption') ) except ImportError: diff --git a/backend/app/services/pp_structure_debug.py b/backend/app/services/pp_structure_debug.py new file mode 100644 index 0000000..f15732f --- /dev/null +++ b/backend/app/services/pp_structure_debug.py @@ -0,0 +1,344 @@ +""" +PP-StructureV3 Debug Service + +Provides debugging tools for visualizing and saving PP-StructureV3 results: +- Save raw results as JSON for inspection +- Generate visualization images showing detected bboxes +- Compare raw OCR regions with PP-StructureV3 elements +""" + +import json +import logging +from pathlib import Path +from typing import Dict, List, Any, Optional, Tuple +from datetime import datetime + +from PIL import Image, ImageDraw, ImageFont + +logger = logging.getLogger(__name__) + +# Color palette for different element types (RGB) +ELEMENT_COLORS: Dict[str, Tuple[int, int, int]] = { + 'text': (0, 128, 0), # Green + 'title': (0, 0, 255), # Blue + 'table': (255, 0, 0), # Red + 'figure': (255, 165, 0), # Orange + 'image': (255, 165, 0), # Orange + 'header': (128, 0, 128), # Purple + 'footer': (128, 0, 128), # Purple + 'equation': (0, 255, 255), # Cyan + 'chart': (255, 192, 203), # Pink + 'list': (139, 69, 19), # Brown + 'reference': (128, 128, 128), # Gray + 'default': (255, 0, 255), # Magenta for unknown types +} + +# Color for raw OCR regions +RAW_OCR_COLOR = (255, 215, 0) # Gold + + +class PPStructureDebug: + """Debug service for PP-StructureV3 analysis results.""" + + def __init__(self, output_dir: Path): + """ + Initialize debug service. + + Args: + output_dir: Directory to save debug outputs + """ + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + + def save_raw_results( + self, + pp_structure_results: Dict[str, Any], + raw_ocr_regions: List[Dict[str, Any]], + filename_prefix: str = "debug" + ) -> Dict[str, Path]: + """ + Save raw PP-StructureV3 results and OCR regions as JSON files. + + Args: + pp_structure_results: Raw PP-StructureV3 analysis results + raw_ocr_regions: Raw OCR text regions + filename_prefix: Prefix for output files + + Returns: + Dictionary with paths to saved files + """ + saved_files = {} + + # Save PP-StructureV3 results + pp_json_path = self.output_dir / f"{filename_prefix}_pp_structure_raw.json" + try: + # Convert any non-serializable types + serializable_results = self._make_serializable(pp_structure_results) + with open(pp_json_path, 'w', encoding='utf-8') as f: + json.dump(serializable_results, f, ensure_ascii=False, indent=2) + saved_files['pp_structure'] = pp_json_path + logger.info(f"Saved PP-StructureV3 raw results to {pp_json_path}") + except Exception as e: + logger.error(f"Failed to save PP-StructureV3 results: {e}") + + # Save raw OCR regions + ocr_json_path = self.output_dir / f"{filename_prefix}_raw_ocr_regions.json" + try: + serializable_ocr = self._make_serializable(raw_ocr_regions) + with open(ocr_json_path, 'w', encoding='utf-8') as f: + json.dump(serializable_ocr, f, ensure_ascii=False, indent=2) + saved_files['raw_ocr'] = ocr_json_path + logger.info(f"Saved raw OCR regions to {ocr_json_path}") + except Exception as e: + logger.error(f"Failed to save raw OCR regions: {e}") + + # Save summary comparison + summary_path = self.output_dir / f"{filename_prefix}_debug_summary.json" + try: + summary = self._generate_summary(pp_structure_results, raw_ocr_regions) + with open(summary_path, 'w', encoding='utf-8') as f: + json.dump(summary, f, ensure_ascii=False, indent=2) + saved_files['summary'] = summary_path + logger.info(f"Saved debug summary to {summary_path}") + except Exception as e: + logger.error(f"Failed to save debug summary: {e}") + + return saved_files + + def generate_visualization( + self, + image_path: Path, + pp_structure_elements: List[Dict[str, Any]], + raw_ocr_regions: Optional[List[Dict[str, Any]]] = None, + filename_prefix: str = "debug", + show_labels: bool = True, + show_raw_ocr: bool = True + ) -> Optional[Path]: + """ + Generate visualization image showing detected elements. + + Args: + image_path: Path to original image + pp_structure_elements: PP-StructureV3 detected elements + raw_ocr_regions: Optional raw OCR regions to overlay + filename_prefix: Prefix for output file + show_labels: Whether to show element type labels + show_raw_ocr: Whether to show raw OCR regions + + Returns: + Path to generated visualization image + """ + try: + # Load original image + img = Image.open(image_path) + if img.mode != 'RGB': + img = img.convert('RGB') + + # Create copy for drawing + viz_img = img.copy() + draw = ImageDraw.Draw(viz_img) + + # Try to load a font, fall back to default + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14) + small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 10) + except (IOError, OSError): + try: + font = ImageFont.truetype("/home/egg/project/Tool_OCR/backend/fonts/NotoSansSC-Regular.ttf", 14) + small_font = ImageFont.truetype("/home/egg/project/Tool_OCR/backend/fonts/NotoSansSC-Regular.ttf", 10) + except (IOError, OSError): + font = ImageFont.load_default() + small_font = font + + # Draw raw OCR regions first (so PP-Structure boxes are on top) + if show_raw_ocr and raw_ocr_regions: + for idx, region in enumerate(raw_ocr_regions): + bbox = self._normalize_bbox(region.get('bbox', [])) + if bbox: + # Draw with dashed style simulation (draw thin lines) + x0, y0, x1, y1 = bbox + draw.rectangle([x0, y0, x1, y1], outline=RAW_OCR_COLOR, width=1) + + # Add small label + if show_labels: + confidence = region.get('confidence', 0) + label = f"OCR:{confidence:.2f}" + draw.text((x0, y0 - 12), label, fill=RAW_OCR_COLOR, font=small_font) + + # Draw PP-StructureV3 elements + for idx, elem in enumerate(pp_structure_elements): + elem_type = elem.get('type', 'default') + if hasattr(elem_type, 'value'): + elem_type = elem_type.value + elem_type = str(elem_type).lower() + + color = ELEMENT_COLORS.get(elem_type, ELEMENT_COLORS['default']) + bbox = self._normalize_bbox(elem.get('bbox', [])) + + if bbox: + x0, y0, x1, y1 = bbox + # Draw thicker rectangle for PP-Structure elements + draw.rectangle([x0, y0, x1, y1], outline=color, width=3) + + # Add label + if show_labels: + label = f"{idx}:{elem_type}" + # Draw label background + text_bbox = draw.textbbox((x0, y0 - 18), label, font=font) + draw.rectangle(text_bbox, fill=(255, 255, 255, 200)) + draw.text((x0, y0 - 18), label, fill=color, font=font) + + # Add legend + self._draw_legend(draw, img.width, font) + + # Add image info + info_text = f"PP-Structure: {len(pp_structure_elements)} elements" + if raw_ocr_regions: + info_text += f" | Raw OCR: {len(raw_ocr_regions)} regions" + info_text += f" | Size: {img.width}x{img.height}" + draw.text((10, img.height - 25), info_text, fill=(0, 0, 0), font=font) + + # Save visualization + viz_path = self.output_dir / f"{filename_prefix}_pp_structure_viz.png" + viz_img.save(viz_path, 'PNG') + logger.info(f"Saved visualization to {viz_path}") + + return viz_path + + except Exception as e: + logger.error(f"Failed to generate visualization: {e}") + import traceback + traceback.print_exc() + return None + + def _draw_legend(self, draw: ImageDraw, img_width: int, font: ImageFont): + """Draw a legend showing element type colors.""" + legend_x = img_width - 150 + legend_y = 10 + + # Draw legend background + draw.rectangle( + [legend_x - 5, legend_y - 5, img_width - 5, legend_y + len(ELEMENT_COLORS) * 18 + 25], + fill=(255, 255, 255, 230), + outline=(0, 0, 0) + ) + + draw.text((legend_x, legend_y), "Legend:", fill=(0, 0, 0), font=font) + legend_y += 20 + + for elem_type, color in ELEMENT_COLORS.items(): + if elem_type == 'default': + continue + draw.rectangle([legend_x, legend_y + 2, legend_x + 12, legend_y + 14], fill=color) + draw.text((legend_x + 18, legend_y), elem_type, fill=(0, 0, 0), font=font) + legend_y += 18 + + # Add raw OCR legend entry + draw.rectangle([legend_x, legend_y + 2, legend_x + 12, legend_y + 14], fill=RAW_OCR_COLOR) + draw.text((legend_x + 18, legend_y), "raw_ocr", fill=(0, 0, 0), font=font) + + def _normalize_bbox(self, bbox: Any) -> Optional[Tuple[float, float, float, float]]: + """Normalize bbox to (x0, y0, x1, y1) format.""" + if not bbox: + return None + + try: + # Handle nested list format [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] + if isinstance(bbox, (list, tuple)) and len(bbox) >= 1: + if isinstance(bbox[0], (list, tuple)): + xs = [pt[0] for pt in bbox if len(pt) >= 2] + ys = [pt[1] for pt in bbox if len(pt) >= 2] + if xs and ys: + return (min(xs), min(ys), max(xs), max(ys)) + + # Handle flat list [x0, y0, x1, y1] + if isinstance(bbox, (list, tuple)) and len(bbox) == 4: + return (float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])) + + # Handle flat polygon [x1, y1, x2, y2, ...] + if isinstance(bbox, (list, tuple)) and len(bbox) >= 8: + xs = [bbox[i] for i in range(0, len(bbox), 2)] + ys = [bbox[i] for i in range(1, len(bbox), 2)] + return (min(xs), min(ys), max(xs), max(ys)) + + # Handle dict format + if isinstance(bbox, dict): + return ( + float(bbox.get('x0', bbox.get('x_min', 0))), + float(bbox.get('y0', bbox.get('y_min', 0))), + float(bbox.get('x1', bbox.get('x_max', 0))), + float(bbox.get('y1', bbox.get('y_max', 0))) + ) + + except (TypeError, ValueError, IndexError) as e: + logger.warning(f"Failed to normalize bbox {bbox}: {e}") + + return None + + def _generate_summary( + self, + pp_structure_results: Dict[str, Any], + raw_ocr_regions: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """Generate summary comparing PP-Structure and raw OCR.""" + pp_elements = pp_structure_results.get('elements', []) + + # Count element types + type_counts = {} + for elem in pp_elements: + elem_type = elem.get('type', 'unknown') + if hasattr(elem_type, 'value'): + elem_type = elem_type.value + type_counts[str(elem_type)] = type_counts.get(str(elem_type), 0) + 1 + + # Calculate bounding box coverage + pp_bbox_area = 0 + ocr_bbox_area = 0 + + for elem in pp_elements: + bbox = self._normalize_bbox(elem.get('bbox')) + if bbox: + pp_bbox_area += (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + + for region in raw_ocr_regions: + bbox = self._normalize_bbox(region.get('bbox')) + if bbox: + ocr_bbox_area += (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + + return { + 'timestamp': datetime.now().isoformat(), + 'pp_structure': { + 'total_elements': len(pp_elements), + 'element_types': type_counts, + 'total_bbox_area': pp_bbox_area, + 'has_parsing_res_list': pp_structure_results.get('has_parsing_res_list', False) + }, + 'raw_ocr': { + 'total_regions': len(raw_ocr_regions), + 'total_bbox_area': ocr_bbox_area, + 'avg_confidence': sum(r.get('confidence', 0) for r in raw_ocr_regions) / len(raw_ocr_regions) if raw_ocr_regions else 0 + }, + 'comparison': { + 'element_count_ratio': len(pp_elements) / len(raw_ocr_regions) if raw_ocr_regions else 0, + 'area_ratio': pp_bbox_area / ocr_bbox_area if ocr_bbox_area > 0 else 0, + 'potential_gap': len(raw_ocr_regions) - len(pp_elements) if raw_ocr_regions else 0 + } + } + + def _make_serializable(self, obj: Any) -> Any: + """Convert object to JSON-serializable format.""" + if obj is None: + return None + if isinstance(obj, (str, int, float, bool)): + return obj + if isinstance(obj, (list, tuple)): + return [self._make_serializable(item) for item in obj] + if isinstance(obj, dict): + return {str(k): self._make_serializable(v) for k, v in obj.items()} + if hasattr(obj, 'value'): + return obj.value + if hasattr(obj, '__dict__'): + return self._make_serializable(obj.__dict__) + if hasattr(obj, 'tolist'): # numpy array + return obj.tolist() + return str(obj) diff --git a/backend/tests/api/test_layout_model_api.py b/backend/tests/api/test_layout_model_api.py new file mode 100644 index 0000000..947ff41 --- /dev/null +++ b/backend/tests/api/test_layout_model_api.py @@ -0,0 +1,332 @@ +""" +API integration tests for Layout Model Selection feature. + +This replaces the deprecated PP-StructureV3 parameter tests. +""" + +import pytest +from fastapi.testclient import TestClient +from unittest.mock import patch +from app.main import app +from app.core.database import get_db +from app.models.user import User +from app.models.task import Task, TaskStatus, TaskFile + + +@pytest.fixture +def client(): + """Create test client""" + return TestClient(app) + + +@pytest.fixture +def test_user(db_session): + """Create test user""" + user = User( + email="test@example.com", + hashed_password="test_hash", + is_active=True + ) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + return user + + +@pytest.fixture +def test_task(db_session, test_user): + """Create test task with uploaded file""" + task = Task( + user_id=test_user.id, + task_id="test-task-123", + filename="test.pdf", + status=TaskStatus.PENDING + ) + db_session.add(task) + db_session.commit() + db_session.refresh(task) + + # Add task file + task_file = TaskFile( + task_id=task.id, + original_name="test.pdf", + stored_path="/tmp/test.pdf", + file_size=1024, + mime_type="application/pdf" + ) + db_session.add(task_file) + db_session.commit() + + return task + + +class TestLayoutModelSchema: + """Test LayoutModel and ProcessingOptions schema validation""" + + def test_processing_options_accepts_layout_model(self): + """Verify ProcessingOptions schema accepts layout_model parameter""" + from app.schemas.task import ProcessingOptions, LayoutModelEnum + + options = ProcessingOptions( + use_dual_track=True, + language='ch', + layout_model=LayoutModelEnum.CHINESE + ) + + assert options.layout_model == LayoutModelEnum.CHINESE + + def test_layout_model_enum_values(self): + """Verify all layout model enum values are valid""" + from app.schemas.task import LayoutModelEnum + + assert LayoutModelEnum.CHINESE.value == "chinese" + assert LayoutModelEnum.DEFAULT.value == "default" + assert LayoutModelEnum.CDLA.value == "cdla" + + def test_default_layout_model_is_chinese(self): + """Verify default layout model is 'chinese' for best Chinese document support""" + from app.schemas.task import ProcessingOptions + + options = ProcessingOptions() + + # Default should be chinese + assert options.layout_model.value == "chinese" + + def test_layout_model_string_values_accepted(self): + """Verify string values are accepted for layout_model""" + from app.schemas.task import ProcessingOptions + + # String values should be converted to enum + options = ProcessingOptions(layout_model="default") + assert options.layout_model.value == "default" + + options = ProcessingOptions(layout_model="cdla") + assert options.layout_model.value == "cdla" + + def test_invalid_layout_model_rejected(self): + """Verify invalid layout model values are rejected""" + from app.schemas.task import ProcessingOptions + from pydantic import ValidationError + + with pytest.raises(ValidationError): + ProcessingOptions(layout_model="invalid_model") + + +class TestStartTaskEndpoint: + """Test /tasks/{task_id}/start endpoint with layout_model parameter""" + + @patch('app.routers.tasks.process_task_ocr') + def test_start_task_with_layout_model(self, mock_process_ocr, client, test_task, db_session): + """Verify layout_model is accepted and passed to OCR service""" + + # Override get_db dependency + def override_get_db(): + try: + yield db_session + finally: + pass + + # Override auth dependency + def override_get_current_user(): + return test_task.user + + app.dependency_overrides[get_db] = override_get_db + from app.core.deps import get_current_user + app.dependency_overrides[get_current_user] = override_get_current_user + + # Request body with layout_model + request_body = { + "use_dual_track": True, + "language": "ch", + "layout_model": "chinese" + } + + # Make API call + response = client.post( + f"/api/v2/tasks/{test_task.task_id}/start", + json=request_body + ) + + # Verify response + assert response.status_code == 200 + data = response.json() + assert data['status'] == 'processing' + + # Verify background task was called with layout_model + mock_process_ocr.assert_called_once() + call_kwargs = mock_process_ocr.call_args[1] + + assert 'layout_model' in call_kwargs + assert call_kwargs['layout_model'] == 'chinese' + + # Clean up + app.dependency_overrides.clear() + + @patch('app.routers.tasks.process_task_ocr') + def test_start_task_with_default_model(self, mock_process_ocr, client, test_task, db_session): + """Verify 'default' layout model is accepted""" + + def override_get_db(): + try: + yield db_session + finally: + pass + + def override_get_current_user(): + return test_task.user + + app.dependency_overrides[get_db] = override_get_db + from app.core.deps import get_current_user + app.dependency_overrides[get_current_user] = override_get_current_user + + request_body = { + "use_dual_track": True, + "layout_model": "default" + } + + response = client.post( + f"/api/v2/tasks/{test_task.task_id}/start", + json=request_body + ) + + assert response.status_code == 200 + + mock_process_ocr.assert_called_once() + call_kwargs = mock_process_ocr.call_args[1] + assert call_kwargs['layout_model'] == 'default' + + app.dependency_overrides.clear() + + @patch('app.routers.tasks.process_task_ocr') + def test_start_task_with_cdla_model(self, mock_process_ocr, client, test_task, db_session): + """Verify 'cdla' layout model is accepted""" + + def override_get_db(): + try: + yield db_session + finally: + pass + + def override_get_current_user(): + return test_task.user + + app.dependency_overrides[get_db] = override_get_db + from app.core.deps import get_current_user + app.dependency_overrides[get_current_user] = override_get_current_user + + request_body = { + "use_dual_track": True, + "layout_model": "cdla" + } + + response = client.post( + f"/api/v2/tasks/{test_task.task_id}/start", + json=request_body + ) + + assert response.status_code == 200 + + mock_process_ocr.assert_called_once() + call_kwargs = mock_process_ocr.call_args[1] + assert call_kwargs['layout_model'] == 'cdla' + + app.dependency_overrides.clear() + + @patch('app.routers.tasks.process_task_ocr') + def test_start_task_without_layout_model_uses_default(self, mock_process_ocr, client, test_task, db_session): + """Verify task can start without layout_model (uses 'chinese' as default)""" + + def override_get_db(): + try: + yield db_session + finally: + pass + + def override_get_current_user(): + return test_task.user + + app.dependency_overrides[get_db] = override_get_db + from app.core.deps import get_current_user + app.dependency_overrides[get_current_user] = override_get_current_user + + # Request without layout_model + request_body = { + "use_dual_track": True, + "language": "ch" + } + + response = client.post( + f"/api/v2/tasks/{test_task.task_id}/start", + json=request_body + ) + + assert response.status_code == 200 + + mock_process_ocr.assert_called_once() + call_kwargs = mock_process_ocr.call_args[1] + + # layout_model should default to 'chinese' + assert call_kwargs['layout_model'] == 'chinese' + + app.dependency_overrides.clear() + + def test_start_task_with_invalid_layout_model(self, client, test_task, db_session): + """Verify invalid layout_model returns 422 validation error""" + + def override_get_db(): + try: + yield db_session + finally: + pass + + def override_get_current_user(): + return test_task.user + + app.dependency_overrides[get_db] = override_get_db + from app.core.deps import get_current_user + app.dependency_overrides[get_current_user] = override_get_current_user + + # Request with invalid layout_model + request_body = { + "use_dual_track": True, + "layout_model": "invalid_model" + } + + response = client.post( + f"/api/v2/tasks/{test_task.task_id}/start", + json=request_body + ) + + # Should return validation error + assert response.status_code == 422 + + app.dependency_overrides.clear() + + +class TestOpenAPISchema: + """Test OpenAPI schema includes layout_model parameter""" + + def test_openapi_schema_includes_layout_model(self, client): + """Verify OpenAPI schema documents layout_model parameter""" + response = client.get("/openapi.json") + assert response.status_code == 200 + + schema = response.json() + + # Check LayoutModelEnum schema exists + assert 'LayoutModelEnum' in schema['components']['schemas'] + + model_schema = schema['components']['schemas']['LayoutModelEnum'] + + # Verify all 3 model options are documented + assert 'chinese' in model_schema['enum'] + assert 'default' in model_schema['enum'] + assert 'cdla' in model_schema['enum'] + + # Verify ProcessingOptions includes layout_model + options_schema = schema['components']['schemas']['ProcessingOptions'] + assert 'layout_model' in options_schema['properties'] + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/backend/tests/services/test_ppstructure_params.py b/backend/tests/archived/test_ppstructure_params.py similarity index 100% rename from backend/tests/services/test_ppstructure_params.py rename to backend/tests/archived/test_ppstructure_params.py diff --git a/backend/tests/api/test_ppstructure_params_api.py b/backend/tests/archived/test_ppstructure_params_api.py similarity index 100% rename from backend/tests/api/test_ppstructure_params_api.py rename to backend/tests/archived/test_ppstructure_params_api.py diff --git a/backend/tests/e2e/test_ppstructure_params_e2e.py b/backend/tests/archived/test_ppstructure_params_e2e.py similarity index 100% rename from backend/tests/e2e/test_ppstructure_params_e2e.py rename to backend/tests/archived/test_ppstructure_params_e2e.py diff --git a/backend/tests/performance/test_ppstructure_params_performance.py b/backend/tests/archived/test_ppstructure_params_performance.py similarity index 100% rename from backend/tests/performance/test_ppstructure_params_performance.py rename to backend/tests/archived/test_ppstructure_params_performance.py diff --git a/backend/tests/services/test_layout_model.py b/backend/tests/services/test_layout_model.py new file mode 100644 index 0000000..4d2f500 --- /dev/null +++ b/backend/tests/services/test_layout_model.py @@ -0,0 +1,244 @@ +""" +Unit tests for Layout Model Selection feature in OCR Service. + +This replaces the deprecated PP-StructureV3 parameter tests. +""" + +import pytest +import sys +from pathlib import Path +from unittest.mock import Mock, patch, MagicMock + +# Mock all external dependencies before importing OCRService +sys.modules['paddleocr'] = MagicMock() +sys.modules['PIL'] = MagicMock() +sys.modules['pdf2image'] = MagicMock() + +# Mock paddle with version attribute +paddle_mock = MagicMock() +paddle_mock.__version__ = '2.5.0' +paddle_mock.device.get_device.return_value = 'cpu' +paddle_mock.device.get_available_device.return_value = 'cpu' +sys.modules['paddle'] = paddle_mock + +# Mock torch +torch_mock = MagicMock() +torch_mock.cuda.is_available.return_value = False +sys.modules['torch'] = torch_mock + +from app.services.ocr_service import OCRService, LAYOUT_MODEL_MAPPING, _USE_PUBLAYNET_DEFAULT +from app.core.config import settings + + +class TestLayoutModelMapping: + """Test layout model name mapping""" + + def test_layout_model_mapping_exists(self): + """Verify LAYOUT_MODEL_MAPPING constant exists and has correct values""" + assert 'chinese' in LAYOUT_MODEL_MAPPING + assert 'default' in LAYOUT_MODEL_MAPPING + assert 'cdla' in LAYOUT_MODEL_MAPPING + + def test_chinese_model_maps_to_pp_doclayout(self): + """Verify 'chinese' maps to PP-DocLayout-S""" + assert LAYOUT_MODEL_MAPPING['chinese'] == 'PP-DocLayout-S' + + def test_default_model_maps_to_publaynet_sentinel(self): + """Verify 'default' maps to sentinel value for PubLayNet default""" + # The 'default' model uses a sentinel value that signals "use PubLayNet default (no custom model)" + assert LAYOUT_MODEL_MAPPING['default'] == _USE_PUBLAYNET_DEFAULT + + def test_cdla_model_maps_to_picodet(self): + """Verify 'cdla' maps to picodet_lcnet_x1_0_fgd_layout_cdla""" + assert LAYOUT_MODEL_MAPPING['cdla'] == 'picodet_lcnet_x1_0_fgd_layout_cdla' + + +class TestLayoutModelEngine: + """Test engine creation with different layout models""" + + def test_chinese_model_creates_engine_with_pp_doclayout(self): + """Verify 'chinese' layout model uses PP-DocLayout-S""" + ocr_service = OCRService() + + with patch.object(ocr_service, 'structure_engine', None): + with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure: + mock_engine = Mock() + mock_ppstructure.return_value = mock_engine + + engine = ocr_service._ensure_structure_engine(layout_model='chinese') + + mock_ppstructure.assert_called_once() + call_kwargs = mock_ppstructure.call_args[1] + + assert call_kwargs.get('layout_detection_model_name') == 'PP-DocLayout-S' + + def test_default_model_creates_engine_without_model_name(self): + """Verify 'default' layout model does not specify model name (uses default)""" + ocr_service = OCRService() + + with patch.object(ocr_service, 'structure_engine', None): + with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure: + mock_engine = Mock() + mock_ppstructure.return_value = mock_engine + + engine = ocr_service._ensure_structure_engine(layout_model='default') + + mock_ppstructure.assert_called_once() + call_kwargs = mock_ppstructure.call_args[1] + + # For 'default', layout_detection_model_name should be None or not set + assert call_kwargs.get('layout_detection_model_name') is None + + def test_cdla_model_creates_engine_with_picodet(self): + """Verify 'cdla' layout model uses picodet_lcnet_x1_0_fgd_layout_cdla""" + ocr_service = OCRService() + + with patch.object(ocr_service, 'structure_engine', None): + with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure: + mock_engine = Mock() + mock_ppstructure.return_value = mock_engine + + engine = ocr_service._ensure_structure_engine(layout_model='cdla') + + mock_ppstructure.assert_called_once() + call_kwargs = mock_ppstructure.call_args[1] + + assert call_kwargs.get('layout_detection_model_name') == 'picodet_lcnet_x1_0_fgd_layout_cdla' + + def test_none_layout_model_uses_chinese_default(self): + """Verify None layout_model defaults to 'chinese' model""" + ocr_service = OCRService() + + with patch.object(ocr_service, 'structure_engine', None): + with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure: + mock_engine = Mock() + mock_ppstructure.return_value = mock_engine + + # Pass None for layout_model + engine = ocr_service._ensure_structure_engine(layout_model=None) + + mock_ppstructure.assert_called_once() + call_kwargs = mock_ppstructure.call_args[1] + + # Should use 'chinese' model as default + assert call_kwargs.get('layout_detection_model_name') == 'PP-DocLayout-S' + + +class TestLayoutModelCaching: + """Test engine caching behavior with layout models""" + + def test_same_layout_model_uses_cached_engine(self): + """Verify same layout model reuses cached engine""" + ocr_service = OCRService() + + with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure: + mock_engine = Mock() + mock_ppstructure.return_value = mock_engine + + # First call with 'chinese' + engine1 = ocr_service._ensure_structure_engine(layout_model='chinese') + + # Second call with same model should use cache + engine2 = ocr_service._ensure_structure_engine(layout_model='chinese') + + # Verify only one engine was created + assert mock_ppstructure.call_count == 1 + assert engine1 is engine2 + + def test_different_layout_model_creates_new_engine(self): + """Verify different layout model creates new engine""" + ocr_service = OCRService() + + with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure: + mock_engine1 = Mock() + mock_engine2 = Mock() + mock_ppstructure.side_effect = [mock_engine1, mock_engine2] + + # First call with 'chinese' + engine1 = ocr_service._ensure_structure_engine(layout_model='chinese') + + # Second call with 'cdla' should create new engine + engine2 = ocr_service._ensure_structure_engine(layout_model='cdla') + + # Verify two engines were created + assert mock_ppstructure.call_count == 2 + assert engine1 is not engine2 + + +class TestLayoutModelFlow: + """Test layout model parameter flow through processing pipeline""" + + def test_layout_model_passed_to_engine_creation(self): + """Verify layout_model is passed through to _ensure_structure_engine""" + ocr_service = OCRService() + + # Test that _ensure_structure_engine accepts layout_model parameter + with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure: + mock_engine = Mock() + mock_ppstructure.return_value = mock_engine + + # Call with specific layout_model + engine = ocr_service._ensure_structure_engine(layout_model='cdla') + + # Verify correct model was requested + mock_ppstructure.assert_called_once() + call_kwargs = mock_ppstructure.call_args[1] + assert call_kwargs.get('layout_detection_model_name') == 'picodet_lcnet_x1_0_fgd_layout_cdla' + + def test_layout_model_default_behavior(self): + """Verify default layout model behavior when None is passed""" + ocr_service = OCRService() + + with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure: + mock_engine = Mock() + mock_ppstructure.return_value = mock_engine + + # Call without layout_model (None) + engine = ocr_service._ensure_structure_engine(layout_model=None) + + # Should use config default (PP-DocLayout-S) + mock_ppstructure.assert_called_once() + call_kwargs = mock_ppstructure.call_args[1] + assert call_kwargs.get('layout_detection_model_name') == settings.layout_detection_model_name + + def test_layout_model_unknown_value_falls_back(self): + """Verify unknown layout model falls back to config default""" + ocr_service = OCRService() + + with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure: + mock_engine = Mock() + mock_ppstructure.return_value = mock_engine + + # Call with unknown layout_model + engine = ocr_service._ensure_structure_engine(layout_model='unknown_model') + + # Should use config default + mock_ppstructure.assert_called_once() + call_kwargs = mock_ppstructure.call_args[1] + assert call_kwargs.get('layout_detection_model_name') == settings.layout_detection_model_name + + +class TestLayoutModelLogging: + """Test layout model logging""" + + def test_layout_model_is_logged(self): + """Verify layout model selection is logged""" + ocr_service = OCRService() + + with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure: + with patch('app.services.ocr_service.logger') as mock_logger: + mock_engine = Mock() + mock_ppstructure.return_value = mock_engine + + # Call with specific layout_model + ocr_service._ensure_structure_engine(layout_model='cdla') + + # Verify logging occurred + assert mock_logger.info.call_count >= 1 + # Check that model name was logged + log_calls = [str(call) for call in mock_logger.info.call_args_list] + assert any('cdla' in str(call).lower() or 'layout' in str(call).lower() for call in log_calls) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/backend/tests/test_gap_filling.py b/backend/tests/test_gap_filling.py new file mode 100644 index 0000000..e80fb7b --- /dev/null +++ b/backend/tests/test_gap_filling.py @@ -0,0 +1,503 @@ +""" +Tests for Gap Filling Service + +Tests the detection and filling of gaps in PP-StructureV3 output +using raw OCR text regions. +""" + +import pytest +from typing import List, Dict, Any + +from app.services.gap_filling_service import GapFillingService, TextRegion, SKIP_ELEMENT_TYPES +from app.models.unified_document import DocumentElement, BoundingBox, ElementType, Dimensions + + +class TestGapFillingService: + """Tests for GapFillingService class.""" + + @pytest.fixture + def service(self) -> GapFillingService: + """Create a GapFillingService instance with default settings.""" + return GapFillingService( + coverage_threshold=0.7, + iou_threshold=0.15, + confidence_threshold=0.3, + dedup_iou_threshold=0.5, + enabled=True + ) + + @pytest.fixture + def disabled_service(self) -> GapFillingService: + """Create a disabled GapFillingService instance.""" + return GapFillingService(enabled=False) + + @pytest.fixture + def sample_raw_regions(self) -> List[TextRegion]: + """Create sample raw OCR text regions.""" + return [ + TextRegion(text="Header text", bbox=[100, 50, 300, 80], confidence=0.95, page=1), + TextRegion(text="Title of document", bbox=[100, 100, 500, 150], confidence=0.92, page=1), + TextRegion(text="First paragraph", bbox=[100, 200, 500, 250], confidence=0.90, page=1), + TextRegion(text="Second paragraph", bbox=[100, 300, 500, 350], confidence=0.88, page=1), + TextRegion(text="Footer note", bbox=[100, 900, 300, 930], confidence=0.85, page=1), + # Low confidence region (should be filtered) + TextRegion(text="Noise", bbox=[50, 50, 80, 80], confidence=0.1, page=1), + ] + + @pytest.fixture + def sample_pp_elements(self) -> List[DocumentElement]: + """Create sample PP-StructureV3 elements that cover only some regions.""" + return [ + DocumentElement( + element_id="pp_1", + type=ElementType.TITLE, + content="Title of document", + bbox=BoundingBox(x0=100, y0=100, x1=500, y1=150), + confidence=0.95 + ), + DocumentElement( + element_id="pp_2", + type=ElementType.TEXT, + content="First paragraph", + bbox=BoundingBox(x0=100, y0=200, x1=500, y1=250), + confidence=0.90 + ), + # Note: Header, Second paragraph, and Footer are NOT covered + ] + + def test_service_initialization(self, service: GapFillingService): + """Test service initializes with correct parameters.""" + assert service.enabled is True + assert service.coverage_threshold == 0.7 + assert service.iou_threshold == 0.15 + assert service.confidence_threshold == 0.3 + assert service.dedup_iou_threshold == 0.5 + + def test_disabled_service(self, disabled_service: GapFillingService): + """Test disabled service does not activate.""" + regions = [TextRegion(text="Test", bbox=[0, 0, 100, 100], confidence=0.9, page=1)] + elements = [] + + should_activate, coverage = disabled_service.should_activate(regions, elements) + assert should_activate is False + assert coverage == 1.0 + + def test_should_activate_low_coverage( + self, + service: GapFillingService, + sample_raw_regions: List[TextRegion], + sample_pp_elements: List[DocumentElement] + ): + """Test activation when coverage is below threshold.""" + # Filter out low confidence regions + valid_regions = [r for r in sample_raw_regions if r.confidence >= 0.3] + + should_activate, coverage = service.should_activate(valid_regions, sample_pp_elements) + + # Only 2 out of 5 valid regions are covered (Title, First paragraph) + assert should_activate is True + assert coverage < 0.7 # Below threshold + + def test_should_not_activate_high_coverage(self, service: GapFillingService): + """Test no activation when coverage is above threshold.""" + # All regions covered + regions = [ + TextRegion(text="Text 1", bbox=[100, 100, 200, 150], confidence=0.9, page=1), + TextRegion(text="Text 2", bbox=[100, 200, 200, 250], confidence=0.9, page=1), + ] + + elements = [ + DocumentElement( + element_id="pp_1", + type=ElementType.TEXT, + content="Text 1", + bbox=BoundingBox(x0=50, y0=50, x1=250, y1=200), # Covers first region + confidence=0.95 + ), + DocumentElement( + element_id="pp_2", + type=ElementType.TEXT, + content="Text 2", + bbox=BoundingBox(x0=50, y0=180, x1=250, y1=300), # Covers second region + confidence=0.95 + ), + ] + + should_activate, coverage = service.should_activate(regions, elements) + + assert should_activate is False + assert coverage >= 0.7 + + def test_find_uncovered_regions( + self, + service: GapFillingService, + sample_raw_regions: List[TextRegion], + sample_pp_elements: List[DocumentElement] + ): + """Test finding uncovered regions.""" + uncovered = service.find_uncovered_regions(sample_raw_regions, sample_pp_elements) + + # Should find Header, Second paragraph, Footer (not Title, First paragraph, or low-confidence Noise) + assert len(uncovered) == 3 + + uncovered_texts = [r.text for r in uncovered] + assert "Header text" in uncovered_texts + assert "Second paragraph" in uncovered_texts + assert "Footer note" in uncovered_texts + assert "Title of document" not in uncovered_texts # Covered + assert "First paragraph" not in uncovered_texts # Covered + assert "Noise" not in uncovered_texts # Low confidence + + def test_coverage_by_center_point(self, service: GapFillingService): + """Test coverage detection via center point.""" + region = TextRegion(text="Test", bbox=[150, 150, 250, 200], confidence=0.9, page=1) + + element = DocumentElement( + element_id="pp_1", + type=ElementType.TEXT, + content="Container", + bbox=BoundingBox(x0=100, y0=100, x1=300, y1=250), # Contains region's center + confidence=0.95 + ) + + is_covered = service._is_region_covered(region, [element]) + assert is_covered is True + + def test_coverage_by_iou(self, service: GapFillingService): + """Test coverage detection via IoU threshold.""" + region = TextRegion(text="Test", bbox=[100, 100, 200, 150], confidence=0.9, page=1) + + element = DocumentElement( + element_id="pp_1", + type=ElementType.TEXT, + content="Overlap", + bbox=BoundingBox(x0=150, y0=100, x1=250, y1=150), # Partial overlap + confidence=0.95 + ) + + # Calculate expected IoU + # Intersection: (150-200) x (100-150) = 50 x 50 = 2500 + # Union: 100x50 + 100x50 - 2500 = 7500 + # IoU = 2500/7500 = 0.33 > 0.15 threshold + + is_covered = service._is_region_covered(region, [element]) + assert is_covered is True + + def test_deduplication( + self, + service: GapFillingService, + sample_pp_elements: List[DocumentElement] + ): + """Test deduplication removes high-overlap regions.""" + uncovered = [ + # High overlap with pp_2 (First paragraph) + TextRegion(text="First paragraph variant", bbox=[100, 200, 500, 250], confidence=0.9, page=1), + # No overlap + TextRegion(text="Unique region", bbox=[100, 500, 300, 550], confidence=0.9, page=1), + ] + + deduplicated = service.deduplicate_regions(uncovered, sample_pp_elements) + + assert len(deduplicated) == 1 + assert deduplicated[0].text == "Unique region" + + def test_convert_regions_to_elements(self, service: GapFillingService): + """Test conversion of TextRegions to DocumentElements.""" + regions = [ + TextRegion(text="Test text 1", bbox=[100, 100, 200, 150], confidence=0.85, page=1), + TextRegion(text="Test text 2", bbox=[100, 200, 200, 250], confidence=0.90, page=1), + ] + + elements = service.convert_regions_to_elements(regions, page_number=1, start_element_id=0) + + assert len(elements) == 2 + assert elements[0].element_id == "gap_fill_1_0" + assert elements[0].type == ElementType.TEXT + assert elements[0].content == "Test text 1" + assert elements[0].confidence == 0.85 + assert elements[0].metadata.get('source') == 'gap_filling' + + assert elements[1].element_id == "gap_fill_1_1" + assert elements[1].content == "Test text 2" + + def test_recalculate_reading_order(self, service: GapFillingService): + """Test reading order recalculation.""" + elements = [ + DocumentElement( + element_id="e3", + type=ElementType.TEXT, + content="Bottom", + bbox=BoundingBox(x0=100, y0=300, x1=200, y1=350), + confidence=0.9 + ), + DocumentElement( + element_id="e1", + type=ElementType.TEXT, + content="Top", + bbox=BoundingBox(x0=100, y0=100, x1=200, y1=150), + confidence=0.9 + ), + DocumentElement( + element_id="e2", + type=ElementType.TEXT, + content="Middle", + bbox=BoundingBox(x0=100, y0=200, x1=200, y1=250), + confidence=0.9 + ), + ] + + reading_order = service.recalculate_reading_order(elements) + + # Should be sorted by y0: Top (100), Middle (200), Bottom (300) + assert reading_order == [1, 2, 0] # Indices of elements in reading order + + def test_fill_gaps_integration( + self, + service: GapFillingService, + ): + """Integration test for fill_gaps method.""" + # Raw OCR regions (dict format as received from OCR service) + raw_regions = [ + {'text': 'Header', 'bbox': [100, 50, 300, 80], 'confidence': 0.95, 'page': 1}, + {'text': 'Title', 'bbox': [100, 100, 500, 150], 'confidence': 0.92, 'page': 1}, + {'text': 'Paragraph 1', 'bbox': [100, 200, 500, 250], 'confidence': 0.90, 'page': 1}, + {'text': 'Paragraph 2', 'bbox': [100, 300, 500, 350], 'confidence': 0.88, 'page': 1}, + {'text': 'Paragraph 3', 'bbox': [100, 400, 500, 450], 'confidence': 0.86, 'page': 1}, + {'text': 'Footer', 'bbox': [100, 900, 300, 930], 'confidence': 0.85, 'page': 1}, + ] + + # PP-StructureV3 only detected Title (missing 5 out of 6 regions = 16.7% coverage) + pp_elements = [ + DocumentElement( + element_id="pp_1", + type=ElementType.TITLE, + content="Title", + bbox=BoundingBox(x0=100, y0=100, x1=500, y1=150), + confidence=0.95 + ), + ] + + supplemented, stats = service.fill_gaps( + raw_ocr_regions=raw_regions, + pp_structure_elements=pp_elements, + page_number=1 + ) + + # Should have activated and supplemented missing regions + assert stats['activated'] is True + assert stats['coverage_ratio'] < 0.7 + assert len(supplemented) == 5 # Header, Paragraph 1, 2, 3, Footer + + def test_fill_gaps_no_activation_when_coverage_high(self, service: GapFillingService): + """Test fill_gaps does not activate when coverage is high.""" + raw_regions = [ + {'text': 'Text 1', 'bbox': [100, 100, 200, 150], 'confidence': 0.9, 'page': 1}, + ] + + pp_elements = [ + DocumentElement( + element_id="pp_1", + type=ElementType.TEXT, + content="Text 1", + bbox=BoundingBox(x0=50, y0=50, x1=250, y1=200), # Fully covers + confidence=0.95 + ), + ] + + supplemented, stats = service.fill_gaps( + raw_ocr_regions=raw_regions, + pp_structure_elements=pp_elements, + page_number=1 + ) + + assert stats['activated'] is False + assert len(supplemented) == 0 + + def test_skip_element_types_not_supplemented(self, service: GapFillingService): + """Test that TABLE/IMAGE/etc. elements are not supplemented over.""" + raw_regions = [ + {'text': 'Table cell text', 'bbox': [100, 100, 200, 150], 'confidence': 0.9, 'page': 1}, + ] + + # PP-StructureV3 has a table covering this region + pp_elements = [ + DocumentElement( + element_id="pp_1", + type=ElementType.TABLE, + content="
+ {t('processing.layoutModel.note')} +
+- Note: These parameters only apply when using the OCR track. Adjusting them - can help improve accuracy for specific document types. -
-