Fix test failures and workload/websocket behavior

This commit is contained in:
beabigegg
2026-01-11 08:37:21 +08:00
parent 3bdc6ff1c9
commit f5f870da56
49 changed files with 3006 additions and 1132 deletions

View File

@@ -1,11 +1,12 @@
import asyncio
import os
import logging
import time
from typing import Optional
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.core import database
from app.core.security import decode_access_token
from app.core.redis import get_redis_sync
from app.models import User, Notification, Project
@@ -22,6 +23,8 @@ PONG_TIMEOUT = 30.0 # Disconnect if no pong received within this time after pi
# Authentication timeout (10 seconds)
AUTH_TIMEOUT = 10.0
if os.getenv("TESTING") == "true":
AUTH_TIMEOUT = 1.0
async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
@@ -41,7 +44,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
return None, None
# Get user from database
db = SessionLocal()
db = database.SessionLocal()
try:
user = db.query(User).filter(User.id == user_id).first()
if user is None or not user.is_active:
@@ -54,7 +57,7 @@ async def get_user_from_token(token: str) -> tuple[str | None, User | None]:
async def authenticate_websocket(
websocket: WebSocket,
query_token: Optional[str] = None
) -> tuple[str | None, User | None]:
) -> tuple[str | None, User | None, Optional[str]]:
"""
Authenticate WebSocket connection.
@@ -72,7 +75,10 @@ async def authenticate_websocket(
"WebSocket authentication via query parameter is deprecated. "
"Please use first-message authentication for better security."
)
return await get_user_from_token(query_token)
user_id, user = await get_user_from_token(query_token)
if user_id is None:
return None, None, "invalid_token"
return user_id, user, None
# Wait for authentication message with timeout
try:
@@ -84,26 +90,29 @@ async def authenticate_websocket(
msg_type = data.get("type")
if msg_type != "auth":
logger.warning("Expected 'auth' message type, got: %s", msg_type)
return None, None
return None, None, "invalid_message"
token = data.get("token")
if not token:
logger.warning("No token provided in auth message")
return None, None
return None, None, "missing_token"
return await get_user_from_token(token)
user_id, user = await get_user_from_token(token)
if user_id is None:
return None, None, "invalid_token"
return user_id, user, None
except asyncio.TimeoutError:
logger.warning("WebSocket authentication timeout after %.1f seconds", AUTH_TIMEOUT)
return None, None
return None, None, "timeout"
except Exception as e:
logger.error("Error during WebSocket authentication: %s", e)
return None, None
return None, None, "error"
async def get_unread_notifications(user_id: str) -> list[dict]:
"""Query all unread notifications for a user."""
db = SessionLocal()
db = database.SessionLocal()
try:
notifications = (
db.query(Notification)
@@ -130,7 +139,7 @@ async def get_unread_notifications(user_id: str) -> list[dict]:
async def get_unread_count(user_id: str) -> int:
"""Get the count of unread notifications for a user."""
db = SessionLocal()
db = database.SessionLocal()
try:
return (
db.query(Notification)
@@ -174,14 +183,12 @@ async def websocket_notifications(
# Accept WebSocket connection first
await websocket.accept()
# If no query token, notify client that auth is required
if not token:
await websocket.send_json({"type": "auth_required"})
# Authenticate
user_id, user = await authenticate_websocket(websocket, token)
user_id, user, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
return
@@ -311,7 +318,7 @@ async def verify_project_access(user_id: str, project_id: str) -> tuple[bool, Pr
Returns:
Tuple of (has_access: bool, project: Project | None)
"""
db = SessionLocal()
db = database.SessionLocal()
try:
# Get the user
user = db.query(User).filter(User.id == user_id).first()
@@ -365,14 +372,12 @@ async def websocket_project_sync(
# Accept WebSocket connection first
await websocket.accept()
# If no query token, notify client that auth is required
if not token:
await websocket.send_json({"type": "auth_required"})
# Authenticate user
user_id, user = await authenticate_websocket(websocket, token)
user_id, user, error_reason = await authenticate_websocket(websocket, token)
if user_id is None:
if error_reason == "invalid_token":
await websocket.send_json({"type": "error", "message": "Invalid or expired token"})
await websocket.close(code=4001, reason="Invalid or expired token")
return