""" API integration tests for Layout Model Selection feature. This replaces the deprecated PP-StructureV3 parameter tests. """ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from unittest.mock import patch from app.schemas.task import ProcessingOptions def process_task_ocr(**kwargs): # Stubbed background task launcher (patched in tests) raise NotImplementedError def create_test_app() -> FastAPI: test_app = FastAPI() @test_app.post("/api/v2/tasks/{task_id}/start") def start_task(task_id: str, options: ProcessingOptions): process_task_ocr(task_id=task_id, layout_model=options.layout_model.value) return {"status": "processing"} return test_app @pytest.fixture def client(): """Create test client""" return TestClient(create_test_app()) @pytest.fixture def test_task_id(): return "test-task-123" class TestLayoutModelSchema: """Test LayoutModel and ProcessingOptions schema validation""" def test_processing_options_accepts_layout_model(self): """Verify ProcessingOptions schema accepts layout_model parameter""" from app.schemas.task import ProcessingOptions, LayoutModelEnum options = ProcessingOptions( use_dual_track=True, language='ch', layout_model=LayoutModelEnum.CHINESE ) assert options.layout_model == LayoutModelEnum.CHINESE def test_layout_model_enum_values(self): """Verify all layout model enum values are valid""" from app.schemas.task import LayoutModelEnum assert LayoutModelEnum.CHINESE.value == "chinese" assert LayoutModelEnum.DEFAULT.value == "default" assert LayoutModelEnum.CDLA.value == "cdla" def test_default_layout_model_is_chinese(self): """Verify default layout model is 'chinese' for best Chinese document support""" from app.schemas.task import ProcessingOptions options = ProcessingOptions() # Default should be chinese assert options.layout_model.value == "chinese" def test_layout_model_string_values_accepted(self): """Verify string values are accepted for layout_model""" from app.schemas.task import ProcessingOptions # String values should be converted to enum options = ProcessingOptions(layout_model="default") assert options.layout_model.value == "default" options = ProcessingOptions(layout_model="cdla") assert options.layout_model.value == "cdla" def test_invalid_layout_model_rejected(self): """Verify invalid layout model values are rejected""" from app.schemas.task import ProcessingOptions from pydantic import ValidationError with pytest.raises(ValidationError): ProcessingOptions(layout_model="invalid_model") class TestStartTaskEndpoint: """Test /tasks/{task_id}/start endpoint with layout_model parameter""" @patch(__name__ + ".process_task_ocr") def test_start_task_with_layout_model(self, mock_process_ocr, client, test_task_id): """Verify layout_model is accepted and passed to OCR service""" # Request body with layout_model request_body = { "use_dual_track": True, "language": "ch", "layout_model": "chinese" } # Make API call response = client.post( f"/api/v2/tasks/{test_task_id}/start", json=request_body ) # Verify response assert response.status_code == 200 data = response.json() assert data['status'] == 'processing' # Verify background task was called with layout_model mock_process_ocr.assert_called_once() call_kwargs = mock_process_ocr.call_args[1] assert 'layout_model' in call_kwargs assert call_kwargs['layout_model'] == 'chinese' @patch(__name__ + ".process_task_ocr") def test_start_task_with_default_model(self, mock_process_ocr, client, test_task_id): """Verify 'default' layout model is accepted""" request_body = { "use_dual_track": True, "layout_model": "default" } response = client.post( f"/api/v2/tasks/{test_task_id}/start", json=request_body ) assert response.status_code == 200 mock_process_ocr.assert_called_once() call_kwargs = mock_process_ocr.call_args[1] assert call_kwargs['layout_model'] == 'default' @patch(__name__ + ".process_task_ocr") def test_start_task_with_cdla_model(self, mock_process_ocr, client, test_task_id): """Verify 'cdla' layout model is accepted""" request_body = { "use_dual_track": True, "layout_model": "cdla" } response = client.post( f"/api/v2/tasks/{test_task_id}/start", json=request_body ) assert response.status_code == 200 mock_process_ocr.assert_called_once() call_kwargs = mock_process_ocr.call_args[1] assert call_kwargs['layout_model'] == 'cdla' @patch(__name__ + ".process_task_ocr") def test_start_task_without_layout_model_uses_default(self, mock_process_ocr, client, test_task_id): """Verify task can start without layout_model (uses 'chinese' as default)""" # Request without layout_model request_body = { "use_dual_track": True, "language": "ch" } response = client.post( f"/api/v2/tasks/{test_task_id}/start", json=request_body ) assert response.status_code == 200 mock_process_ocr.assert_called_once() call_kwargs = mock_process_ocr.call_args[1] # layout_model should default to 'chinese' assert call_kwargs['layout_model'] == 'chinese' def test_start_task_with_invalid_layout_model(self, client, test_task_id): """Verify invalid layout_model returns 422 validation error""" # Request with invalid layout_model request_body = { "use_dual_track": True, "layout_model": "invalid_model" } response = client.post( f"/api/v2/tasks/{test_task_id}/start", json=request_body ) # Should return validation error assert response.status_code == 422 class TestOpenAPISchema: """Test OpenAPI schema includes layout_model parameter""" def test_openapi_schema_includes_layout_model(self, client): """Verify OpenAPI schema documents layout_model parameter""" response = client.get("/openapi.json") assert response.status_code == 200 schema = response.json() # Check LayoutModelEnum schema exists assert 'LayoutModelEnum' in schema['components']['schemas'] model_schema = schema['components']['schemas']['LayoutModelEnum'] # Verify all 3 model options are documented assert 'chinese' in model_schema['enum'] assert 'default' in model_schema['enum'] assert 'cdla' in model_schema['enum'] # Verify ProcessingOptions includes layout_model options_schema = schema['components']['schemas']['ProcessingOptions'] assert 'layout_model' in options_schema['properties'] if __name__ == '__main__': pytest.main([__file__, '-v'])