"""Authentication middleware for protected routes 自動處理: 1. Token 驗證 2. 3 天不活動逾時檢查 3. AD token 自動刷新(5 分鐘內過期時) 4. 重試計數器管理(最多 3 次) """ from fastapi import Request, HTTPException, status from datetime import datetime, timedelta from app.core.database import SessionLocal from app.core.config import get_settings from app.modules.auth.services.session_service import session_service from app.modules.auth.services.encryption import encryption_service from app.modules.auth.services.ad_client import ad_auth_service import logging settings = get_settings() logger = logging.getLogger(__name__) class AuthMiddleware: """Authentication middleware""" async def __call__(self, request: Request, call_next): """Process request through authentication checks""" # Skip auth for login/logout endpoints if request.url.path in ["/api/auth/login", "/api/auth/logout", "/docs", "/openapi.json"]: return await call_next(request) # Extract token from Authorization header authorization = request.headers.get("Authorization") if not authorization or not authorization.startswith("Bearer "): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required" ) internal_token = authorization.replace("Bearer ", "") # Get database session db = SessionLocal() try: # Query session user_session = session_service.get_session_by_token(db, internal_token) if not user_session: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token" ) # Check 3-day inactivity timeout inactivity_limit = datetime.utcnow() - timedelta(days=settings.SESSION_INACTIVITY_DAYS) if user_session.last_activity < inactivity_limit: session_service.delete_session(db, user_session.id) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Session expired due to inactivity. Please login again.", ) # Check if refresh attempts exceeded if user_session.refresh_attempt_count >= settings.MAX_REFRESH_ATTEMPTS: session_service.delete_session(db, user_session.id) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Session expired due to authentication failures. Please login again.", ) # Check if AD token needs refresh (< 5 minutes until expiry) time_until_expiry = user_session.ad_token_expires_at - datetime.utcnow() if time_until_expiry < timedelta(minutes=settings.TOKEN_REFRESH_THRESHOLD_MINUTES): # Auto-refresh AD token await self._refresh_ad_token(db, user_session) # Update last_activity session_service.update_activity(db, user_session.id) # Attach user info to request state request.state.user = { "id": user_session.id, "username": user_session.username, "display_name": user_session.display_name, } finally: db.close() return await call_next(request) async def _refresh_ad_token(self, db, user_session): """Auto-refresh AD token using stored encrypted password""" try: # Decrypt password password = encryption_service.decrypt_password(user_session.encrypted_password) # Re-authenticate with AD API ad_result = await ad_auth_service.authenticate(user_session.username, password) # Update session with new token session_service.update_ad_token( db, user_session.id, ad_result["token"], ad_result["expires_at"] ) logger.info(f"AD token refreshed successfully for user: {user_session.username}") except (ValueError, ConnectionError) as e: # Refresh failed, increment counter new_count = session_service.increment_refresh_attempts(db, user_session.id) logger.warning( f"AD token refresh failed for user {user_session.username}. " f"Attempt {new_count}/{settings.MAX_REFRESH_ATTEMPTS}" ) # If reached max attempts, delete session if new_count >= settings.MAX_REFRESH_ATTEMPTS: session_service.delete_session(db, user_session.id) logger.error( f"Session terminated for {user_session.username} after {new_count} failed refresh attempts" ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Session terminated. Your password may have been changed. Please login again.", ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token refresh failed. Please try again or re-login if issue persists.", ) auth_middleware = AuthMiddleware()