Implemented proposals from comprehensive QA review: 1. extend-csrf-protection - Add POST to CSRF protected methods in frontend - Global CSRF middleware for all state-changing operations - Update tests with CSRF token fixtures 2. tighten-cors-websocket-security - Replace wildcard CORS with explicit method/header lists - Disable query parameter auth in production (code 4002) - Add per-user WebSocket connection limit (max 5, code 4005) 3. shorten-jwt-expiry - Reduce JWT expiry from 7 days to 60 minutes - Add refresh token support with 7-day expiry - Implement token rotation on refresh - Frontend auto-refresh when token near expiry (<5 min) 4. fix-frontend-quality - Add React.lazy() code splitting for all pages - Fix useCallback dependency arrays (Dashboard, Comments) - Add localStorage data validation in AuthContext - Complete i18n for AttachmentUpload component 5. enhance-backend-validation - Add SecurityAuditMiddleware for access denied logging - Add ErrorSanitizerMiddleware for production error messages - Protect /health/detailed with admin authentication - Add input length validation (comment 5000, desc 10000) All 521 backend tests passing. Frontend builds successfully. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
279 lines
9.3 KiB
Python
279 lines
9.3 KiB
Python
import pytest
|
|
from app.core.security import (
|
|
create_access_token,
|
|
decode_access_token,
|
|
create_token_payload,
|
|
generate_refresh_token,
|
|
store_refresh_token,
|
|
validate_refresh_token,
|
|
invalidate_refresh_token,
|
|
invalidate_all_user_refresh_tokens,
|
|
decode_refresh_token_user_id,
|
|
get_refresh_token_key,
|
|
)
|
|
|
|
|
|
class TestJWT:
|
|
"""Test JWT token creation and validation."""
|
|
|
|
def test_create_access_token(self):
|
|
"""Test creating an access token."""
|
|
data = {"sub": "user123", "email": "test@example.com"}
|
|
token = create_access_token(data)
|
|
|
|
assert token is not None
|
|
assert isinstance(token, str)
|
|
|
|
def test_decode_valid_token(self):
|
|
"""Test decoding a valid token."""
|
|
data = create_token_payload(
|
|
user_id="user123",
|
|
email="test@example.com",
|
|
role="engineer",
|
|
department_id="dept123",
|
|
is_system_admin=False,
|
|
)
|
|
token = create_access_token(data)
|
|
payload = decode_access_token(token)
|
|
|
|
assert payload is not None
|
|
assert payload["sub"] == "user123"
|
|
assert payload["email"] == "test@example.com"
|
|
assert payload["role"] == "engineer"
|
|
assert payload["is_system_admin"] is False
|
|
|
|
def test_decode_invalid_token(self):
|
|
"""Test decoding an invalid token."""
|
|
payload = decode_access_token("invalid.token.here")
|
|
assert payload is None
|
|
|
|
def test_token_payload_structure(self):
|
|
"""Test token payload has correct structure."""
|
|
payload = create_token_payload(
|
|
user_id="user123",
|
|
email="test@example.com",
|
|
role="engineer",
|
|
department_id="dept123",
|
|
is_system_admin=False,
|
|
)
|
|
|
|
assert "sub" in payload
|
|
assert "email" in payload
|
|
assert "role" in payload
|
|
assert "department_id" in payload
|
|
assert "is_system_admin" in payload
|
|
|
|
|
|
class TestAuthEndpoints:
|
|
"""Test authentication API endpoints."""
|
|
|
|
def test_get_me_without_auth(self, client):
|
|
"""Test accessing /me without authentication."""
|
|
response = client.get("/api/auth/me")
|
|
assert response.status_code == 401 # 401 for unauthenticated, 403 for unauthorized
|
|
|
|
def test_get_me_with_auth(self, client, admin_token):
|
|
"""Test accessing /me with valid authentication."""
|
|
response = client.get(
|
|
"/api/auth/me",
|
|
headers={"Authorization": f"Bearer {admin_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["email"] == "ymirliu@panjit.com.tw"
|
|
assert data["is_system_admin"] is True
|
|
|
|
def test_logout(self, client, auth_headers, mock_redis):
|
|
"""Test logout endpoint."""
|
|
response = client.post(
|
|
"/api/auth/logout",
|
|
headers=auth_headers,
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
# Verify session is removed
|
|
assert mock_redis.get("session:00000000-0000-0000-0000-000000000001") is None
|
|
|
|
|
|
class TestRefreshToken:
|
|
"""Test refresh token functionality."""
|
|
|
|
def test_generate_refresh_token(self):
|
|
"""Test that refresh tokens are generated correctly."""
|
|
token = generate_refresh_token()
|
|
assert token is not None
|
|
assert isinstance(token, str)
|
|
assert len(token) > 20 # URL-safe base64 encoded 32 bytes
|
|
|
|
def test_generate_unique_refresh_tokens(self):
|
|
"""Test that each generated token is unique."""
|
|
tokens = [generate_refresh_token() for _ in range(100)]
|
|
assert len(set(tokens)) == 100 # All tokens should be unique
|
|
|
|
def test_store_and_validate_refresh_token(self, mock_redis):
|
|
"""Test storing and validating refresh tokens."""
|
|
user_id = "test-user-123"
|
|
token = generate_refresh_token()
|
|
|
|
# Store the token
|
|
store_refresh_token(mock_redis, user_id, token)
|
|
|
|
# Validate the token
|
|
assert validate_refresh_token(mock_redis, user_id, token) is True
|
|
|
|
# Wrong user should fail
|
|
assert validate_refresh_token(mock_redis, "wrong-user", token) is False
|
|
|
|
# Wrong token should fail
|
|
assert validate_refresh_token(mock_redis, user_id, "wrong-token") is False
|
|
|
|
def test_invalidate_refresh_token(self, mock_redis):
|
|
"""Test invalidating a refresh token."""
|
|
user_id = "test-user-123"
|
|
token = generate_refresh_token()
|
|
|
|
# Store and verify
|
|
store_refresh_token(mock_redis, user_id, token)
|
|
assert validate_refresh_token(mock_redis, user_id, token) is True
|
|
|
|
# Invalidate
|
|
result = invalidate_refresh_token(mock_redis, user_id, token)
|
|
assert result is True
|
|
|
|
# Should no longer be valid
|
|
assert validate_refresh_token(mock_redis, user_id, token) is False
|
|
|
|
def test_invalidate_all_user_refresh_tokens(self, mock_redis):
|
|
"""Test invalidating all refresh tokens for a user."""
|
|
user_id = "test-user-123"
|
|
tokens = [generate_refresh_token() for _ in range(3)]
|
|
|
|
# Store multiple tokens
|
|
for token in tokens:
|
|
store_refresh_token(mock_redis, user_id, token)
|
|
|
|
# Verify all are valid
|
|
for token in tokens:
|
|
assert validate_refresh_token(mock_redis, user_id, token) is True
|
|
|
|
# Invalidate all
|
|
count = invalidate_all_user_refresh_tokens(mock_redis, user_id)
|
|
assert count == 3
|
|
|
|
# All should be invalid now
|
|
for token in tokens:
|
|
assert validate_refresh_token(mock_redis, user_id, token) is False
|
|
|
|
def test_decode_refresh_token_user_id(self, mock_redis):
|
|
"""Test finding user ID from refresh token."""
|
|
user_id = "test-user-456"
|
|
token = generate_refresh_token()
|
|
|
|
# Store the token
|
|
store_refresh_token(mock_redis, user_id, token)
|
|
|
|
# Find user ID
|
|
found_user_id = decode_refresh_token_user_id(token, mock_redis)
|
|
assert found_user_id == user_id
|
|
|
|
# Invalid token should return None
|
|
assert decode_refresh_token_user_id("invalid-token", mock_redis) is None
|
|
|
|
|
|
class TestRefreshTokenEndpoint:
|
|
"""Test the refresh token API endpoint."""
|
|
|
|
def test_refresh_token_success(self, client, db, mock_redis):
|
|
"""Test successful token refresh."""
|
|
user_id = "00000000-0000-0000-0000-000000000001"
|
|
|
|
# Generate and store a refresh token
|
|
refresh_token = generate_refresh_token()
|
|
store_refresh_token(mock_redis, user_id, refresh_token)
|
|
|
|
# Call refresh endpoint
|
|
response = client.post(
|
|
"/api/auth/refresh",
|
|
json={"refresh_token": refresh_token},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "access_token" in data
|
|
assert "refresh_token" in data
|
|
assert data["token_type"] == "bearer"
|
|
assert data["expires_in"] > 0
|
|
|
|
# Old refresh token should be invalidated (rotation)
|
|
assert validate_refresh_token(mock_redis, user_id, refresh_token) is False
|
|
|
|
# New refresh token should be valid
|
|
assert validate_refresh_token(mock_redis, user_id, data["refresh_token"]) is True
|
|
|
|
def test_refresh_token_invalid(self, client, mock_redis):
|
|
"""Test refresh with invalid token."""
|
|
response = client.post(
|
|
"/api/auth/refresh",
|
|
json={"refresh_token": "invalid-token"},
|
|
)
|
|
|
|
assert response.status_code == 401
|
|
assert "Invalid or expired refresh token" in response.json()["detail"]
|
|
|
|
def test_refresh_token_rotation(self, client, db, mock_redis):
|
|
"""Test that refresh tokens are rotated (old one invalidated)."""
|
|
user_id = "00000000-0000-0000-0000-000000000001"
|
|
|
|
# Generate and store initial refresh token
|
|
initial_token = generate_refresh_token()
|
|
store_refresh_token(mock_redis, user_id, initial_token)
|
|
|
|
# First refresh
|
|
response1 = client.post(
|
|
"/api/auth/refresh",
|
|
json={"refresh_token": initial_token},
|
|
)
|
|
assert response1.status_code == 200
|
|
new_token = response1.json()["refresh_token"]
|
|
|
|
# Try to reuse the old token (should fail due to rotation)
|
|
response2 = client.post(
|
|
"/api/auth/refresh",
|
|
json={"refresh_token": initial_token},
|
|
)
|
|
assert response2.status_code == 401
|
|
|
|
# New token should still work
|
|
response3 = client.post(
|
|
"/api/auth/refresh",
|
|
json={"refresh_token": new_token},
|
|
)
|
|
assert response3.status_code == 200
|
|
|
|
def test_refresh_token_disabled_user(self, client, db, mock_redis):
|
|
"""Test that disabled users cannot refresh tokens."""
|
|
from app.models.user import User
|
|
|
|
# Create a disabled user
|
|
disabled_user = User(
|
|
id="disabled-user-123",
|
|
email="disabled@example.com",
|
|
name="Disabled User",
|
|
is_active=False,
|
|
)
|
|
db.add(disabled_user)
|
|
db.commit()
|
|
|
|
# Generate and store refresh token for disabled user
|
|
refresh_token = generate_refresh_token()
|
|
store_refresh_token(mock_redis, disabled_user.id, refresh_token)
|
|
|
|
# Try to refresh
|
|
response = client.post(
|
|
"/api/auth/refresh",
|
|
json={"refresh_token": refresh_token},
|
|
)
|
|
|
|
assert response.status_code == 403
|
|
assert "disabled" in response.json()["detail"].lower()
|