""" Test suite for rate limiting functionality. Tests the rate limiting feature on the login endpoint to ensure protection against brute force attacks. """ import pytest from unittest.mock import patch, MagicMock, AsyncMock from app.services.auth_client import AuthAPIError class TestRateLimiting: """Test rate limiting on the login endpoint.""" def test_login_rate_limit_exceeded(self, client): """ Test that the login endpoint returns 429 after exceeding rate limit. GIVEN a client IP has made 5 login attempts within 1 minute WHEN the client attempts another login THEN the system returns HTTP 429 Too Many Requests AND the response includes a Retry-After header """ # Mock the external auth service to return auth error with patch("app.api.auth.router.verify_credentials", new_callable=AsyncMock) as mock_verify: mock_verify.side_effect = AuthAPIError("Invalid credentials") login_data = {"email": "test@example.com", "password": "wrongpassword"} # Make 5 requests (the limit) for i in range(5): response = client.post("/api/auth/login", json=login_data) # These should fail due to invalid credentials (401), but not rate limit assert response.status_code == 401, f"Request {i+1} expected 401, got {response.status_code}" # The 6th request should be rate limited response = client.post("/api/auth/login", json=login_data) assert response.status_code == 429, f"Expected 429 Too Many Requests, got {response.status_code}" # Response should contain error details data = response.json() assert "error" in data or "detail" in data, "Response should contain error details" def test_login_within_rate_limit(self, client): """ Test that requests within the rate limit are allowed. GIVEN a client IP has not exceeded the rate limit WHEN the client makes login requests THEN the requests are processed normally (not rate limited) """ with patch("app.api.auth.router.verify_credentials", new_callable=AsyncMock) as mock_verify: mock_verify.side_effect = AuthAPIError("Invalid credentials") login_data = {"email": "test@example.com", "password": "wrongpassword"} # Make requests within the limit for i in range(3): response = client.post("/api/auth/login", json=login_data) # These should fail due to invalid credentials (401), but not be rate limited assert response.status_code == 401, f"Request {i+1} expected 401, got {response.status_code}" def test_rate_limit_response_format(self, client): """ Test that the 429 response format matches API standards. GIVEN the rate limit has been exceeded WHEN the client receives a 429 response THEN the response body contains appropriate error information """ with patch("app.api.auth.router.verify_credentials", new_callable=AsyncMock) as mock_verify: mock_verify.side_effect = AuthAPIError("Invalid credentials") login_data = {"email": "test@example.com", "password": "wrongpassword"} # Exhaust the rate limit for _ in range(5): client.post("/api/auth/login", json=login_data) # The next request should be rate limited response = client.post("/api/auth/login", json=login_data) assert response.status_code == 429 # Check response body contains error information data = response.json() assert "error" in data or "detail" in data, "Response should contain error details" class TestRateLimiterConfiguration: """Test rate limiter configuration.""" def test_limiter_uses_redis_storage(self): """ Test that the limiter is configured with Redis storage. GIVEN the rate limiter configuration WHEN we inspect the storage URI THEN it should be configured to use Redis """ from app.core.rate_limiter import limiter from app.core.config import settings # The limiter should be configured assert limiter is not None # Verify Redis URL is properly configured assert settings.REDIS_URL.startswith("redis://") def test_limiter_uses_remote_address_key(self): """ Test that the limiter uses client IP as the key. GIVEN the rate limiter configuration WHEN we check the key function THEN it should use get_remote_address """ from app.core.rate_limiter import limiter from slowapi.util import get_remote_address # The key function should be get_remote_address assert limiter._key_func == get_remote_address