Fix test failures and workload/websocket behavior
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.core import database
|
||||
from app.core.security import decode_access_token
|
||||
from app.core.redis import get_redis_sync
|
||||
from app.models import User, Notification, Project
|
||||
@@ -22,6 +23,8 @@ PONG_TIMEOUT = 30.0 # Disconnect if no pong received within this time after pi
|
||||
|
||||
# Authentication timeout (10 seconds)
|
||||
AUTH_TIMEOUT = 10.0
|
||||
if os.getenv("TESTING") == "true":
|
||||
AUTH_TIMEOUT = 1.0
|
||||
|
||||
|
||||
async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
@@ -41,7 +44,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
return None, None
|
||||
|
||||
# Get user from database
|
||||
db = SessionLocal()
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None or not user.is_active:
|
||||
@@ -54,7 +57,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
|
||||
async def authenticate_websocket(
|
||||
websocket: WebSocket,
|
||||
query_token: Optional[str] = None
|
||||
) -> tuple[str | None, User | None]:
|
||||
) -> tuple[str | None, User | None, Optional[str]]:
|
||||
"""
|
||||
Authenticate WebSocket connection.
|
||||
|
||||
@@ -72,7 +75,10 @@ async def authenticate_websocket(
|
||||
"WebSocket authentication via query parameter is deprecated. "
|
||||
"Please use first-message authentication for better security."
|
||||
)
|
||||
return await get_user_from_token(query_token)
|
||||
user_id, user = await get_user_from_token(query_token)
|
||||
if user_id is None:
|
||||
return None, None, "invalid_token"
|
||||
return user_id, user, None
|
||||
|
||||
# Wait for authentication message with timeout
|
||||
try:
|
||||
@@ -84,26 +90,29 @@ async def authenticate_websocket(
|
||||
msg_type = data.get("type")
|
||||
if msg_type != "auth":
|
||||
logger.warning("Expected 'auth' message type, got: %s", msg_type)
|
||||
return None, None
|
||||
return None, None, "invalid_message"
|
||||
|
||||
token = data.get("token")
|
||||
if not token:
|
||||
logger.warning("No token provided in auth message")
|
||||
return None, None
|
||||
return None, None, "missing_token"
|
||||
|
||||
return await get_user_from_token(token)
|
||||
user_id, user = await get_user_from_token(token)
|
||||
if user_id is None:
|
||||
return None, None, "invalid_token"
|
||||
return user_id, user, None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
|
||||
return None, None
|
||||
return None, None, "timeout"
|
||||
except Exception as e:
|
||||
logger.error("Error during WebSocket authentication: %s", e)
|
||||
return None, None
|
||||
return None, None, "error"
|
||||
|
||||
|
||||
async def get_unread_notifications(user_id: str) -> list[dict]:
|
||||
"""Query all unread notifications for a user."""
|
||||
db = SessionLocal()
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
notifications = (
|
||||
db.query(Notification)
|
||||
@@ -130,7 +139,7 @@ async def get_unread_notifications(user_id: str) -> list[dict]:
|
||||
|
||||
async def get_unread_count(user_id: str) -> int:
|
||||
"""Get the count of unread notifications for a user."""
|
||||
db = SessionLocal()
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
return (
|
||||
db.query(Notification)
|
||||
@@ -174,14 +183,12 @@ async def websocket_notifications(
|
||||
# Accept WebSocket connection first
|
||||
await websocket.accept()
|
||||
|
||||
# If no query token, notify client that auth is required
|
||||
if not token:
|
||||
await websocket.send_json({"type": "auth_required"})
|
||||
|
||||
# Authenticate
|
||||
user_id, user = await authenticate_websocket(websocket, token)
|
||||
user_id, user, error_reason = await authenticate_websocket(websocket, token)
|
||||
|
||||
if user_id is None:
|
||||
if error_reason == "invalid_token":
|
||||
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
return
|
||||
|
||||
@@ -311,7 +318,7 @@ async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Pr
|
||||
Returns:
|
||||
Tuple of (has_access: bool, project: Project | None)
|
||||
"""
|
||||
db = SessionLocal()
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
# Get the user
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
@@ -365,14 +372,12 @@ async def websocket_project_sync(
|
||||
# Accept WebSocket connection first
|
||||
await websocket.accept()
|
||||
|
||||
# If no query token, notify client that auth is required
|
||||
if not token:
|
||||
await websocket.send_json({"type": "auth_required"})
|
||||
|
||||
# Authenticate user
|
||||
user_id, user = await authenticate_websocket(websocket, token)
|
||||
user_id, user, error_reason = await authenticate_websocket(websocket, token)
|
||||
|
||||
if user_id is None:
|
||||
if error_reason == "invalid_token":
|
||||
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user