import pytest from app.core.security import ( create_access_token, decode_access_token, create_token_payload, generate_refresh_token, store_refresh_token, validate_refresh_token, invalidate_refresh_token, invalidate_all_user_refresh_tokens, decode_refresh_token_user_id, get_refresh_token_key, ) class TestJWT: """Test JWT token creation and validation.""" def test_create_access_token(self): """Test creating an access token.""" data = {"sub": "user123", "email": "test@example.com"} token = create_access_token(data) assert token is not None assert isinstance(token, str) def test_decode_valid_token(self): """Test decoding a valid token.""" data = create_token_payload( user_id="user123", email="test@example.com", role="engineer", department_id="dept123", is_system_admin=False, ) token = create_access_token(data) payload = decode_access_token(token) assert payload is not None assert payload["sub"] == "user123" assert payload["email"] == "test@example.com" assert payload["role"] == "engineer" assert payload["is_system_admin"] is False def test_decode_invalid_token(self): """Test decoding an invalid token.""" payload = decode_access_token("invalid.token.here") assert payload is None def test_token_payload_structure(self): """Test token payload has correct structure.""" payload = create_token_payload( user_id="user123", email="test@example.com", role="engineer", department_id="dept123", is_system_admin=False, ) assert "sub" in payload assert "email" in payload assert "role" in payload assert "department_id" in payload assert "is_system_admin" in payload class TestAuthEndpoints: """Test authentication API endpoints.""" def test_get_me_without_auth(self, client): """Test accessing /me without authentication.""" response = client.get("/api/auth/me") assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized def test_get_me_with_auth(self, client, admin_token): """Test accessing /me with valid authentication.""" response = client.get( "/api/auth/me", headers={"Authorization": f"Bearer {admin_token}"}, ) assert response.status_code == 200 data = response.json() assert data["email"] == "ymirliu@panjit.com.tw" assert data["is_system_admin"] is True def test_logout(self, client, auth_headers, mock_redis): """Test logout endpoint.""" response = client.post( "/api/auth/logout", headers=auth_headers, ) assert response.status_code == 200 # Verify session is removed assert mock_redis.get("session:00000000-0000-0000-0000-000000000001") is None class TestRefreshToken: """Test refresh token functionality.""" def test_generate_refresh_token(self): """Test that refresh tokens are generated correctly.""" token = generate_refresh_token() assert token is not None assert isinstance(token, str) assert len(token) > 20 # URL-safe base64 encoded 32 bytes def test_generate_unique_refresh_tokens(self): """Test that each generated token is unique.""" tokens = [generate_refresh_token() for _ in range(100)] assert len(set(tokens)) == 100 # All tokens should be unique def test_store_and_validate_refresh_token(self, mock_redis): """Test storing and validating refresh tokens.""" user_id = "test-user-123" token = generate_refresh_token() # Store the token store_refresh_token(mock_redis, user_id, token) # Validate the token assert validate_refresh_token(mock_redis, user_id, token) is True # Wrong user should fail assert validate_refresh_token(mock_redis, "wrong-user", token) is False # Wrong token should fail assert validate_refresh_token(mock_redis, user_id, "wrong-token") is False def test_invalidate_refresh_token(self, mock_redis): """Test invalidating a refresh token.""" user_id = "test-user-123" token = generate_refresh_token() # Store and verify store_refresh_token(mock_redis, user_id, token) assert validate_refresh_token(mock_redis, user_id, token) is True # Invalidate result = invalidate_refresh_token(mock_redis, user_id, token) assert result is True # Should no longer be valid assert validate_refresh_token(mock_redis, user_id, token) is False def test_invalidate_all_user_refresh_tokens(self, mock_redis): """Test invalidating all refresh tokens for a user.""" user_id = "test-user-123" tokens = [generate_refresh_token() for _ in range(3)] # Store multiple tokens for token in tokens: store_refresh_token(mock_redis, user_id, token) # Verify all are valid for token in tokens: assert validate_refresh_token(mock_redis, user_id, token) is True # Invalidate all count = invalidate_all_user_refresh_tokens(mock_redis, user_id) assert count == 3 # All should be invalid now for token in tokens: assert validate_refresh_token(mock_redis, user_id, token) is False def test_decode_refresh_token_user_id(self, mock_redis): """Test finding user ID from refresh token.""" user_id = "test-user-456" token = generate_refresh_token() # Store the token store_refresh_token(mock_redis, user_id, token) # Find user ID found_user_id = decode_refresh_token_user_id(token, mock_redis) assert found_user_id == user_id # Invalid token should return None assert decode_refresh_token_user_id("invalid-token", mock_redis) is None class TestRefreshTokenEndpoint: """Test the refresh token API endpoint.""" def test_refresh_token_success(self, client, db, mock_redis): """Test successful token refresh.""" user_id = "00000000-0000-0000-0000-000000000001" # Generate and store a refresh token refresh_token = generate_refresh_token() store_refresh_token(mock_redis, user_id, refresh_token) # Call refresh endpoint response = client.post( "/api/auth/refresh", json={"refresh_token": refresh_token}, ) assert response.status_code == 200 data = response.json() assert "access_token" in data assert "refresh_token" in data assert data["token_type"] == "bearer" assert data["expires_in"] > 0 # Old refresh token should be invalidated (rotation) assert validate_refresh_token(mock_redis, user_id, refresh_token) is False # New refresh token should be valid assert validate_refresh_token(mock_redis, user_id, data["refresh_token"]) is True def test_refresh_token_invalid(self, client, mock_redis): """Test refresh with invalid token.""" response = client.post( "/api/auth/refresh", json={"refresh_token": "invalid-token"}, ) assert response.status_code == 401 assert "Invalid or expired refresh token" in response.json()["detail"] def test_refresh_token_rotation(self, client, db, mock_redis): """Test that refresh tokens are rotated (old one invalidated).""" user_id = "00000000-0000-0000-0000-000000000001" # Generate and store initial refresh token initial_token = generate_refresh_token() store_refresh_token(mock_redis, user_id, initial_token) # First refresh response1 = client.post( "/api/auth/refresh", json={"refresh_token": initial_token}, ) assert response1.status_code == 200 new_token = response1.json()["refresh_token"] # Try to reuse the old token (should fail due to rotation) response2 = client.post( "/api/auth/refresh", json={"refresh_token": initial_token}, ) assert response2.status_code == 401 # New token should still work response3 = client.post( "/api/auth/refresh", json={"refresh_token": new_token}, ) assert response3.status_code == 200 def test_refresh_token_disabled_user(self, client, db, mock_redis): """Test that disabled users cannot refresh tokens.""" from app.models.user import User # Create a disabled user disabled_user = User( id="disabled-user-123", email="disabled@example.com", name="Disabled User", is_active=False, ) db.add(disabled_user) db.commit() # Generate and store refresh token for disabled user refresh_token = generate_refresh_token() store_refresh_token(mock_redis, disabled_user.id, refresh_token) # Try to refresh response = client.post( "/api/auth/refresh", json={"refresh_token": refresh_token}, ) assert response.status_code == 403 assert "disabled" in response.json()["detail"].lower()