first
This commit is contained in:
528
backend/tests/test_ocr_service.py
Normal file
528
backend/tests/test_ocr_service.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
Tool_OCR - OCR Service Unit Tests
|
||||
Tests for app/services/ocr_service.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from app.services.ocr_service import OCRService
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOCRServiceInit:
|
||||
"""Test OCR service initialization"""
|
||||
|
||||
def test_init(self):
|
||||
"""Test OCR service initialization"""
|
||||
service = OCRService()
|
||||
|
||||
assert service is not None
|
||||
assert service.ocr_engines == {}
|
||||
assert service.structure_engine is None
|
||||
assert service.confidence_threshold > 0
|
||||
assert len(service.ocr_languages) > 0
|
||||
|
||||
def test_supported_languages(self):
|
||||
"""Test that supported languages are configured"""
|
||||
service = OCRService()
|
||||
|
||||
# Should have at least Chinese and English
|
||||
assert 'ch' in service.ocr_languages or 'en' in service.ocr_languages
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOCREngineLazyLoading:
|
||||
"""Test OCR engine lazy loading"""
|
||||
|
||||
@patch('app.services.ocr_service.PaddleOCR')
|
||||
def test_get_ocr_engine_creates_new_engine(self, mock_paddle_ocr):
|
||||
"""Test that get_ocr_engine creates engine on first call"""
|
||||
mock_engine = Mock()
|
||||
mock_paddle_ocr.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
engine = service.get_ocr_engine(lang='en')
|
||||
|
||||
assert engine == mock_engine
|
||||
mock_paddle_ocr.assert_called_once()
|
||||
assert 'en' in service.ocr_engines
|
||||
|
||||
@patch('app.services.ocr_service.PaddleOCR')
|
||||
def test_get_ocr_engine_reuses_existing_engine(self, mock_paddle_ocr):
|
||||
"""Test that get_ocr_engine reuses existing engine"""
|
||||
mock_engine = Mock()
|
||||
mock_paddle_ocr.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
|
||||
# First call creates engine
|
||||
engine1 = service.get_ocr_engine(lang='en')
|
||||
# Second call should reuse
|
||||
engine2 = service.get_ocr_engine(lang='en')
|
||||
|
||||
assert engine1 == engine2
|
||||
mock_paddle_ocr.assert_called_once()
|
||||
|
||||
@patch('app.services.ocr_service.PaddleOCR')
|
||||
def test_get_ocr_engine_different_languages(self, mock_paddle_ocr):
|
||||
"""Test that different languages get different engines"""
|
||||
mock_paddle_ocr.return_value = Mock()
|
||||
|
||||
service = OCRService()
|
||||
|
||||
engine_en = service.get_ocr_engine(lang='en')
|
||||
engine_ch = service.get_ocr_engine(lang='ch')
|
||||
|
||||
assert 'en' in service.ocr_engines
|
||||
assert 'ch' in service.ocr_engines
|
||||
assert mock_paddle_ocr.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStructureEngineLazyLoading:
|
||||
"""Test structure engine lazy loading"""
|
||||
|
||||
@patch('app.services.ocr_service.PPStructureV3')
|
||||
def test_get_structure_engine_creates_new_engine(self, mock_structure):
|
||||
"""Test that get_structure_engine creates engine on first call"""
|
||||
mock_engine = Mock()
|
||||
mock_structure.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
engine = service.get_structure_engine()
|
||||
|
||||
assert engine == mock_engine
|
||||
mock_structure.assert_called_once()
|
||||
assert service.structure_engine == mock_engine
|
||||
|
||||
@patch('app.services.ocr_service.PPStructureV3')
|
||||
def test_get_structure_engine_reuses_existing_engine(self, mock_structure):
|
||||
"""Test that get_structure_engine reuses existing engine"""
|
||||
mock_engine = Mock()
|
||||
mock_structure.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
|
||||
# First call creates engine
|
||||
engine1 = service.get_structure_engine()
|
||||
# Second call should reuse
|
||||
engine2 = service.get_structure_engine()
|
||||
|
||||
assert engine1 == engine2
|
||||
mock_structure.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestProcessImageMocked:
|
||||
"""Test image processing with mocked OCR engines"""
|
||||
|
||||
@patch('app.services.ocr_service.PaddleOCR')
|
||||
def test_process_image_success(self, mock_paddle_ocr, sample_image_path):
|
||||
"""Test successful image processing"""
|
||||
# Mock OCR results - PaddleOCR 3.x format
|
||||
mock_ocr_results = [{
|
||||
'rec_texts': ['Hello World', 'Test Text'],
|
||||
'rec_scores': [0.95, 0.88],
|
||||
'rec_polys': [
|
||||
[[10, 10], [100, 10], [100, 30], [10, 30]],
|
||||
[[10, 40], [100, 40], [100, 60], [10, 60]]
|
||||
]
|
||||
}]
|
||||
|
||||
mock_engine = Mock()
|
||||
mock_engine.ocr.return_value = mock_ocr_results
|
||||
mock_paddle_ocr.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
result = service.process_image(sample_image_path, detect_layout=False)
|
||||
|
||||
assert result['status'] == 'success'
|
||||
assert result['file_name'] == sample_image_path.name
|
||||
assert result['language'] == 'ch'
|
||||
assert result['total_text_regions'] == 2
|
||||
assert result['average_confidence'] > 0.8
|
||||
assert len(result['text_regions']) == 2
|
||||
assert 'markdown_content' in result
|
||||
assert 'processing_time' in result
|
||||
|
||||
@patch('app.services.ocr_service.PaddleOCR')
|
||||
def test_process_image_filters_low_confidence(self, mock_paddle_ocr, sample_image_path):
|
||||
"""Test that low confidence results are filtered"""
|
||||
# Mock OCR results with varying confidence - PaddleOCR 3.x format
|
||||
mock_ocr_results = [{
|
||||
'rec_texts': ['High Confidence', 'Low Confidence'],
|
||||
'rec_scores': [0.95, 0.50],
|
||||
'rec_polys': [
|
||||
[[10, 10], [100, 10], [100, 30], [10, 30]],
|
||||
[[10, 40], [100, 40], [100, 60], [10, 60]]
|
||||
]
|
||||
}]
|
||||
|
||||
mock_engine = Mock()
|
||||
mock_engine.ocr.return_value = mock_ocr_results
|
||||
mock_paddle_ocr.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
result = service.process_image(
|
||||
sample_image_path,
|
||||
detect_layout=False,
|
||||
confidence_threshold=0.80
|
||||
)
|
||||
|
||||
assert result['status'] == 'success'
|
||||
assert result['total_text_regions'] == 1 # Only high confidence
|
||||
assert result['text_regions'][0]['text'] == 'High Confidence'
|
||||
|
||||
@patch('app.services.ocr_service.PaddleOCR')
|
||||
def test_process_image_empty_results(self, mock_paddle_ocr, sample_image_path):
|
||||
"""Test processing image with no text detected"""
|
||||
mock_ocr_results = [[]]
|
||||
|
||||
mock_engine = Mock()
|
||||
mock_engine.ocr.return_value = mock_ocr_results
|
||||
mock_paddle_ocr.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
result = service.process_image(sample_image_path, detect_layout=False)
|
||||
|
||||
assert result['status'] == 'success'
|
||||
assert result['total_text_regions'] == 0
|
||||
assert result['average_confidence'] == 0.0
|
||||
|
||||
@patch('app.services.ocr_service.PaddleOCR')
|
||||
def test_process_image_error_handling(self, mock_paddle_ocr, sample_image_path):
|
||||
"""Test error handling during OCR processing"""
|
||||
mock_engine = Mock()
|
||||
mock_engine.ocr.side_effect = Exception("OCR engine error")
|
||||
mock_paddle_ocr.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
result = service.process_image(sample_image_path, detect_layout=False)
|
||||
|
||||
assert result['status'] == 'error'
|
||||
assert 'error_message' in result
|
||||
assert 'OCR engine error' in result['error_message']
|
||||
|
||||
@patch('app.services.ocr_service.PaddleOCR')
|
||||
def test_process_image_different_languages(self, mock_paddle_ocr, sample_image_path):
|
||||
"""Test processing with different languages"""
|
||||
mock_ocr_results = [[
|
||||
[[[10, 10], [100, 10], [100, 30], [10, 30]], ('Text', 0.95)]
|
||||
]]
|
||||
|
||||
mock_engine = Mock()
|
||||
mock_engine.ocr.return_value = mock_ocr_results
|
||||
mock_paddle_ocr.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
|
||||
# Test English
|
||||
result_en = service.process_image(sample_image_path, lang='en', detect_layout=False)
|
||||
assert result_en['language'] == 'en'
|
||||
|
||||
# Test Chinese
|
||||
result_ch = service.process_image(sample_image_path, lang='ch', detect_layout=False)
|
||||
assert result_ch['language'] == 'ch'
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLayoutAnalysisMocked:
|
||||
"""Test layout analysis with mocked structure engine"""
|
||||
|
||||
@patch('app.services.ocr_service.PPStructureV3')
|
||||
def test_analyze_layout_success(self, mock_structure, sample_image_path):
|
||||
"""Test successful layout analysis"""
|
||||
# Create mock page result with markdown attribute (PP-StructureV3 format)
|
||||
mock_page_result = Mock()
|
||||
mock_page_result.markdown = {
|
||||
'markdown_texts': 'Document Title\n\nParagraph content',
|
||||
'markdown_images': {}
|
||||
}
|
||||
|
||||
# PP-Structure predict() returns a list of page results
|
||||
mock_engine = Mock()
|
||||
mock_engine.predict.return_value = [mock_page_result]
|
||||
mock_structure.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
layout_data, images_metadata = service.analyze_layout(sample_image_path)
|
||||
|
||||
assert layout_data is not None
|
||||
assert layout_data['total_elements'] == 1
|
||||
assert len(layout_data['elements']) == 1
|
||||
assert layout_data['elements'][0]['type'] == 'text'
|
||||
assert 'Document Title' in layout_data['elements'][0]['content']
|
||||
|
||||
@patch('app.services.ocr_service.PPStructureV3')
|
||||
def test_analyze_layout_with_table(self, mock_structure, sample_image_path):
|
||||
"""Test layout analysis with table element"""
|
||||
# Create mock page result with table in markdown (PP-StructureV3 format)
|
||||
mock_page_result = Mock()
|
||||
mock_page_result.markdown = {
|
||||
'markdown_texts': '<table><tr><td>Cell 1</td></tr></table>',
|
||||
'markdown_images': {}
|
||||
}
|
||||
|
||||
# PP-Structure predict() returns a list of page results
|
||||
mock_engine = Mock()
|
||||
mock_engine.predict.return_value = [mock_page_result]
|
||||
mock_structure.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
layout_data, images_metadata = service.analyze_layout(sample_image_path)
|
||||
|
||||
assert layout_data is not None
|
||||
assert layout_data['elements'][0]['type'] == 'table'
|
||||
# Content should contain the HTML table
|
||||
assert '<table>' in layout_data['elements'][0]['content']
|
||||
|
||||
@patch('app.services.ocr_service.PPStructureV3')
|
||||
def test_analyze_layout_error_handling(self, mock_structure, sample_image_path):
|
||||
"""Test error handling in layout analysis"""
|
||||
mock_engine = Mock()
|
||||
mock_engine.side_effect = Exception("Structure analysis error")
|
||||
mock_structure.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
layout_data, images_metadata = service.analyze_layout(sample_image_path)
|
||||
|
||||
assert layout_data is None
|
||||
assert images_metadata == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMarkdownGeneration:
|
||||
"""Test Markdown generation"""
|
||||
|
||||
def test_generate_markdown_from_text_regions(self):
|
||||
"""Test Markdown generation from text regions only"""
|
||||
service = OCRService()
|
||||
|
||||
text_regions = [
|
||||
{'text': 'First line', 'bbox': [[10, 10], [100, 10], [100, 30], [10, 30]]},
|
||||
{'text': 'Second line', 'bbox': [[10, 40], [100, 40], [100, 60], [10, 60]]},
|
||||
{'text': 'Third line', 'bbox': [[10, 70], [100, 70], [100, 90], [10, 90]]},
|
||||
]
|
||||
|
||||
markdown = service.generate_markdown(text_regions)
|
||||
|
||||
assert 'First line' in markdown
|
||||
assert 'Second line' in markdown
|
||||
assert 'Third line' in markdown
|
||||
|
||||
def test_generate_markdown_with_layout(self):
|
||||
"""Test Markdown generation with layout information"""
|
||||
service = OCRService()
|
||||
|
||||
text_regions = []
|
||||
layout_data = {
|
||||
'elements': [
|
||||
{'type': 'title', 'content': 'Document Title'},
|
||||
{'type': 'text', 'content': 'Paragraph text'},
|
||||
{'type': 'figure', 'element_id': 0},
|
||||
]
|
||||
}
|
||||
|
||||
markdown = service.generate_markdown(text_regions, layout_data)
|
||||
|
||||
assert '# Document Title' in markdown
|
||||
assert 'Paragraph text' in markdown
|
||||
assert '![Figure 0]' in markdown
|
||||
|
||||
def test_generate_markdown_with_table(self):
|
||||
"""Test Markdown generation with table"""
|
||||
service = OCRService()
|
||||
|
||||
layout_data = {
|
||||
'elements': [
|
||||
{
|
||||
'type': 'table',
|
||||
'content': '<table><tr><td>Cell</td></tr></table>'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
markdown = service.generate_markdown([], layout_data)
|
||||
|
||||
assert '<table>' in markdown
|
||||
|
||||
def test_generate_markdown_empty_input(self):
|
||||
"""Test Markdown generation with empty input"""
|
||||
service = OCRService()
|
||||
|
||||
markdown = service.generate_markdown([])
|
||||
|
||||
assert markdown == ""
|
||||
|
||||
def test_generate_markdown_sorts_by_position(self):
|
||||
"""Test that text regions are sorted by vertical position"""
|
||||
service = OCRService()
|
||||
|
||||
# Create text regions in reverse order
|
||||
text_regions = [
|
||||
{'text': 'Bottom', 'bbox': [[10, 90], [100, 90], [100, 110], [10, 110]]},
|
||||
{'text': 'Top', 'bbox': [[10, 10], [100, 10], [100, 30], [10, 30]]},
|
||||
{'text': 'Middle', 'bbox': [[10, 50], [100, 50], [100, 70], [10, 70]]},
|
||||
]
|
||||
|
||||
markdown = service.generate_markdown(text_regions)
|
||||
lines = markdown.strip().split('\n')
|
||||
|
||||
# Should be sorted top to bottom
|
||||
assert lines[0] == 'Top'
|
||||
assert lines[1] == 'Middle'
|
||||
assert lines[2] == 'Bottom'
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSaveResults:
|
||||
"""Test saving OCR results"""
|
||||
|
||||
def test_save_results_success(self, temp_dir):
|
||||
"""Test successful saving of results"""
|
||||
service = OCRService()
|
||||
|
||||
result = {
|
||||
'status': 'success',
|
||||
'file_name': 'test.png',
|
||||
'text_regions': [{'text': 'Hello', 'confidence': 0.95}],
|
||||
'markdown_content': '# Hello\n\nTest content',
|
||||
}
|
||||
|
||||
json_path, md_path = service.save_results(result, temp_dir, 'test123')
|
||||
|
||||
assert json_path is not None
|
||||
assert md_path is not None
|
||||
assert json_path.exists()
|
||||
assert md_path.exists()
|
||||
|
||||
# Verify JSON content
|
||||
with open(json_path, 'r') as f:
|
||||
saved_result = json.load(f)
|
||||
assert saved_result['file_name'] == 'test.png'
|
||||
|
||||
# Verify Markdown content
|
||||
md_content = md_path.read_text()
|
||||
assert 'Hello' in md_content
|
||||
|
||||
def test_save_results_creates_directory(self, temp_dir):
|
||||
"""Test that save_results creates output directory if needed"""
|
||||
service = OCRService()
|
||||
output_dir = temp_dir / "subdir" / "results"
|
||||
|
||||
result = {
|
||||
'status': 'success',
|
||||
'markdown_content': 'Test',
|
||||
}
|
||||
|
||||
json_path, md_path = service.save_results(result, output_dir, 'test')
|
||||
|
||||
assert output_dir.exists()
|
||||
assert json_path.exists()
|
||||
|
||||
def test_save_results_handles_unicode(self, temp_dir):
|
||||
"""Test saving results with Unicode characters"""
|
||||
service = OCRService()
|
||||
|
||||
result = {
|
||||
'status': 'success',
|
||||
'text_regions': [{'text': '你好世界', 'confidence': 0.95}],
|
||||
'markdown_content': '# 你好世界\n\n测试内容',
|
||||
}
|
||||
|
||||
json_path, md_path = service.save_results(result, temp_dir, 'unicode_test')
|
||||
|
||||
# Verify Unicode is preserved
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
saved_result = json.load(f)
|
||||
assert saved_result['text_regions'][0]['text'] == '你好世界'
|
||||
|
||||
md_content = md_path.read_text(encoding='utf-8')
|
||||
assert '你好世界' in md_content
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error handling"""
|
||||
|
||||
@patch('app.services.ocr_service.PaddleOCR')
|
||||
def test_process_image_with_none_results(self, mock_paddle_ocr, sample_image_path):
|
||||
"""Test processing when OCR returns None"""
|
||||
mock_engine = Mock()
|
||||
mock_engine.ocr.return_value = None
|
||||
mock_paddle_ocr.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
result = service.process_image(sample_image_path, detect_layout=False)
|
||||
|
||||
assert result['status'] == 'success'
|
||||
assert result['total_text_regions'] == 0
|
||||
|
||||
@patch('app.services.ocr_service.PaddleOCR')
|
||||
def test_process_image_with_custom_threshold(self, mock_paddle_ocr, sample_image_path):
|
||||
"""Test processing with custom confidence threshold"""
|
||||
# PaddleOCR 3.x format
|
||||
mock_ocr_results = [{
|
||||
'rec_texts': ['Text'],
|
||||
'rec_scores': [0.85],
|
||||
'rec_polys': [[[10, 10], [100, 10], [100, 30], [10, 30]]]
|
||||
}]
|
||||
|
||||
mock_engine = Mock()
|
||||
mock_engine.ocr.return_value = mock_ocr_results
|
||||
mock_paddle_ocr.return_value = mock_engine
|
||||
|
||||
service = OCRService()
|
||||
|
||||
# With high threshold - should filter out
|
||||
result_high = service.process_image(
|
||||
sample_image_path,
|
||||
detect_layout=False,
|
||||
confidence_threshold=0.90
|
||||
)
|
||||
assert result_high['total_text_regions'] == 0
|
||||
|
||||
# With low threshold - should include
|
||||
result_low = service.process_image(
|
||||
sample_image_path,
|
||||
detect_layout=False,
|
||||
confidence_threshold=0.80
|
||||
)
|
||||
assert result_low['total_text_regions'] == 1
|
||||
|
||||
|
||||
# Integration tests that require actual PaddleOCR models
|
||||
@pytest.mark.requires_models
|
||||
@pytest.mark.slow
|
||||
class TestOCRServiceIntegration:
|
||||
"""
|
||||
Integration tests that require actual PaddleOCR models
|
||||
These tests will download models (~900MB) on first run
|
||||
Run with: pytest -m requires_models
|
||||
"""
|
||||
|
||||
def test_real_ocr_engine_initialization(self):
|
||||
"""Test real PaddleOCR engine initialization"""
|
||||
service = OCRService()
|
||||
engine = service.get_ocr_engine(lang='en')
|
||||
|
||||
assert engine is not None
|
||||
assert hasattr(engine, 'ocr')
|
||||
|
||||
def test_real_structure_engine_initialization(self):
|
||||
"""Test real PP-Structure engine initialization"""
|
||||
service = OCRService()
|
||||
engine = service.get_structure_engine()
|
||||
|
||||
assert engine is not None
|
||||
|
||||
def test_real_image_processing(self, sample_image_with_text):
|
||||
"""Test processing real image with text"""
|
||||
service = OCRService()
|
||||
result = service.process_image(sample_image_with_text, lang='en')
|
||||
|
||||
assert result['status'] == 'success'
|
||||
assert result['total_text_regions'] > 0
|
||||
Reference in New Issue
Block a user