529 lines
18 KiB
Python
529 lines
18 KiB
Python
"""
|
|
Tool_OCR - OCR Service Unit Tests
|
|
Tests for app/services/ocr_service.py
|
|
"""
|
|
|
|
import pytest
|
|
import json
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
|
|
from app.services.ocr_service import OCRService
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestOCRServiceInit:
|
|
"""Test OCR service initialization"""
|
|
|
|
def test_init(self):
|
|
"""Test OCR service initialization"""
|
|
service = OCRService()
|
|
|
|
assert service is not None
|
|
assert service.ocr_engines == {}
|
|
assert service.structure_engine is None
|
|
assert service.confidence_threshold > 0
|
|
assert len(service.ocr_languages) > 0
|
|
|
|
def test_supported_languages(self):
|
|
"""Test that supported languages are configured"""
|
|
service = OCRService()
|
|
|
|
# Should have at least Chinese and English
|
|
assert 'ch' in service.ocr_languages or 'en' in service.ocr_languages
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestOCREngineLazyLoading:
|
|
"""Test OCR engine lazy loading"""
|
|
|
|
@patch('app.services.ocr_service.PaddleOCR')
|
|
def test_get_ocr_engine_creates_new_engine(self, mock_paddle_ocr):
|
|
"""Test that get_ocr_engine creates engine on first call"""
|
|
mock_engine = Mock()
|
|
mock_paddle_ocr.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
engine = service.get_ocr_engine(lang='en')
|
|
|
|
assert engine == mock_engine
|
|
mock_paddle_ocr.assert_called_once()
|
|
assert 'en' in service.ocr_engines
|
|
|
|
@patch('app.services.ocr_service.PaddleOCR')
|
|
def test_get_ocr_engine_reuses_existing_engine(self, mock_paddle_ocr):
|
|
"""Test that get_ocr_engine reuses existing engine"""
|
|
mock_engine = Mock()
|
|
mock_paddle_ocr.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
|
|
# First call creates engine
|
|
engine1 = service.get_ocr_engine(lang='en')
|
|
# Second call should reuse
|
|
engine2 = service.get_ocr_engine(lang='en')
|
|
|
|
assert engine1 == engine2
|
|
mock_paddle_ocr.assert_called_once()
|
|
|
|
@patch('app.services.ocr_service.PaddleOCR')
|
|
def test_get_ocr_engine_different_languages(self, mock_paddle_ocr):
|
|
"""Test that different languages get different engines"""
|
|
mock_paddle_ocr.return_value = Mock()
|
|
|
|
service = OCRService()
|
|
|
|
engine_en = service.get_ocr_engine(lang='en')
|
|
engine_ch = service.get_ocr_engine(lang='ch')
|
|
|
|
assert 'en' in service.ocr_engines
|
|
assert 'ch' in service.ocr_engines
|
|
assert mock_paddle_ocr.call_count == 2
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestStructureEngineLazyLoading:
|
|
"""Test structure engine lazy loading"""
|
|
|
|
@patch('app.services.ocr_service.PPStructureV3')
|
|
def test_get_structure_engine_creates_new_engine(self, mock_structure):
|
|
"""Test that get_structure_engine creates engine on first call"""
|
|
mock_engine = Mock()
|
|
mock_structure.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
engine = service.get_structure_engine()
|
|
|
|
assert engine == mock_engine
|
|
mock_structure.assert_called_once()
|
|
assert service.structure_engine == mock_engine
|
|
|
|
@patch('app.services.ocr_service.PPStructureV3')
|
|
def test_get_structure_engine_reuses_existing_engine(self, mock_structure):
|
|
"""Test that get_structure_engine reuses existing engine"""
|
|
mock_engine = Mock()
|
|
mock_structure.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
|
|
# First call creates engine
|
|
engine1 = service.get_structure_engine()
|
|
# Second call should reuse
|
|
engine2 = service.get_structure_engine()
|
|
|
|
assert engine1 == engine2
|
|
mock_structure.assert_called_once()
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestProcessImageMocked:
|
|
"""Test image processing with mocked OCR engines"""
|
|
|
|
@patch('app.services.ocr_service.PaddleOCR')
|
|
def test_process_image_success(self, mock_paddle_ocr, sample_image_path):
|
|
"""Test successful image processing"""
|
|
# Mock OCR results - PaddleOCR 3.x format
|
|
mock_ocr_results = [{
|
|
'rec_texts': ['Hello World', 'Test Text'],
|
|
'rec_scores': [0.95, 0.88],
|
|
'rec_polys': [
|
|
[[10, 10], [100, 10], [100, 30], [10, 30]],
|
|
[[10, 40], [100, 40], [100, 60], [10, 60]]
|
|
]
|
|
}]
|
|
|
|
mock_engine = Mock()
|
|
mock_engine.ocr.return_value = mock_ocr_results
|
|
mock_paddle_ocr.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
result = service.process_image(sample_image_path, detect_layout=False)
|
|
|
|
assert result['status'] == 'success'
|
|
assert result['file_name'] == sample_image_path.name
|
|
assert result['language'] == 'ch'
|
|
assert result['total_text_regions'] == 2
|
|
assert result['average_confidence'] > 0.8
|
|
assert len(result['text_regions']) == 2
|
|
assert 'markdown_content' in result
|
|
assert 'processing_time' in result
|
|
|
|
@patch('app.services.ocr_service.PaddleOCR')
|
|
def test_process_image_filters_low_confidence(self, mock_paddle_ocr, sample_image_path):
|
|
"""Test that low confidence results are filtered"""
|
|
# Mock OCR results with varying confidence - PaddleOCR 3.x format
|
|
mock_ocr_results = [{
|
|
'rec_texts': ['High Confidence', 'Low Confidence'],
|
|
'rec_scores': [0.95, 0.50],
|
|
'rec_polys': [
|
|
[[10, 10], [100, 10], [100, 30], [10, 30]],
|
|
[[10, 40], [100, 40], [100, 60], [10, 60]]
|
|
]
|
|
}]
|
|
|
|
mock_engine = Mock()
|
|
mock_engine.ocr.return_value = mock_ocr_results
|
|
mock_paddle_ocr.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
result = service.process_image(
|
|
sample_image_path,
|
|
detect_layout=False,
|
|
confidence_threshold=0.80
|
|
)
|
|
|
|
assert result['status'] == 'success'
|
|
assert result['total_text_regions'] == 1 # Only high confidence
|
|
assert result['text_regions'][0]['text'] == 'High Confidence'
|
|
|
|
@patch('app.services.ocr_service.PaddleOCR')
|
|
def test_process_image_empty_results(self, mock_paddle_ocr, sample_image_path):
|
|
"""Test processing image with no text detected"""
|
|
mock_ocr_results = [[]]
|
|
|
|
mock_engine = Mock()
|
|
mock_engine.ocr.return_value = mock_ocr_results
|
|
mock_paddle_ocr.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
result = service.process_image(sample_image_path, detect_layout=False)
|
|
|
|
assert result['status'] == 'success'
|
|
assert result['total_text_regions'] == 0
|
|
assert result['average_confidence'] == 0.0
|
|
|
|
@patch('app.services.ocr_service.PaddleOCR')
|
|
def test_process_image_error_handling(self, mock_paddle_ocr, sample_image_path):
|
|
"""Test error handling during OCR processing"""
|
|
mock_engine = Mock()
|
|
mock_engine.ocr.side_effect = Exception("OCR engine error")
|
|
mock_paddle_ocr.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
result = service.process_image(sample_image_path, detect_layout=False)
|
|
|
|
assert result['status'] == 'error'
|
|
assert 'error_message' in result
|
|
assert 'OCR engine error' in result['error_message']
|
|
|
|
@patch('app.services.ocr_service.PaddleOCR')
|
|
def test_process_image_different_languages(self, mock_paddle_ocr, sample_image_path):
|
|
"""Test processing with different languages"""
|
|
mock_ocr_results = [[
|
|
[[[10, 10], [100, 10], [100, 30], [10, 30]], ('Text', 0.95)]
|
|
]]
|
|
|
|
mock_engine = Mock()
|
|
mock_engine.ocr.return_value = mock_ocr_results
|
|
mock_paddle_ocr.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
|
|
# Test English
|
|
result_en = service.process_image(sample_image_path, lang='en', detect_layout=False)
|
|
assert result_en['language'] == 'en'
|
|
|
|
# Test Chinese
|
|
result_ch = service.process_image(sample_image_path, lang='ch', detect_layout=False)
|
|
assert result_ch['language'] == 'ch'
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestLayoutAnalysisMocked:
|
|
"""Test layout analysis with mocked structure engine"""
|
|
|
|
@patch('app.services.ocr_service.PPStructureV3')
|
|
def test_analyze_layout_success(self, mock_structure, sample_image_path):
|
|
"""Test successful layout analysis"""
|
|
# Create mock page result with markdown attribute (PP-StructureV3 format)
|
|
mock_page_result = Mock()
|
|
mock_page_result.markdown = {
|
|
'markdown_texts': 'Document Title\n\nParagraph content',
|
|
'markdown_images': {}
|
|
}
|
|
|
|
# PP-Structure predict() returns a list of page results
|
|
mock_engine = Mock()
|
|
mock_engine.predict.return_value = [mock_page_result]
|
|
mock_structure.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
layout_data, images_metadata = service.analyze_layout(sample_image_path)
|
|
|
|
assert layout_data is not None
|
|
assert layout_data['total_elements'] == 1
|
|
assert len(layout_data['elements']) == 1
|
|
assert layout_data['elements'][0]['type'] == 'text'
|
|
assert 'Document Title' in layout_data['elements'][0]['content']
|
|
|
|
@patch('app.services.ocr_service.PPStructureV3')
|
|
def test_analyze_layout_with_table(self, mock_structure, sample_image_path):
|
|
"""Test layout analysis with table element"""
|
|
# Create mock page result with table in markdown (PP-StructureV3 format)
|
|
mock_page_result = Mock()
|
|
mock_page_result.markdown = {
|
|
'markdown_texts': '<table><tr><td>Cell 1</td></tr></table>',
|
|
'markdown_images': {}
|
|
}
|
|
|
|
# PP-Structure predict() returns a list of page results
|
|
mock_engine = Mock()
|
|
mock_engine.predict.return_value = [mock_page_result]
|
|
mock_structure.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
layout_data, images_metadata = service.analyze_layout(sample_image_path)
|
|
|
|
assert layout_data is not None
|
|
assert layout_data['elements'][0]['type'] == 'table'
|
|
# Content should contain the HTML table
|
|
assert '<table>' in layout_data['elements'][0]['content']
|
|
|
|
@patch('app.services.ocr_service.PPStructureV3')
|
|
def test_analyze_layout_error_handling(self, mock_structure, sample_image_path):
|
|
"""Test error handling in layout analysis"""
|
|
mock_engine = Mock()
|
|
mock_engine.side_effect = Exception("Structure analysis error")
|
|
mock_structure.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
layout_data, images_metadata = service.analyze_layout(sample_image_path)
|
|
|
|
assert layout_data is None
|
|
assert images_metadata == []
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestMarkdownGeneration:
|
|
"""Test Markdown generation"""
|
|
|
|
def test_generate_markdown_from_text_regions(self):
|
|
"""Test Markdown generation from text regions only"""
|
|
service = OCRService()
|
|
|
|
text_regions = [
|
|
{'text': 'First line', 'bbox': [[10, 10], [100, 10], [100, 30], [10, 30]]},
|
|
{'text': 'Second line', 'bbox': [[10, 40], [100, 40], [100, 60], [10, 60]]},
|
|
{'text': 'Third line', 'bbox': [[10, 70], [100, 70], [100, 90], [10, 90]]},
|
|
]
|
|
|
|
markdown = service.generate_markdown(text_regions)
|
|
|
|
assert 'First line' in markdown
|
|
assert 'Second line' in markdown
|
|
assert 'Third line' in markdown
|
|
|
|
def test_generate_markdown_with_layout(self):
|
|
"""Test Markdown generation with layout information"""
|
|
service = OCRService()
|
|
|
|
text_regions = []
|
|
layout_data = {
|
|
'elements': [
|
|
{'type': 'title', 'content': 'Document Title'},
|
|
{'type': 'text', 'content': 'Paragraph text'},
|
|
{'type': 'figure', 'element_id': 0},
|
|
]
|
|
}
|
|
|
|
markdown = service.generate_markdown(text_regions, layout_data)
|
|
|
|
assert '# Document Title' in markdown
|
|
assert 'Paragraph text' in markdown
|
|
assert '![Figure 0]' in markdown
|
|
|
|
def test_generate_markdown_with_table(self):
|
|
"""Test Markdown generation with table"""
|
|
service = OCRService()
|
|
|
|
layout_data = {
|
|
'elements': [
|
|
{
|
|
'type': 'table',
|
|
'content': '<table><tr><td>Cell</td></tr></table>'
|
|
}
|
|
]
|
|
}
|
|
|
|
markdown = service.generate_markdown([], layout_data)
|
|
|
|
assert '<table>' in markdown
|
|
|
|
def test_generate_markdown_empty_input(self):
|
|
"""Test Markdown generation with empty input"""
|
|
service = OCRService()
|
|
|
|
markdown = service.generate_markdown([])
|
|
|
|
assert markdown == ""
|
|
|
|
def test_generate_markdown_sorts_by_position(self):
|
|
"""Test that text regions are sorted by vertical position"""
|
|
service = OCRService()
|
|
|
|
# Create text regions in reverse order
|
|
text_regions = [
|
|
{'text': 'Bottom', 'bbox': [[10, 90], [100, 90], [100, 110], [10, 110]]},
|
|
{'text': 'Top', 'bbox': [[10, 10], [100, 10], [100, 30], [10, 30]]},
|
|
{'text': 'Middle', 'bbox': [[10, 50], [100, 50], [100, 70], [10, 70]]},
|
|
]
|
|
|
|
markdown = service.generate_markdown(text_regions)
|
|
lines = markdown.strip().split('\n')
|
|
|
|
# Should be sorted top to bottom
|
|
assert lines[0] == 'Top'
|
|
assert lines[1] == 'Middle'
|
|
assert lines[2] == 'Bottom'
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestSaveResults:
|
|
"""Test saving OCR results"""
|
|
|
|
def test_save_results_success(self, temp_dir):
|
|
"""Test successful saving of results"""
|
|
service = OCRService()
|
|
|
|
result = {
|
|
'status': 'success',
|
|
'file_name': 'test.png',
|
|
'text_regions': [{'text': 'Hello', 'confidence': 0.95}],
|
|
'markdown_content': '# Hello\n\nTest content',
|
|
}
|
|
|
|
json_path, md_path = service.save_results(result, temp_dir, 'test123')
|
|
|
|
assert json_path is not None
|
|
assert md_path is not None
|
|
assert json_path.exists()
|
|
assert md_path.exists()
|
|
|
|
# Verify JSON content
|
|
with open(json_path, 'r') as f:
|
|
saved_result = json.load(f)
|
|
assert saved_result['file_name'] == 'test.png'
|
|
|
|
# Verify Markdown content
|
|
md_content = md_path.read_text()
|
|
assert 'Hello' in md_content
|
|
|
|
def test_save_results_creates_directory(self, temp_dir):
|
|
"""Test that save_results creates output directory if needed"""
|
|
service = OCRService()
|
|
output_dir = temp_dir / "subdir" / "results"
|
|
|
|
result = {
|
|
'status': 'success',
|
|
'markdown_content': 'Test',
|
|
}
|
|
|
|
json_path, md_path = service.save_results(result, output_dir, 'test')
|
|
|
|
assert output_dir.exists()
|
|
assert json_path.exists()
|
|
|
|
def test_save_results_handles_unicode(self, temp_dir):
|
|
"""Test saving results with Unicode characters"""
|
|
service = OCRService()
|
|
|
|
result = {
|
|
'status': 'success',
|
|
'text_regions': [{'text': '你好世界', 'confidence': 0.95}],
|
|
'markdown_content': '# 你好世界\n\n测试内容',
|
|
}
|
|
|
|
json_path, md_path = service.save_results(result, temp_dir, 'unicode_test')
|
|
|
|
# Verify Unicode is preserved
|
|
with open(json_path, 'r', encoding='utf-8') as f:
|
|
saved_result = json.load(f)
|
|
assert saved_result['text_regions'][0]['text'] == '你好世界'
|
|
|
|
md_content = md_path.read_text(encoding='utf-8')
|
|
assert '你好世界' in md_content
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestEdgeCases:
|
|
"""Test edge cases and error handling"""
|
|
|
|
@patch('app.services.ocr_service.PaddleOCR')
|
|
def test_process_image_with_none_results(self, mock_paddle_ocr, sample_image_path):
|
|
"""Test processing when OCR returns None"""
|
|
mock_engine = Mock()
|
|
mock_engine.ocr.return_value = None
|
|
mock_paddle_ocr.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
result = service.process_image(sample_image_path, detect_layout=False)
|
|
|
|
assert result['status'] == 'success'
|
|
assert result['total_text_regions'] == 0
|
|
|
|
@patch('app.services.ocr_service.PaddleOCR')
|
|
def test_process_image_with_custom_threshold(self, mock_paddle_ocr, sample_image_path):
|
|
"""Test processing with custom confidence threshold"""
|
|
# PaddleOCR 3.x format
|
|
mock_ocr_results = [{
|
|
'rec_texts': ['Text'],
|
|
'rec_scores': [0.85],
|
|
'rec_polys': [[[10, 10], [100, 10], [100, 30], [10, 30]]]
|
|
}]
|
|
|
|
mock_engine = Mock()
|
|
mock_engine.ocr.return_value = mock_ocr_results
|
|
mock_paddle_ocr.return_value = mock_engine
|
|
|
|
service = OCRService()
|
|
|
|
# With high threshold - should filter out
|
|
result_high = service.process_image(
|
|
sample_image_path,
|
|
detect_layout=False,
|
|
confidence_threshold=0.90
|
|
)
|
|
assert result_high['total_text_regions'] == 0
|
|
|
|
# With low threshold - should include
|
|
result_low = service.process_image(
|
|
sample_image_path,
|
|
detect_layout=False,
|
|
confidence_threshold=0.80
|
|
)
|
|
assert result_low['total_text_regions'] == 1
|
|
|
|
|
|
# Integration tests that require actual PaddleOCR models
|
|
@pytest.mark.requires_models
|
|
@pytest.mark.slow
|
|
class TestOCRServiceIntegration:
|
|
"""
|
|
Integration tests that require actual PaddleOCR models
|
|
These tests will download models (~900MB) on first run
|
|
Run with: pytest -m requires_models
|
|
"""
|
|
|
|
def test_real_ocr_engine_initialization(self):
|
|
"""Test real PaddleOCR engine initialization"""
|
|
service = OCRService()
|
|
engine = service.get_ocr_engine(lang='en')
|
|
|
|
assert engine is not None
|
|
assert hasattr(engine, 'ocr')
|
|
|
|
def test_real_structure_engine_initialization(self):
|
|
"""Test real PP-Structure engine initialization"""
|
|
service = OCRService()
|
|
engine = service.get_structure_engine()
|
|
|
|
assert engine is not None
|
|
|
|
def test_real_image_processing(self, sample_image_with_text):
|
|
"""Test processing real image with text"""
|
|
service = OCRService()
|
|
result = service.process_image(sample_image_with_text, lang='en')
|
|
|
|
assert result['status'] == 'success'
|
|
assert result['total_text_regions'] > 0
|