""" API integration tests for Layout Model Selection feature. This replaces the deprecated PP-StructureV3 parameter tests. """ import pytest from fastapi.testclient import TestClient from unittest.mock import patch from app.main import app from app.core.database import get_db from app.models.user import User from app.models.task import Task, TaskStatus, TaskFile @pytest.fixture def client(): """Create test client""" return TestClient(app) @pytest.fixture def test_user(db_session): """Create test user""" user = User( email="test@example.com", hashed_password="test_hash", is_active=True ) db_session.add(user) db_session.commit() db_session.refresh(user) return user @pytest.fixture def test_task(db_session, test_user): """Create test task with uploaded file""" task = Task( user_id=test_user.id, task_id="test-task-123", filename="test.pdf", status=TaskStatus.PENDING ) db_session.add(task) db_session.commit() db_session.refresh(task) # Add task file task_file = TaskFile( task_id=task.id, original_name="test.pdf", stored_path="/tmp/test.pdf", file_size=1024, mime_type="application/pdf" ) db_session.add(task_file) db_session.commit() return task class TestLayoutModelSchema: """Test LayoutModel and ProcessingOptions schema validation""" def test_processing_options_accepts_layout_model(self): """Verify ProcessingOptions schema accepts layout_model parameter""" from app.schemas.task import ProcessingOptions, LayoutModelEnum options = ProcessingOptions( use_dual_track=True, language='ch', layout_model=LayoutModelEnum.CHINESE ) assert options.layout_model == LayoutModelEnum.CHINESE def test_layout_model_enum_values(self): """Verify all layout model enum values are valid""" from app.schemas.task import LayoutModelEnum assert LayoutModelEnum.CHINESE.value == "chinese" assert LayoutModelEnum.DEFAULT.value == "default" assert LayoutModelEnum.CDLA.value == "cdla" def test_default_layout_model_is_chinese(self): """Verify default layout model is 'chinese' for best Chinese document support""" from app.schemas.task import ProcessingOptions options = ProcessingOptions() # Default should be chinese assert options.layout_model.value == "chinese" def test_layout_model_string_values_accepted(self): """Verify string values are accepted for layout_model""" from app.schemas.task import ProcessingOptions # String values should be converted to enum options = ProcessingOptions(layout_model="default") assert options.layout_model.value == "default" options = ProcessingOptions(layout_model="cdla") assert options.layout_model.value == "cdla" def test_invalid_layout_model_rejected(self): """Verify invalid layout model values are rejected""" from app.schemas.task import ProcessingOptions from pydantic import ValidationError with pytest.raises(ValidationError): ProcessingOptions(layout_model="invalid_model") class TestStartTaskEndpoint: """Test /tasks/{task_id}/start endpoint with layout_model parameter""" @patch('app.routers.tasks.process_task_ocr') def test_start_task_with_layout_model(self, mock_process_ocr, client, test_task, db_session): """Verify layout_model is accepted and passed to OCR service""" # Override get_db dependency def override_get_db(): try: yield db_session finally: pass # Override auth dependency def override_get_current_user(): return test_task.user app.dependency_overrides[get_db] = override_get_db from app.core.deps import get_current_user app.dependency_overrides[get_current_user] = override_get_current_user # Request body with layout_model request_body = { "use_dual_track": True, "language": "ch", "layout_model": "chinese" } # Make API call response = client.post( f"/api/v2/tasks/{test_task.task_id}/start", json=request_body ) # Verify response assert response.status_code == 200 data = response.json() assert data['status'] == 'processing' # Verify background task was called with layout_model mock_process_ocr.assert_called_once() call_kwargs = mock_process_ocr.call_args[1] assert 'layout_model' in call_kwargs assert call_kwargs['layout_model'] == 'chinese' # Clean up app.dependency_overrides.clear() @patch('app.routers.tasks.process_task_ocr') def test_start_task_with_default_model(self, mock_process_ocr, client, test_task, db_session): """Verify 'default' layout model is accepted""" def override_get_db(): try: yield db_session finally: pass def override_get_current_user(): return test_task.user app.dependency_overrides[get_db] = override_get_db from app.core.deps import get_current_user app.dependency_overrides[get_current_user] = override_get_current_user request_body = { "use_dual_track": True, "layout_model": "default" } response = client.post( f"/api/v2/tasks/{test_task.task_id}/start", json=request_body ) assert response.status_code == 200 mock_process_ocr.assert_called_once() call_kwargs = mock_process_ocr.call_args[1] assert call_kwargs['layout_model'] == 'default' app.dependency_overrides.clear() @patch('app.routers.tasks.process_task_ocr') def test_start_task_with_cdla_model(self, mock_process_ocr, client, test_task, db_session): """Verify 'cdla' layout model is accepted""" def override_get_db(): try: yield db_session finally: pass def override_get_current_user(): return test_task.user app.dependency_overrides[get_db] = override_get_db from app.core.deps import get_current_user app.dependency_overrides[get_current_user] = override_get_current_user request_body = { "use_dual_track": True, "layout_model": "cdla" } response = client.post( f"/api/v2/tasks/{test_task.task_id}/start", json=request_body ) assert response.status_code == 200 mock_process_ocr.assert_called_once() call_kwargs = mock_process_ocr.call_args[1] assert call_kwargs['layout_model'] == 'cdla' app.dependency_overrides.clear() @patch('app.routers.tasks.process_task_ocr') def test_start_task_without_layout_model_uses_default(self, mock_process_ocr, client, test_task, db_session): """Verify task can start without layout_model (uses 'chinese' as default)""" def override_get_db(): try: yield db_session finally: pass def override_get_current_user(): return test_task.user app.dependency_overrides[get_db] = override_get_db from app.core.deps import get_current_user app.dependency_overrides[get_current_user] = override_get_current_user # Request without layout_model request_body = { "use_dual_track": True, "language": "ch" } response = client.post( f"/api/v2/tasks/{test_task.task_id}/start", json=request_body ) assert response.status_code == 200 mock_process_ocr.assert_called_once() call_kwargs = mock_process_ocr.call_args[1] # layout_model should default to 'chinese' assert call_kwargs['layout_model'] == 'chinese' app.dependency_overrides.clear() def test_start_task_with_invalid_layout_model(self, client, test_task, db_session): """Verify invalid layout_model returns 422 validation error""" def override_get_db(): try: yield db_session finally: pass def override_get_current_user(): return test_task.user app.dependency_overrides[get_db] = override_get_db from app.core.deps import get_current_user app.dependency_overrides[get_current_user] = override_get_current_user # Request with invalid layout_model request_body = { "use_dual_track": True, "layout_model": "invalid_model" } response = client.post( f"/api/v2/tasks/{test_task.task_id}/start", json=request_body ) # Should return validation error assert response.status_code == 422 app.dependency_overrides.clear() class TestOpenAPISchema: """Test OpenAPI schema includes layout_model parameter""" def test_openapi_schema_includes_layout_model(self, client): """Verify OpenAPI schema documents layout_model parameter""" response = client.get("/openapi.json") assert response.status_code == 200 schema = response.json() # Check LayoutModelEnum schema exists assert 'LayoutModelEnum' in schema['components']['schemas'] model_schema = schema['components']['schemas']['LayoutModelEnum'] # Verify all 3 model options are documented assert 'chinese' in model_schema['enum'] assert 'default' in model_schema['enum'] assert 'cdla' in model_schema['enum'] # Verify ProcessingOptions includes layout_model options_schema = schema['components']['schemas']['ProcessingOptions'] assert 'layout_model' in options_schema['properties'] if __name__ == '__main__': pytest.main([__file__, '-v'])