feat: simplify layout model selection and archive proposals
Changes: - Replace PP-Structure 7-slider parameter UI with simple 3-option layout model selector - Add layout model mapping: chinese (PP-DocLayout-S), default (PubLayNet), cdla - Add LayoutModelSelector component and zh-TW translations - Fix "default" model behavior with sentinel value for PubLayNet - Add gap filling service for OCR track coverage improvement - Add PP-Structure debug utilities - Archive completed/incomplete proposals: - add-ocr-track-gap-filling (complete) - fix-ocr-track-table-rendering (incomplete) - simplify-ppstructure-model-selection (22/25 tasks) - Add new layout model tests, archive old PP-Structure param tests - Update OpenSpec ocr-processing spec with layout model requirements 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
332
backend/tests/api/test_layout_model_api.py
Normal file
332
backend/tests/api/test_layout_model_api.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
API integration tests for Layout Model Selection feature.
|
||||
|
||||
This replaces the deprecated PP-StructureV3 parameter tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch
|
||||
from app.main import app
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.task import Task, TaskStatus, TaskFile
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create test client"""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(db_session):
|
||||
"""Create test user"""
|
||||
user = User(
|
||||
email="test@example.com",
|
||||
hashed_password="test_hash",
|
||||
is_active=True
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_task(db_session, test_user):
|
||||
"""Create test task with uploaded file"""
|
||||
task = Task(
|
||||
user_id=test_user.id,
|
||||
task_id="test-task-123",
|
||||
filename="test.pdf",
|
||||
status=TaskStatus.PENDING
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
# Add task file
|
||||
task_file = TaskFile(
|
||||
task_id=task.id,
|
||||
original_name="test.pdf",
|
||||
stored_path="/tmp/test.pdf",
|
||||
file_size=1024,
|
||||
mime_type="application/pdf"
|
||||
)
|
||||
db_session.add(task_file)
|
||||
db_session.commit()
|
||||
|
||||
return task
|
||||
|
||||
|
||||
class TestLayoutModelSchema:
|
||||
"""Test LayoutModel and ProcessingOptions schema validation"""
|
||||
|
||||
def test_processing_options_accepts_layout_model(self):
|
||||
"""Verify ProcessingOptions schema accepts layout_model parameter"""
|
||||
from app.schemas.task import ProcessingOptions, LayoutModelEnum
|
||||
|
||||
options = ProcessingOptions(
|
||||
use_dual_track=True,
|
||||
language='ch',
|
||||
layout_model=LayoutModelEnum.CHINESE
|
||||
)
|
||||
|
||||
assert options.layout_model == LayoutModelEnum.CHINESE
|
||||
|
||||
def test_layout_model_enum_values(self):
|
||||
"""Verify all layout model enum values are valid"""
|
||||
from app.schemas.task import LayoutModelEnum
|
||||
|
||||
assert LayoutModelEnum.CHINESE.value == "chinese"
|
||||
assert LayoutModelEnum.DEFAULT.value == "default"
|
||||
assert LayoutModelEnum.CDLA.value == "cdla"
|
||||
|
||||
def test_default_layout_model_is_chinese(self):
|
||||
"""Verify default layout model is 'chinese' for best Chinese document support"""
|
||||
from app.schemas.task import ProcessingOptions
|
||||
|
||||
options = ProcessingOptions()
|
||||
|
||||
# Default should be chinese
|
||||
assert options.layout_model.value == "chinese"
|
||||
|
||||
def test_layout_model_string_values_accepted(self):
|
||||
"""Verify string values are accepted for layout_model"""
|
||||
from app.schemas.task import ProcessingOptions
|
||||
|
||||
# String values should be converted to enum
|
||||
options = ProcessingOptions(layout_model="default")
|
||||
assert options.layout_model.value == "default"
|
||||
|
||||
options = ProcessingOptions(layout_model="cdla")
|
||||
assert options.layout_model.value == "cdla"
|
||||
|
||||
def test_invalid_layout_model_rejected(self):
|
||||
"""Verify invalid layout model values are rejected"""
|
||||
from app.schemas.task import ProcessingOptions
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ProcessingOptions(layout_model="invalid_model")
|
||||
|
||||
|
||||
class TestStartTaskEndpoint:
|
||||
"""Test /tasks/{task_id}/start endpoint with layout_model parameter"""
|
||||
|
||||
@patch('app.routers.tasks.process_task_ocr')
|
||||
def test_start_task_with_layout_model(self, mock_process_ocr, client, test_task, db_session):
|
||||
"""Verify layout_model is accepted and passed to OCR service"""
|
||||
|
||||
# Override get_db dependency
|
||||
def override_get_db():
|
||||
try:
|
||||
yield db_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
# Override auth dependency
|
||||
def override_get_current_user():
|
||||
return test_task.user
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
from app.core.deps import get_current_user
|
||||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
# Request body with layout_model
|
||||
request_body = {
|
||||
"use_dual_track": True,
|
||||
"language": "ch",
|
||||
"layout_model": "chinese"
|
||||
}
|
||||
|
||||
# Make API call
|
||||
response = client.post(
|
||||
f"/api/v2/tasks/{test_task.task_id}/start",
|
||||
json=request_body
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['status'] == 'processing'
|
||||
|
||||
# Verify background task was called with layout_model
|
||||
mock_process_ocr.assert_called_once()
|
||||
call_kwargs = mock_process_ocr.call_args[1]
|
||||
|
||||
assert 'layout_model' in call_kwargs
|
||||
assert call_kwargs['layout_model'] == 'chinese'
|
||||
|
||||
# Clean up
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
@patch('app.routers.tasks.process_task_ocr')
|
||||
def test_start_task_with_default_model(self, mock_process_ocr, client, test_task, db_session):
|
||||
"""Verify 'default' layout model is accepted"""
|
||||
|
||||
def override_get_db():
|
||||
try:
|
||||
yield db_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
def override_get_current_user():
|
||||
return test_task.user
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
from app.core.deps import get_current_user
|
||||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
request_body = {
|
||||
"use_dual_track": True,
|
||||
"layout_model": "default"
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
f"/api/v2/tasks/{test_task.task_id}/start",
|
||||
json=request_body
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
mock_process_ocr.assert_called_once()
|
||||
call_kwargs = mock_process_ocr.call_args[1]
|
||||
assert call_kwargs['layout_model'] == 'default'
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
@patch('app.routers.tasks.process_task_ocr')
|
||||
def test_start_task_with_cdla_model(self, mock_process_ocr, client, test_task, db_session):
|
||||
"""Verify 'cdla' layout model is accepted"""
|
||||
|
||||
def override_get_db():
|
||||
try:
|
||||
yield db_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
def override_get_current_user():
|
||||
return test_task.user
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
from app.core.deps import get_current_user
|
||||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
request_body = {
|
||||
"use_dual_track": True,
|
||||
"layout_model": "cdla"
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
f"/api/v2/tasks/{test_task.task_id}/start",
|
||||
json=request_body
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
mock_process_ocr.assert_called_once()
|
||||
call_kwargs = mock_process_ocr.call_args[1]
|
||||
assert call_kwargs['layout_model'] == 'cdla'
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
@patch('app.routers.tasks.process_task_ocr')
|
||||
def test_start_task_without_layout_model_uses_default(self, mock_process_ocr, client, test_task, db_session):
|
||||
"""Verify task can start without layout_model (uses 'chinese' as default)"""
|
||||
|
||||
def override_get_db():
|
||||
try:
|
||||
yield db_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
def override_get_current_user():
|
||||
return test_task.user
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
from app.core.deps import get_current_user
|
||||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
# Request without layout_model
|
||||
request_body = {
|
||||
"use_dual_track": True,
|
||||
"language": "ch"
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
f"/api/v2/tasks/{test_task.task_id}/start",
|
||||
json=request_body
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
mock_process_ocr.assert_called_once()
|
||||
call_kwargs = mock_process_ocr.call_args[1]
|
||||
|
||||
# layout_model should default to 'chinese'
|
||||
assert call_kwargs['layout_model'] == 'chinese'
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
def test_start_task_with_invalid_layout_model(self, client, test_task, db_session):
|
||||
"""Verify invalid layout_model returns 422 validation error"""
|
||||
|
||||
def override_get_db():
|
||||
try:
|
||||
yield db_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
def override_get_current_user():
|
||||
return test_task.user
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
from app.core.deps import get_current_user
|
||||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
# Request with invalid layout_model
|
||||
request_body = {
|
||||
"use_dual_track": True,
|
||||
"layout_model": "invalid_model"
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
f"/api/v2/tasks/{test_task.task_id}/start",
|
||||
json=request_body
|
||||
)
|
||||
|
||||
# Should return validation error
|
||||
assert response.status_code == 422
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestOpenAPISchema:
|
||||
"""Test OpenAPI schema includes layout_model parameter"""
|
||||
|
||||
def test_openapi_schema_includes_layout_model(self, client):
|
||||
"""Verify OpenAPI schema documents layout_model parameter"""
|
||||
response = client.get("/openapi.json")
|
||||
assert response.status_code == 200
|
||||
|
||||
schema = response.json()
|
||||
|
||||
# Check LayoutModelEnum schema exists
|
||||
assert 'LayoutModelEnum' in schema['components']['schemas']
|
||||
|
||||
model_schema = schema['components']['schemas']['LayoutModelEnum']
|
||||
|
||||
# Verify all 3 model options are documented
|
||||
assert 'chinese' in model_schema['enum']
|
||||
assert 'default' in model_schema['enum']
|
||||
assert 'cdla' in model_schema['enum']
|
||||
|
||||
# Verify ProcessingOptions includes layout_model
|
||||
options_schema = schema['components']['schemas']['ProcessingOptions']
|
||||
assert 'layout_model' in options_schema['properties']
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
244
backend/tests/services/test_layout_model.py
Normal file
244
backend/tests/services/test_layout_model.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Unit tests for Layout Model Selection feature in OCR Service.
|
||||
|
||||
This replaces the deprecated PP-StructureV3 parameter tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Mock all external dependencies before importing OCRService
|
||||
sys.modules['paddleocr'] = MagicMock()
|
||||
sys.modules['PIL'] = MagicMock()
|
||||
sys.modules['pdf2image'] = MagicMock()
|
||||
|
||||
# Mock paddle with version attribute
|
||||
paddle_mock = MagicMock()
|
||||
paddle_mock.__version__ = '2.5.0'
|
||||
paddle_mock.device.get_device.return_value = 'cpu'
|
||||
paddle_mock.device.get_available_device.return_value = 'cpu'
|
||||
sys.modules['paddle'] = paddle_mock
|
||||
|
||||
# Mock torch
|
||||
torch_mock = MagicMock()
|
||||
torch_mock.cuda.is_available.return_value = False
|
||||
sys.modules['torch'] = torch_mock
|
||||
|
||||
from app.services.ocr_service import OCRService, LAYOUT_MODEL_MAPPING, _USE_PUBLAYNET_DEFAULT
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestLayoutModelMapping:
|
||||
"""Test layout model name mapping"""
|
||||
|
||||
def test_layout_model_mapping_exists(self):
|
||||
"""Verify LAYOUT_MODEL_MAPPING constant exists and has correct values"""
|
||||
assert 'chinese' in LAYOUT_MODEL_MAPPING
|
||||
assert 'default' in LAYOUT_MODEL_MAPPING
|
||||
assert 'cdla' in LAYOUT_MODEL_MAPPING
|
||||
|
||||
def test_chinese_model_maps_to_pp_doclayout(self):
|
||||
"""Verify 'chinese' maps to PP-DocLayout-S"""
|
||||
assert LAYOUT_MODEL_MAPPING['chinese'] == 'PP-DocLayout-S'
|
||||
|
||||
def test_default_model_maps_to_publaynet_sentinel(self):
|
||||
"""Verify 'default' maps to sentinel value for PubLayNet default"""
|
||||
# The 'default' model uses a sentinel value that signals "use PubLayNet default (no custom model)"
|
||||
assert LAYOUT_MODEL_MAPPING['default'] == _USE_PUBLAYNET_DEFAULT
|
||||
|
||||
def test_cdla_model_maps_to_picodet(self):
|
||||
"""Verify 'cdla' maps to picodet_lcnet_x1_0_fgd_layout_cdla"""
|
||||
assert LAYOUT_MODEL_MAPPING['cdla'] == 'picodet_lcnet_x1_0_fgd_layout_cdla'
|
||||
|
||||
|
||||
class TestLayoutModelEngine:
|
||||
"""Test engine creation with different layout models"""
|
||||
|
||||
def test_chinese_model_creates_engine_with_pp_doclayout(self):
|
||||
"""Verify 'chinese' layout model uses PP-DocLayout-S"""
|
||||
ocr_service = OCRService()
|
||||
|
||||
with patch.object(ocr_service, 'structure_engine', None):
|
||||
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
|
||||
mock_engine = Mock()
|
||||
mock_ppstructure.return_value = mock_engine
|
||||
|
||||
engine = ocr_service._ensure_structure_engine(layout_model='chinese')
|
||||
|
||||
mock_ppstructure.assert_called_once()
|
||||
call_kwargs = mock_ppstructure.call_args[1]
|
||||
|
||||
assert call_kwargs.get('layout_detection_model_name') == 'PP-DocLayout-S'
|
||||
|
||||
def test_default_model_creates_engine_without_model_name(self):
|
||||
"""Verify 'default' layout model does not specify model name (uses default)"""
|
||||
ocr_service = OCRService()
|
||||
|
||||
with patch.object(ocr_service, 'structure_engine', None):
|
||||
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
|
||||
mock_engine = Mock()
|
||||
mock_ppstructure.return_value = mock_engine
|
||||
|
||||
engine = ocr_service._ensure_structure_engine(layout_model='default')
|
||||
|
||||
mock_ppstructure.assert_called_once()
|
||||
call_kwargs = mock_ppstructure.call_args[1]
|
||||
|
||||
# For 'default', layout_detection_model_name should be None or not set
|
||||
assert call_kwargs.get('layout_detection_model_name') is None
|
||||
|
||||
def test_cdla_model_creates_engine_with_picodet(self):
|
||||
"""Verify 'cdla' layout model uses picodet_lcnet_x1_0_fgd_layout_cdla"""
|
||||
ocr_service = OCRService()
|
||||
|
||||
with patch.object(ocr_service, 'structure_engine', None):
|
||||
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
|
||||
mock_engine = Mock()
|
||||
mock_ppstructure.return_value = mock_engine
|
||||
|
||||
engine = ocr_service._ensure_structure_engine(layout_model='cdla')
|
||||
|
||||
mock_ppstructure.assert_called_once()
|
||||
call_kwargs = mock_ppstructure.call_args[1]
|
||||
|
||||
assert call_kwargs.get('layout_detection_model_name') == 'picodet_lcnet_x1_0_fgd_layout_cdla'
|
||||
|
||||
def test_none_layout_model_uses_chinese_default(self):
|
||||
"""Verify None layout_model defaults to 'chinese' model"""
|
||||
ocr_service = OCRService()
|
||||
|
||||
with patch.object(ocr_service, 'structure_engine', None):
|
||||
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
|
||||
mock_engine = Mock()
|
||||
mock_ppstructure.return_value = mock_engine
|
||||
|
||||
# Pass None for layout_model
|
||||
engine = ocr_service._ensure_structure_engine(layout_model=None)
|
||||
|
||||
mock_ppstructure.assert_called_once()
|
||||
call_kwargs = mock_ppstructure.call_args[1]
|
||||
|
||||
# Should use 'chinese' model as default
|
||||
assert call_kwargs.get('layout_detection_model_name') == 'PP-DocLayout-S'
|
||||
|
||||
|
||||
class TestLayoutModelCaching:
|
||||
"""Test engine caching behavior with layout models"""
|
||||
|
||||
def test_same_layout_model_uses_cached_engine(self):
|
||||
"""Verify same layout model reuses cached engine"""
|
||||
ocr_service = OCRService()
|
||||
|
||||
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
|
||||
mock_engine = Mock()
|
||||
mock_ppstructure.return_value = mock_engine
|
||||
|
||||
# First call with 'chinese'
|
||||
engine1 = ocr_service._ensure_structure_engine(layout_model='chinese')
|
||||
|
||||
# Second call with same model should use cache
|
||||
engine2 = ocr_service._ensure_structure_engine(layout_model='chinese')
|
||||
|
||||
# Verify only one engine was created
|
||||
assert mock_ppstructure.call_count == 1
|
||||
assert engine1 is engine2
|
||||
|
||||
def test_different_layout_model_creates_new_engine(self):
|
||||
"""Verify different layout model creates new engine"""
|
||||
ocr_service = OCRService()
|
||||
|
||||
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
|
||||
mock_engine1 = Mock()
|
||||
mock_engine2 = Mock()
|
||||
mock_ppstructure.side_effect = [mock_engine1, mock_engine2]
|
||||
|
||||
# First call with 'chinese'
|
||||
engine1 = ocr_service._ensure_structure_engine(layout_model='chinese')
|
||||
|
||||
# Second call with 'cdla' should create new engine
|
||||
engine2 = ocr_service._ensure_structure_engine(layout_model='cdla')
|
||||
|
||||
# Verify two engines were created
|
||||
assert mock_ppstructure.call_count == 2
|
||||
assert engine1 is not engine2
|
||||
|
||||
|
||||
class TestLayoutModelFlow:
|
||||
"""Test layout model parameter flow through processing pipeline"""
|
||||
|
||||
def test_layout_model_passed_to_engine_creation(self):
|
||||
"""Verify layout_model is passed through to _ensure_structure_engine"""
|
||||
ocr_service = OCRService()
|
||||
|
||||
# Test that _ensure_structure_engine accepts layout_model parameter
|
||||
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
|
||||
mock_engine = Mock()
|
||||
mock_ppstructure.return_value = mock_engine
|
||||
|
||||
# Call with specific layout_model
|
||||
engine = ocr_service._ensure_structure_engine(layout_model='cdla')
|
||||
|
||||
# Verify correct model was requested
|
||||
mock_ppstructure.assert_called_once()
|
||||
call_kwargs = mock_ppstructure.call_args[1]
|
||||
assert call_kwargs.get('layout_detection_model_name') == 'picodet_lcnet_x1_0_fgd_layout_cdla'
|
||||
|
||||
def test_layout_model_default_behavior(self):
|
||||
"""Verify default layout model behavior when None is passed"""
|
||||
ocr_service = OCRService()
|
||||
|
||||
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
|
||||
mock_engine = Mock()
|
||||
mock_ppstructure.return_value = mock_engine
|
||||
|
||||
# Call without layout_model (None)
|
||||
engine = ocr_service._ensure_structure_engine(layout_model=None)
|
||||
|
||||
# Should use config default (PP-DocLayout-S)
|
||||
mock_ppstructure.assert_called_once()
|
||||
call_kwargs = mock_ppstructure.call_args[1]
|
||||
assert call_kwargs.get('layout_detection_model_name') == settings.layout_detection_model_name
|
||||
|
||||
def test_layout_model_unknown_value_falls_back(self):
|
||||
"""Verify unknown layout model falls back to config default"""
|
||||
ocr_service = OCRService()
|
||||
|
||||
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
|
||||
mock_engine = Mock()
|
||||
mock_ppstructure.return_value = mock_engine
|
||||
|
||||
# Call with unknown layout_model
|
||||
engine = ocr_service._ensure_structure_engine(layout_model='unknown_model')
|
||||
|
||||
# Should use config default
|
||||
mock_ppstructure.assert_called_once()
|
||||
call_kwargs = mock_ppstructure.call_args[1]
|
||||
assert call_kwargs.get('layout_detection_model_name') == settings.layout_detection_model_name
|
||||
|
||||
|
||||
class TestLayoutModelLogging:
|
||||
"""Test layout model logging"""
|
||||
|
||||
def test_layout_model_is_logged(self):
|
||||
"""Verify layout model selection is logged"""
|
||||
ocr_service = OCRService()
|
||||
|
||||
with patch('app.services.ocr_service.PPStructureV3') as mock_ppstructure:
|
||||
with patch('app.services.ocr_service.logger') as mock_logger:
|
||||
mock_engine = Mock()
|
||||
mock_ppstructure.return_value = mock_engine
|
||||
|
||||
# Call with specific layout_model
|
||||
ocr_service._ensure_structure_engine(layout_model='cdla')
|
||||
|
||||
# Verify logging occurred
|
||||
assert mock_logger.info.call_count >= 1
|
||||
# Check that model name was logged
|
||||
log_calls = [str(call) for call in mock_logger.info.call_args_list]
|
||||
assert any('cdla' in str(call).lower() or 'layout' in str(call).lower() for call in log_calls)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
503
backend/tests/test_gap_filling.py
Normal file
503
backend/tests/test_gap_filling.py
Normal file
@@ -0,0 +1,503 @@
|
||||
"""
|
||||
Tests for Gap Filling Service
|
||||
|
||||
Tests the detection and filling of gaps in PP-StructureV3 output
|
||||
using raw OCR text regions.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from app.services.gap_filling_service import GapFillingService, TextRegion, SKIP_ELEMENT_TYPES
|
||||
from app.models.unified_document import DocumentElement, BoundingBox, ElementType, Dimensions
|
||||
|
||||
|
||||
class TestGapFillingService:
|
||||
"""Tests for GapFillingService class."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self) -> GapFillingService:
|
||||
"""Create a GapFillingService instance with default settings."""
|
||||
return GapFillingService(
|
||||
coverage_threshold=0.7,
|
||||
iou_threshold=0.15,
|
||||
confidence_threshold=0.3,
|
||||
dedup_iou_threshold=0.5,
|
||||
enabled=True
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def disabled_service(self) -> GapFillingService:
|
||||
"""Create a disabled GapFillingService instance."""
|
||||
return GapFillingService(enabled=False)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_raw_regions(self) -> List[TextRegion]:
|
||||
"""Create sample raw OCR text regions."""
|
||||
return [
|
||||
TextRegion(text="Header text", bbox=[100, 50, 300, 80], confidence=0.95, page=1),
|
||||
TextRegion(text="Title of document", bbox=[100, 100, 500, 150], confidence=0.92, page=1),
|
||||
TextRegion(text="First paragraph", bbox=[100, 200, 500, 250], confidence=0.90, page=1),
|
||||
TextRegion(text="Second paragraph", bbox=[100, 300, 500, 350], confidence=0.88, page=1),
|
||||
TextRegion(text="Footer note", bbox=[100, 900, 300, 930], confidence=0.85, page=1),
|
||||
# Low confidence region (should be filtered)
|
||||
TextRegion(text="Noise", bbox=[50, 50, 80, 80], confidence=0.1, page=1),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pp_elements(self) -> List[DocumentElement]:
|
||||
"""Create sample PP-StructureV3 elements that cover only some regions."""
|
||||
return [
|
||||
DocumentElement(
|
||||
element_id="pp_1",
|
||||
type=ElementType.TITLE,
|
||||
content="Title of document",
|
||||
bbox=BoundingBox(x0=100, y0=100, x1=500, y1=150),
|
||||
confidence=0.95
|
||||
),
|
||||
DocumentElement(
|
||||
element_id="pp_2",
|
||||
type=ElementType.TEXT,
|
||||
content="First paragraph",
|
||||
bbox=BoundingBox(x0=100, y0=200, x1=500, y1=250),
|
||||
confidence=0.90
|
||||
),
|
||||
# Note: Header, Second paragraph, and Footer are NOT covered
|
||||
]
|
||||
|
||||
def test_service_initialization(self, service: GapFillingService):
|
||||
"""Test service initializes with correct parameters."""
|
||||
assert service.enabled is True
|
||||
assert service.coverage_threshold == 0.7
|
||||
assert service.iou_threshold == 0.15
|
||||
assert service.confidence_threshold == 0.3
|
||||
assert service.dedup_iou_threshold == 0.5
|
||||
|
||||
def test_disabled_service(self, disabled_service: GapFillingService):
|
||||
"""Test disabled service does not activate."""
|
||||
regions = [TextRegion(text="Test", bbox=[0, 0, 100, 100], confidence=0.9, page=1)]
|
||||
elements = []
|
||||
|
||||
should_activate, coverage = disabled_service.should_activate(regions, elements)
|
||||
assert should_activate is False
|
||||
assert coverage == 1.0
|
||||
|
||||
def test_should_activate_low_coverage(
|
||||
self,
|
||||
service: GapFillingService,
|
||||
sample_raw_regions: List[TextRegion],
|
||||
sample_pp_elements: List[DocumentElement]
|
||||
):
|
||||
"""Test activation when coverage is below threshold."""
|
||||
# Filter out low confidence regions
|
||||
valid_regions = [r for r in sample_raw_regions if r.confidence >= 0.3]
|
||||
|
||||
should_activate, coverage = service.should_activate(valid_regions, sample_pp_elements)
|
||||
|
||||
# Only 2 out of 5 valid regions are covered (Title, First paragraph)
|
||||
assert should_activate is True
|
||||
assert coverage < 0.7 # Below threshold
|
||||
|
||||
def test_should_not_activate_high_coverage(self, service: GapFillingService):
|
||||
"""Test no activation when coverage is above threshold."""
|
||||
# All regions covered
|
||||
regions = [
|
||||
TextRegion(text="Text 1", bbox=[100, 100, 200, 150], confidence=0.9, page=1),
|
||||
TextRegion(text="Text 2", bbox=[100, 200, 200, 250], confidence=0.9, page=1),
|
||||
]
|
||||
|
||||
elements = [
|
||||
DocumentElement(
|
||||
element_id="pp_1",
|
||||
type=ElementType.TEXT,
|
||||
content="Text 1",
|
||||
bbox=BoundingBox(x0=50, y0=50, x1=250, y1=200), # Covers first region
|
||||
confidence=0.95
|
||||
),
|
||||
DocumentElement(
|
||||
element_id="pp_2",
|
||||
type=ElementType.TEXT,
|
||||
content="Text 2",
|
||||
bbox=BoundingBox(x0=50, y0=180, x1=250, y1=300), # Covers second region
|
||||
confidence=0.95
|
||||
),
|
||||
]
|
||||
|
||||
should_activate, coverage = service.should_activate(regions, elements)
|
||||
|
||||
assert should_activate is False
|
||||
assert coverage >= 0.7
|
||||
|
||||
def test_find_uncovered_regions(
|
||||
self,
|
||||
service: GapFillingService,
|
||||
sample_raw_regions: List[TextRegion],
|
||||
sample_pp_elements: List[DocumentElement]
|
||||
):
|
||||
"""Test finding uncovered regions."""
|
||||
uncovered = service.find_uncovered_regions(sample_raw_regions, sample_pp_elements)
|
||||
|
||||
# Should find Header, Second paragraph, Footer (not Title, First paragraph, or low-confidence Noise)
|
||||
assert len(uncovered) == 3
|
||||
|
||||
uncovered_texts = [r.text for r in uncovered]
|
||||
assert "Header text" in uncovered_texts
|
||||
assert "Second paragraph" in uncovered_texts
|
||||
assert "Footer note" in uncovered_texts
|
||||
assert "Title of document" not in uncovered_texts # Covered
|
||||
assert "First paragraph" not in uncovered_texts # Covered
|
||||
assert "Noise" not in uncovered_texts # Low confidence
|
||||
|
||||
def test_coverage_by_center_point(self, service: GapFillingService):
|
||||
"""Test coverage detection via center point."""
|
||||
region = TextRegion(text="Test", bbox=[150, 150, 250, 200], confidence=0.9, page=1)
|
||||
|
||||
element = DocumentElement(
|
||||
element_id="pp_1",
|
||||
type=ElementType.TEXT,
|
||||
content="Container",
|
||||
bbox=BoundingBox(x0=100, y0=100, x1=300, y1=250), # Contains region's center
|
||||
confidence=0.95
|
||||
)
|
||||
|
||||
is_covered = service._is_region_covered(region, [element])
|
||||
assert is_covered is True
|
||||
|
||||
def test_coverage_by_iou(self, service: GapFillingService):
|
||||
"""Test coverage detection via IoU threshold."""
|
||||
region = TextRegion(text="Test", bbox=[100, 100, 200, 150], confidence=0.9, page=1)
|
||||
|
||||
element = DocumentElement(
|
||||
element_id="pp_1",
|
||||
type=ElementType.TEXT,
|
||||
content="Overlap",
|
||||
bbox=BoundingBox(x0=150, y0=100, x1=250, y1=150), # Partial overlap
|
||||
confidence=0.95
|
||||
)
|
||||
|
||||
# Calculate expected IoU
|
||||
# Intersection: (150-200) x (100-150) = 50 x 50 = 2500
|
||||
# Union: 100x50 + 100x50 - 2500 = 7500
|
||||
# IoU = 2500/7500 = 0.33 > 0.15 threshold
|
||||
|
||||
is_covered = service._is_region_covered(region, [element])
|
||||
assert is_covered is True
|
||||
|
||||
def test_deduplication(
|
||||
self,
|
||||
service: GapFillingService,
|
||||
sample_pp_elements: List[DocumentElement]
|
||||
):
|
||||
"""Test deduplication removes high-overlap regions."""
|
||||
uncovered = [
|
||||
# High overlap with pp_2 (First paragraph)
|
||||
TextRegion(text="First paragraph variant", bbox=[100, 200, 500, 250], confidence=0.9, page=1),
|
||||
# No overlap
|
||||
TextRegion(text="Unique region", bbox=[100, 500, 300, 550], confidence=0.9, page=1),
|
||||
]
|
||||
|
||||
deduplicated = service.deduplicate_regions(uncovered, sample_pp_elements)
|
||||
|
||||
assert len(deduplicated) == 1
|
||||
assert deduplicated[0].text == "Unique region"
|
||||
|
||||
def test_convert_regions_to_elements(self, service: GapFillingService):
|
||||
"""Test conversion of TextRegions to DocumentElements."""
|
||||
regions = [
|
||||
TextRegion(text="Test text 1", bbox=[100, 100, 200, 150], confidence=0.85, page=1),
|
||||
TextRegion(text="Test text 2", bbox=[100, 200, 200, 250], confidence=0.90, page=1),
|
||||
]
|
||||
|
||||
elements = service.convert_regions_to_elements(regions, page_number=1, start_element_id=0)
|
||||
|
||||
assert len(elements) == 2
|
||||
assert elements[0].element_id == "gap_fill_1_0"
|
||||
assert elements[0].type == ElementType.TEXT
|
||||
assert elements[0].content == "Test text 1"
|
||||
assert elements[0].confidence == 0.85
|
||||
assert elements[0].metadata.get('source') == 'gap_filling'
|
||||
|
||||
assert elements[1].element_id == "gap_fill_1_1"
|
||||
assert elements[1].content == "Test text 2"
|
||||
|
||||
def test_recalculate_reading_order(self, service: GapFillingService):
|
||||
"""Test reading order recalculation."""
|
||||
elements = [
|
||||
DocumentElement(
|
||||
element_id="e3",
|
||||
type=ElementType.TEXT,
|
||||
content="Bottom",
|
||||
bbox=BoundingBox(x0=100, y0=300, x1=200, y1=350),
|
||||
confidence=0.9
|
||||
),
|
||||
DocumentElement(
|
||||
element_id="e1",
|
||||
type=ElementType.TEXT,
|
||||
content="Top",
|
||||
bbox=BoundingBox(x0=100, y0=100, x1=200, y1=150),
|
||||
confidence=0.9
|
||||
),
|
||||
DocumentElement(
|
||||
element_id="e2",
|
||||
type=ElementType.TEXT,
|
||||
content="Middle",
|
||||
bbox=BoundingBox(x0=100, y0=200, x1=200, y1=250),
|
||||
confidence=0.9
|
||||
),
|
||||
]
|
||||
|
||||
reading_order = service.recalculate_reading_order(elements)
|
||||
|
||||
# Should be sorted by y0: Top (100), Middle (200), Bottom (300)
|
||||
assert reading_order == [1, 2, 0] # Indices of elements in reading order
|
||||
|
||||
def test_fill_gaps_integration(
|
||||
self,
|
||||
service: GapFillingService,
|
||||
):
|
||||
"""Integration test for fill_gaps method."""
|
||||
# Raw OCR regions (dict format as received from OCR service)
|
||||
raw_regions = [
|
||||
{'text': 'Header', 'bbox': [100, 50, 300, 80], 'confidence': 0.95, 'page': 1},
|
||||
{'text': 'Title', 'bbox': [100, 100, 500, 150], 'confidence': 0.92, 'page': 1},
|
||||
{'text': 'Paragraph 1', 'bbox': [100, 200, 500, 250], 'confidence': 0.90, 'page': 1},
|
||||
{'text': 'Paragraph 2', 'bbox': [100, 300, 500, 350], 'confidence': 0.88, 'page': 1},
|
||||
{'text': 'Paragraph 3', 'bbox': [100, 400, 500, 450], 'confidence': 0.86, 'page': 1},
|
||||
{'text': 'Footer', 'bbox': [100, 900, 300, 930], 'confidence': 0.85, 'page': 1},
|
||||
]
|
||||
|
||||
# PP-StructureV3 only detected Title (missing 5 out of 6 regions = 16.7% coverage)
|
||||
pp_elements = [
|
||||
DocumentElement(
|
||||
element_id="pp_1",
|
||||
type=ElementType.TITLE,
|
||||
content="Title",
|
||||
bbox=BoundingBox(x0=100, y0=100, x1=500, y1=150),
|
||||
confidence=0.95
|
||||
),
|
||||
]
|
||||
|
||||
supplemented, stats = service.fill_gaps(
|
||||
raw_ocr_regions=raw_regions,
|
||||
pp_structure_elements=pp_elements,
|
||||
page_number=1
|
||||
)
|
||||
|
||||
# Should have activated and supplemented missing regions
|
||||
assert stats['activated'] is True
|
||||
assert stats['coverage_ratio'] < 0.7
|
||||
assert len(supplemented) == 5 # Header, Paragraph 1, 2, 3, Footer
|
||||
|
||||
def test_fill_gaps_no_activation_when_coverage_high(self, service: GapFillingService):
|
||||
"""Test fill_gaps does not activate when coverage is high."""
|
||||
raw_regions = [
|
||||
{'text': 'Text 1', 'bbox': [100, 100, 200, 150], 'confidence': 0.9, 'page': 1},
|
||||
]
|
||||
|
||||
pp_elements = [
|
||||
DocumentElement(
|
||||
element_id="pp_1",
|
||||
type=ElementType.TEXT,
|
||||
content="Text 1",
|
||||
bbox=BoundingBox(x0=50, y0=50, x1=250, y1=200), # Fully covers
|
||||
confidence=0.95
|
||||
),
|
||||
]
|
||||
|
||||
supplemented, stats = service.fill_gaps(
|
||||
raw_ocr_regions=raw_regions,
|
||||
pp_structure_elements=pp_elements,
|
||||
page_number=1
|
||||
)
|
||||
|
||||
assert stats['activated'] is False
|
||||
assert len(supplemented) == 0
|
||||
|
||||
def test_skip_element_types_not_supplemented(self, service: GapFillingService):
|
||||
"""Test that TABLE/IMAGE/etc. elements are not supplemented over."""
|
||||
raw_regions = [
|
||||
{'text': 'Table cell text', 'bbox': [100, 100, 200, 150], 'confidence': 0.9, 'page': 1},
|
||||
]
|
||||
|
||||
# PP-StructureV3 has a table covering this region
|
||||
pp_elements = [
|
||||
DocumentElement(
|
||||
element_id="pp_1",
|
||||
type=ElementType.TABLE,
|
||||
content="<table>...</table>",
|
||||
bbox=BoundingBox(x0=50, y0=50, x1=250, y1=200),
|
||||
confidence=0.95
|
||||
),
|
||||
]
|
||||
|
||||
# The region should be considered covered by the table
|
||||
supplemented, stats = service.fill_gaps(
|
||||
raw_ocr_regions=raw_regions,
|
||||
pp_structure_elements=pp_elements,
|
||||
page_number=1
|
||||
)
|
||||
|
||||
# Should not supplement because the table covers it
|
||||
assert len(supplemented) == 0
|
||||
|
||||
def test_coordinate_scaling(self, service: GapFillingService):
|
||||
"""Test coordinate alignment with different dimensions."""
|
||||
# OCR was done at 2000x3000, PP-Structure at 1000x1500
|
||||
ocr_dimensions = {'width': 2000, 'height': 3000}
|
||||
pp_dimensions = Dimensions(width=1000, height=1500)
|
||||
|
||||
raw_regions = [
|
||||
# At OCR scale: (200, 300) to (400, 450) -> at PP scale: (100, 150) to (200, 225)
|
||||
{'text': 'Scaled text', 'bbox': [200, 300, 400, 450], 'confidence': 0.9, 'page': 1},
|
||||
]
|
||||
|
||||
pp_elements = [
|
||||
DocumentElement(
|
||||
element_id="pp_1",
|
||||
type=ElementType.TEXT,
|
||||
content="Scaled text",
|
||||
bbox=BoundingBox(x0=100, y0=150, x1=200, y1=225), # Should cover after scaling
|
||||
confidence=0.95
|
||||
),
|
||||
]
|
||||
|
||||
supplemented, stats = service.fill_gaps(
|
||||
raw_ocr_regions=raw_regions,
|
||||
pp_structure_elements=pp_elements,
|
||||
page_number=1,
|
||||
ocr_dimensions=ocr_dimensions,
|
||||
pp_dimensions=pp_dimensions
|
||||
)
|
||||
|
||||
# After scaling, the region should be covered
|
||||
assert stats['coverage_ratio'] >= 0.7 or len(supplemented) == 0
|
||||
|
||||
def test_iou_calculation(self, service: GapFillingService):
|
||||
"""Test IoU calculation accuracy."""
|
||||
# Two identical boxes
|
||||
bbox1 = (0, 0, 100, 100)
|
||||
bbox2 = (0, 0, 100, 100)
|
||||
assert service._calculate_iou(bbox1, bbox2) == 1.0
|
||||
|
||||
# No overlap
|
||||
bbox1 = (0, 0, 100, 100)
|
||||
bbox2 = (200, 200, 300, 300)
|
||||
assert service._calculate_iou(bbox1, bbox2) == 0.0
|
||||
|
||||
# 50% overlap
|
||||
bbox1 = (0, 0, 100, 100)
|
||||
bbox2 = (50, 0, 150, 100) # Shifted right by 50
|
||||
# Intersection: 50x100 = 5000
|
||||
# Union: 10000 + 10000 - 5000 = 15000
|
||||
# IoU = 5000/15000 = 0.333...
|
||||
iou = service._calculate_iou(bbox1, bbox2)
|
||||
assert abs(iou - 1/3) < 0.01
|
||||
|
||||
def test_point_in_bbox(self, service: GapFillingService):
|
||||
"""Test point-in-bbox check."""
|
||||
bbox = (100, 100, 200, 200)
|
||||
|
||||
# Inside
|
||||
assert service._point_in_bbox(150, 150, bbox) is True
|
||||
|
||||
# On edge
|
||||
assert service._point_in_bbox(100, 100, bbox) is True
|
||||
assert service._point_in_bbox(200, 200, bbox) is True
|
||||
|
||||
# Outside
|
||||
assert service._point_in_bbox(50, 150, bbox) is False
|
||||
assert service._point_in_bbox(250, 150, bbox) is False
|
||||
|
||||
def test_merge_adjacent_regions(self, service: GapFillingService):
|
||||
"""Test merging of adjacent text regions."""
|
||||
regions = [
|
||||
TextRegion(text="Hello", bbox=[100, 100, 150, 130], confidence=0.9, page=1),
|
||||
TextRegion(text="World", bbox=[160, 100, 210, 130], confidence=0.85, page=1), # Adjacent
|
||||
TextRegion(text="Far away", bbox=[100, 300, 200, 330], confidence=0.9, page=1), # Not adjacent
|
||||
]
|
||||
|
||||
merged = service.merge_adjacent_regions(regions, max_horizontal_gap=20, max_vertical_gap=10)
|
||||
|
||||
assert len(merged) == 2
|
||||
# First two should be merged
|
||||
assert "Hello" in merged[0].text and "World" in merged[0].text
|
||||
assert merged[1].text == "Far away"
|
||||
|
||||
|
||||
class TestTextRegion:
|
||||
"""Tests for TextRegion dataclass."""
|
||||
|
||||
def test_normalized_bbox_4_values(self):
|
||||
"""Test bbox normalization with 4 values."""
|
||||
region = TextRegion(text="Test", bbox=[100, 200, 300, 400], confidence=0.9, page=1)
|
||||
assert region.normalized_bbox == (100, 200, 300, 400)
|
||||
|
||||
def test_normalized_bbox_polygon_flat(self):
|
||||
"""Test bbox normalization with flat polygon format (8 values)."""
|
||||
# Polygon: 4 points as flat list [x1, y1, x2, y2, x3, y3, x4, y4]
|
||||
region = TextRegion(
|
||||
text="Test",
|
||||
bbox=[100, 200, 300, 200, 300, 400, 100, 400],
|
||||
confidence=0.9,
|
||||
page=1
|
||||
)
|
||||
assert region.normalized_bbox == (100, 200, 300, 400)
|
||||
|
||||
def test_normalized_bbox_polygon_nested(self):
|
||||
"""Test bbox normalization with nested polygon format (PaddleOCR format)."""
|
||||
# PaddleOCR format: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||||
region = TextRegion(
|
||||
text="Test",
|
||||
bbox=[[100, 200], [300, 200], [300, 400], [100, 400]],
|
||||
confidence=0.9,
|
||||
page=1
|
||||
)
|
||||
assert region.normalized_bbox == (100, 200, 300, 400)
|
||||
|
||||
def test_normalized_bbox_numpy_polygon(self):
|
||||
"""Test bbox normalization with numpy-like nested format."""
|
||||
# Sometimes PaddleOCR returns numpy arrays converted to lists
|
||||
region = TextRegion(
|
||||
text="Test",
|
||||
bbox=[[100.5, 200.5], [300.5, 200.5], [300.5, 400.5], [100.5, 400.5]],
|
||||
confidence=0.9,
|
||||
page=1
|
||||
)
|
||||
bbox = region.normalized_bbox
|
||||
assert bbox[0] == 100.5
|
||||
assert bbox[1] == 200.5
|
||||
assert bbox[2] == 300.5
|
||||
assert bbox[3] == 400.5
|
||||
|
||||
def test_center_calculation(self):
|
||||
"""Test center point calculation."""
|
||||
region = TextRegion(text="Test", bbox=[100, 200, 300, 400], confidence=0.9, page=1)
|
||||
assert region.center == (200, 300)
|
||||
|
||||
def test_center_calculation_nested_bbox(self):
|
||||
"""Test center point calculation with nested bbox format."""
|
||||
region = TextRegion(
|
||||
text="Test",
|
||||
bbox=[[100, 200], [300, 200], [300, 400], [100, 400]],
|
||||
confidence=0.9,
|
||||
page=1
|
||||
)
|
||||
assert region.center == (200, 300)
|
||||
|
||||
|
||||
class TestOCRToUnifiedConverterIntegration:
|
||||
"""Integration tests for OCRToUnifiedConverter with gap filling."""
|
||||
|
||||
def test_converter_with_gap_filling_enabled(self):
|
||||
"""Test converter initializes with gap filling enabled."""
|
||||
from app.services.ocr_to_unified_converter import OCRToUnifiedConverter
|
||||
|
||||
converter = OCRToUnifiedConverter(enable_gap_filling=True)
|
||||
assert converter.gap_filling_service is not None
|
||||
|
||||
def test_converter_with_gap_filling_disabled(self):
|
||||
"""Test converter initializes without gap filling."""
|
||||
from app.services.ocr_to_unified_converter import OCRToUnifiedConverter
|
||||
|
||||
converter = OCRToUnifiedConverter(enable_gap_filling=False)
|
||||
assert converter.gap_filling_service is None
|
||||
Reference in New Issue
Block a user