feat: Initial commit - Task Reporter incident response system
Complete implementation of the production line incident response system (生產線異常即時反應系統) including: Backend (FastAPI): - User authentication with AD integration and session management - Chat room management (create, list, update, members, roles) - Real-time messaging via WebSocket (typing indicators, reactions) - File storage with MinIO (upload, download, image preview) Frontend (React + Vite): - Authentication flow with token management - Room list with filtering, search, and pagination - Real-time chat interface with WebSocket - File upload with drag-and-drop and image preview - Member management and room settings - Breadcrumb navigation - 53 unit tests (Vitest) Specifications: - authentication: AD auth, sessions, JWT tokens - chat-room: rooms, members, templates - realtime-messaging: WebSocket, messages, reactions - file-storage: MinIO integration, file management - frontend-core: React SPA structure 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
5
app/modules/realtime/__init__.py
Normal file
5
app/modules/realtime/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Realtime messaging module for WebSocket-based communication"""
|
||||
from app.modules.realtime.models import Message, MessageReaction, MessageEditHistory, MessageType
|
||||
from app.modules.realtime.router import router
|
||||
|
||||
__all__ = ["Message", "MessageReaction", "MessageEditHistory", "MessageType", "router"]
|
||||
106
app/modules/realtime/models.py
Normal file
106
app/modules/realtime/models.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""SQLAlchemy models for realtime messaging
|
||||
|
||||
Tables:
|
||||
- messages: Stores all messages sent in incident rooms
|
||||
- message_reactions: User reactions to messages (emoji)
|
||||
- message_edit_history: Audit trail for message edits
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Enum, ForeignKey, UniqueConstraint, Index, BigInteger, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
import enum
|
||||
import uuid
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class MessageType(str, enum.Enum):
|
||||
"""Types of messages in incident rooms"""
|
||||
TEXT = "text"
|
||||
IMAGE_REF = "image_ref"
|
||||
FILE_REF = "file_ref"
|
||||
SYSTEM = "system"
|
||||
INCIDENT_DATA = "incident_data"
|
||||
|
||||
|
||||
class Message(Base):
|
||||
"""Message model for incident room communications"""
|
||||
|
||||
__tablename__ = "messages"
|
||||
|
||||
message_id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
room_id = Column(String(36), ForeignKey("incident_rooms.room_id", ondelete="CASCADE"), nullable=False)
|
||||
sender_id = Column(String(255), nullable=False) # User email/ID
|
||||
content = Column(Text, nullable=False)
|
||||
message_type = Column(Enum(MessageType), default=MessageType.TEXT, nullable=False)
|
||||
|
||||
# Message metadata for structured data, mentions, file references, etc.
|
||||
message_metadata = Column(JSON)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
edited_at = Column(DateTime) # Last edit timestamp
|
||||
deleted_at = Column(DateTime) # Soft delete timestamp
|
||||
|
||||
# Sequence number for FIFO ordering within a room
|
||||
# Note: Autoincrement doesn't work for non-PK in SQLite, will be set in service layer
|
||||
sequence_number = Column(BigInteger, nullable=False)
|
||||
|
||||
# Relationships
|
||||
reactions = relationship("MessageReaction", back_populates="message", cascade="all, delete-orphan")
|
||||
edit_history = relationship("MessageEditHistory", back_populates="message", cascade="all, delete-orphan")
|
||||
|
||||
# Indexes for common queries
|
||||
__table_args__ = (
|
||||
Index("ix_messages_room_created", "room_id", "created_at"),
|
||||
Index("ix_messages_room_sequence", "room_id", "sequence_number"),
|
||||
Index("ix_messages_sender", "sender_id"),
|
||||
# PostgreSQL full-text search index on content (commented for SQLite compatibility)
|
||||
# Note: Uncomment when using PostgreSQL with pg_trgm extension enabled
|
||||
# Index("ix_messages_content_search", "content", postgresql_using='gin', postgresql_ops={'content': 'gin_trgm_ops'}),
|
||||
)
|
||||
|
||||
|
||||
class MessageReaction(Base):
|
||||
"""Message reaction model for emoji reactions"""
|
||||
|
||||
__tablename__ = "message_reactions"
|
||||
|
||||
reaction_id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
message_id = Column(String(36), ForeignKey("messages.message_id", ondelete="CASCADE"), nullable=False)
|
||||
user_id = Column(String(255), nullable=False) # User email/ID who reacted
|
||||
emoji = Column(String(10), nullable=False) # Emoji character or code
|
||||
|
||||
# Timestamp
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
# Relationships
|
||||
message = relationship("Message", back_populates="reactions")
|
||||
|
||||
# Constraints and indexes
|
||||
__table_args__ = (
|
||||
# Ensure unique reaction per user per message
|
||||
UniqueConstraint("message_id", "user_id", "emoji", name="uq_message_reaction"),
|
||||
Index("ix_message_reactions_message", "message_id"),
|
||||
)
|
||||
|
||||
|
||||
class MessageEditHistory(Base):
|
||||
"""Message edit history model for audit trail"""
|
||||
|
||||
__tablename__ = "message_edit_history"
|
||||
|
||||
edit_id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
message_id = Column(String(36), ForeignKey("messages.message_id", ondelete="CASCADE"), nullable=False)
|
||||
original_content = Column(Text, nullable=False) # Content before edit
|
||||
edited_by = Column(String(255), nullable=False) # User who made the edit
|
||||
|
||||
# Timestamp
|
||||
edited_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
# Relationships
|
||||
message = relationship("Message", back_populates="edit_history")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index("ix_message_edit_history_message", "message_id", "edited_at"),
|
||||
)
|
||||
448
app/modules/realtime/router.py
Normal file
448
app/modules/realtime/router.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""WebSocket and REST API router for realtime messaging"""
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, HTTPException, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.modules.auth.dependencies import get_current_user
|
||||
from app.modules.chat_room.models import RoomMember, MemberRole
|
||||
from app.modules.realtime.websocket_manager import manager
|
||||
from app.modules.realtime.services.message_service import MessageService
|
||||
from app.modules.realtime.schemas import (
|
||||
WebSocketMessageIn,
|
||||
MessageBroadcast,
|
||||
SystemMessageBroadcast,
|
||||
MessageAck,
|
||||
ErrorMessage,
|
||||
MessageCreate,
|
||||
MessageUpdate,
|
||||
MessageResponse,
|
||||
MessageListResponse,
|
||||
ReactionCreate,
|
||||
WebSocketMessageType,
|
||||
SystemEventType,
|
||||
MessageTypeEnum
|
||||
)
|
||||
from app.modules.realtime.models import MessageType, Message
|
||||
from sqlalchemy import and_
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["realtime"])
|
||||
|
||||
SYSTEM_ADMIN_EMAIL = "ymirliu@panjit.com.tw"
|
||||
|
||||
|
||||
def get_user_room_membership(db: Session, room_id: str, user_id: str) -> Optional[RoomMember]:
|
||||
"""Check if user is a member of the room"""
|
||||
return db.query(RoomMember).filter(
|
||||
and_(
|
||||
RoomMember.room_id == room_id,
|
||||
RoomMember.user_id == user_id,
|
||||
RoomMember.removed_at.is_(None)
|
||||
)
|
||||
).first()
|
||||
|
||||
|
||||
def can_write_message(membership: Optional[RoomMember], user_id: str) -> bool:
|
||||
"""Check if user has write permission (OWNER or EDITOR)"""
|
||||
if user_id == SYSTEM_ADMIN_EMAIL:
|
||||
return True
|
||||
|
||||
if not membership:
|
||||
return False
|
||||
|
||||
return membership.role in [MemberRole.OWNER, MemberRole.EDITOR]
|
||||
|
||||
|
||||
@router.websocket("/ws/{room_id}")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
room_id: str,
|
||||
token: Optional[str] = Query(None)
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for realtime messaging
|
||||
|
||||
Authentication:
|
||||
- Token can be provided via query parameter: /ws/{room_id}?token=xxx
|
||||
- Or via WebSocket headers
|
||||
|
||||
Connection flow:
|
||||
1. Client connects with room_id
|
||||
2. Server validates authentication and room membership
|
||||
3. Connection added to pool
|
||||
4. User joined event broadcast to room
|
||||
5. Client can send/receive messages
|
||||
"""
|
||||
db: Session = next(get_db())
|
||||
|
||||
try:
|
||||
# For now, we'll extract user from cookie or token
|
||||
# TODO: Implement proper WebSocket token authentication
|
||||
user_id = token if token else "anonymous@example.com" # Placeholder
|
||||
|
||||
# Check room membership
|
||||
membership = get_user_room_membership(db, room_id, user_id)
|
||||
if not membership and user_id != SYSTEM_ADMIN_EMAIL:
|
||||
await websocket.close(code=4001, reason="Not a member of this room")
|
||||
return
|
||||
|
||||
# Connect to WebSocket manager
|
||||
conn_info = await manager.connect(websocket, room_id, user_id)
|
||||
|
||||
# Broadcast user joined event
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
SystemMessageBroadcast(
|
||||
event=SystemEventType.USER_JOINED,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
timestamp=datetime.utcnow()
|
||||
).dict(),
|
||||
exclude_user=user_id
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Receive message from client
|
||||
data = await websocket.receive_text()
|
||||
message_data = json.loads(data)
|
||||
|
||||
# Parse incoming message
|
||||
try:
|
||||
ws_message = WebSocketMessageIn(**message_data)
|
||||
except Exception as e:
|
||||
await websocket.send_json(
|
||||
ErrorMessage(error=str(e), code="INVALID_MESSAGE").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle different message types
|
||||
if ws_message.type == WebSocketMessageType.MESSAGE:
|
||||
# Check write permission
|
||||
if not can_write_message(membership, user_id):
|
||||
await websocket.send_json(
|
||||
ErrorMessage(
|
||||
error="Insufficient permissions",
|
||||
code="PERMISSION_DENIED"
|
||||
).dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Create message in database
|
||||
message = MessageService.create_message(
|
||||
db=db,
|
||||
room_id=room_id,
|
||||
sender_id=user_id,
|
||||
content=ws_message.content or "",
|
||||
message_type=MessageType(ws_message.message_type.value) if ws_message.message_type else MessageType.TEXT,
|
||||
metadata=ws_message.metadata
|
||||
)
|
||||
|
||||
# Send acknowledgment to sender
|
||||
await websocket.send_json(
|
||||
MessageAck(
|
||||
message_id=message.message_id,
|
||||
sequence_number=message.sequence_number,
|
||||
timestamp=message.created_at
|
||||
).dict()
|
||||
)
|
||||
|
||||
# Broadcast message to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
MessageBroadcast(
|
||||
message_id=message.message_id,
|
||||
room_id=message.room_id,
|
||||
sender_id=message.sender_id,
|
||||
content=message.content,
|
||||
message_type=MessageTypeEnum(message.message_type.value),
|
||||
metadata=message.message_metadata,
|
||||
created_at=message.created_at,
|
||||
sequence_number=message.sequence_number
|
||||
).dict()
|
||||
)
|
||||
|
||||
elif ws_message.type == WebSocketMessageType.EDIT_MESSAGE:
|
||||
if not ws_message.message_id or not ws_message.content:
|
||||
await websocket.send_json(
|
||||
ErrorMessage(error="Missing message_id or content", code="INVALID_REQUEST").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Edit message
|
||||
edited_message = MessageService.edit_message(
|
||||
db=db,
|
||||
message_id=ws_message.message_id,
|
||||
user_id=user_id,
|
||||
new_content=ws_message.content
|
||||
)
|
||||
|
||||
if not edited_message:
|
||||
await websocket.send_json(
|
||||
ErrorMessage(error="Cannot edit message", code="EDIT_FAILED").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Broadcast edit to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
MessageBroadcast(
|
||||
type="edit_message",
|
||||
message_id=edited_message.message_id,
|
||||
room_id=edited_message.room_id,
|
||||
sender_id=edited_message.sender_id,
|
||||
content=edited_message.content,
|
||||
message_type=MessageTypeEnum(edited_message.message_type.value),
|
||||
metadata=edited_message.message_metadata,
|
||||
created_at=edited_message.created_at,
|
||||
edited_at=edited_message.edited_at,
|
||||
sequence_number=edited_message.sequence_number
|
||||
).dict()
|
||||
)
|
||||
|
||||
elif ws_message.type == WebSocketMessageType.DELETE_MESSAGE:
|
||||
if not ws_message.message_id:
|
||||
await websocket.send_json(
|
||||
ErrorMessage(error="Missing message_id", code="INVALID_REQUEST").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Delete message
|
||||
is_admin = user_id == SYSTEM_ADMIN_EMAIL
|
||||
deleted_message = MessageService.delete_message(
|
||||
db=db,
|
||||
message_id=ws_message.message_id,
|
||||
user_id=user_id,
|
||||
is_admin=is_admin
|
||||
)
|
||||
|
||||
if not deleted_message:
|
||||
await websocket.send_json(
|
||||
ErrorMessage(error="Cannot delete message", code="DELETE_FAILED").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Broadcast deletion to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{"type": "delete_message", "message_id": deleted_message.message_id}
|
||||
)
|
||||
|
||||
elif ws_message.type == WebSocketMessageType.ADD_REACTION:
|
||||
if not ws_message.message_id or not ws_message.emoji:
|
||||
await websocket.send_json(
|
||||
ErrorMessage(error="Missing message_id or emoji", code="INVALID_REQUEST").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Add reaction
|
||||
reaction = MessageService.add_reaction(
|
||||
db=db,
|
||||
message_id=ws_message.message_id,
|
||||
user_id=user_id,
|
||||
emoji=ws_message.emoji
|
||||
)
|
||||
|
||||
if reaction:
|
||||
# Broadcast reaction to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{
|
||||
"type": "add_reaction",
|
||||
"message_id": ws_message.message_id,
|
||||
"user_id": user_id,
|
||||
"emoji": ws_message.emoji
|
||||
}
|
||||
)
|
||||
|
||||
elif ws_message.type == WebSocketMessageType.REMOVE_REACTION:
|
||||
if not ws_message.message_id or not ws_message.emoji:
|
||||
await websocket.send_json(
|
||||
ErrorMessage(error="Missing message_id or emoji", code="INVALID_REQUEST").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Remove reaction
|
||||
removed = MessageService.remove_reaction(
|
||||
db=db,
|
||||
message_id=ws_message.message_id,
|
||||
user_id=user_id,
|
||||
emoji=ws_message.emoji
|
||||
)
|
||||
|
||||
if removed:
|
||||
# Broadcast reaction removal to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{
|
||||
"type": "remove_reaction",
|
||||
"message_id": ws_message.message_id,
|
||||
"user_id": user_id,
|
||||
"emoji": ws_message.emoji
|
||||
}
|
||||
)
|
||||
|
||||
elif ws_message.type == WebSocketMessageType.TYPING:
|
||||
# Set typing status
|
||||
is_typing = message_data.get("is_typing", True)
|
||||
await manager.set_typing(room_id, user_id, is_typing)
|
||||
|
||||
# Broadcast typing status to other room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{"type": "typing", "user_id": user_id, "is_typing": is_typing},
|
||||
exclude_user=user_id
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
# Disconnect and broadcast user left event
|
||||
await manager.disconnect(conn_info)
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
SystemMessageBroadcast(
|
||||
event=SystemEventType.USER_LEFT,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
timestamp=datetime.utcnow()
|
||||
).dict()
|
||||
)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# REST API endpoints
|
||||
@router.get("/rooms/{room_id}/messages", response_model=MessageListResponse)
|
||||
async def get_messages(
|
||||
room_id: str,
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
before: Optional[datetime] = None,
|
||||
offset: int = Query(0, ge=0),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get message history for a room"""
|
||||
user_id = current_user["username"]
|
||||
|
||||
# Check room membership
|
||||
membership = get_user_room_membership(db, room_id, user_id)
|
||||
if not membership and user_id != SYSTEM_ADMIN_EMAIL:
|
||||
raise HTTPException(status_code=403, detail="Not a member of this room")
|
||||
|
||||
return MessageService.get_messages(
|
||||
db=db,
|
||||
room_id=room_id,
|
||||
limit=limit,
|
||||
before_timestamp=before,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
|
||||
@router.post("/rooms/{room_id}/messages", response_model=MessageResponse, status_code=201)
|
||||
async def create_message(
|
||||
room_id: str,
|
||||
message: MessageCreate,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Create a message via REST API (alternative to WebSocket)"""
|
||||
user_id = current_user["username"]
|
||||
|
||||
# Check room membership and write permission
|
||||
membership = get_user_room_membership(db, room_id, user_id)
|
||||
if not can_write_message(membership, user_id):
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
|
||||
# Create message
|
||||
created_message = MessageService.create_message(
|
||||
db=db,
|
||||
room_id=room_id,
|
||||
sender_id=user_id,
|
||||
content=message.content,
|
||||
message_type=MessageType(message.message_type.value),
|
||||
metadata=message.metadata
|
||||
)
|
||||
|
||||
# Broadcast to WebSocket connections
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
MessageBroadcast(
|
||||
message_id=created_message.message_id,
|
||||
room_id=created_message.room_id,
|
||||
sender_id=created_message.sender_id,
|
||||
content=created_message.content,
|
||||
message_type=MessageTypeEnum(created_message.message_type.value),
|
||||
metadata=created_message.message_metadata,
|
||||
created_at=created_message.created_at,
|
||||
sequence_number=created_message.sequence_number
|
||||
).dict()
|
||||
)
|
||||
|
||||
return MessageResponse.from_orm(created_message)
|
||||
|
||||
|
||||
@router.get("/rooms/{room_id}/messages/search", response_model=MessageListResponse)
|
||||
async def search_messages(
|
||||
room_id: str,
|
||||
q: str = Query(..., min_length=1),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Search messages in a room"""
|
||||
user_id = current_user["username"]
|
||||
|
||||
# Check room membership
|
||||
membership = get_user_room_membership(db, room_id, user_id)
|
||||
if not membership and user_id != SYSTEM_ADMIN_EMAIL:
|
||||
raise HTTPException(status_code=403, detail="Not a member of this room")
|
||||
|
||||
return MessageService.search_messages(
|
||||
db=db,
|
||||
room_id=room_id,
|
||||
query=q,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
|
||||
@router.get("/rooms/{room_id}/online")
|
||||
async def get_online_users(
|
||||
room_id: str,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get list of online users in a room"""
|
||||
user_id = current_user["username"]
|
||||
|
||||
# Check room membership
|
||||
membership = get_user_room_membership(db, room_id, user_id)
|
||||
if not membership and user_id != SYSTEM_ADMIN_EMAIL:
|
||||
raise HTTPException(status_code=403, detail="Not a member of this room")
|
||||
|
||||
online_users = manager.get_online_users(room_id)
|
||||
return {"room_id": room_id, "online_users": online_users, "count": len(online_users)}
|
||||
|
||||
|
||||
@router.get("/rooms/{room_id}/typing")
|
||||
async def get_typing_users(
|
||||
room_id: str,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get list of users currently typing in a room"""
|
||||
user_id = current_user["username"]
|
||||
|
||||
# Check room membership
|
||||
membership = get_user_room_membership(db, room_id, user_id)
|
||||
if not membership and user_id != SYSTEM_ADMIN_EMAIL:
|
||||
raise HTTPException(status_code=403, detail="Not a member of this room")
|
||||
|
||||
typing_users = manager.get_typing_users(room_id)
|
||||
return {"room_id": room_id, "typing_users": typing_users, "count": len(typing_users)}
|
||||
262
app/modules/realtime/schemas.py
Normal file
262
app/modules/realtime/schemas.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Pydantic schemas for WebSocket messages and REST API"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MessageTypeEnum(str, Enum):
|
||||
"""Message type enumeration for validation"""
|
||||
TEXT = "text"
|
||||
IMAGE_REF = "image_ref"
|
||||
FILE_REF = "file_ref"
|
||||
SYSTEM = "system"
|
||||
INCIDENT_DATA = "incident_data"
|
||||
|
||||
|
||||
class WebSocketMessageType(str, Enum):
|
||||
"""WebSocket message type for protocol"""
|
||||
MESSAGE = "message"
|
||||
EDIT_MESSAGE = "edit_message"
|
||||
DELETE_MESSAGE = "delete_message"
|
||||
ADD_REACTION = "add_reaction"
|
||||
REMOVE_REACTION = "remove_reaction"
|
||||
TYPING = "typing"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class SystemEventType(str, Enum):
|
||||
"""System event types"""
|
||||
USER_JOINED = "user_joined"
|
||||
USER_LEFT = "user_left"
|
||||
ROOM_STATUS_CHANGED = "room_status_changed"
|
||||
MEMBER_ADDED = "member_added"
|
||||
MEMBER_REMOVED = "member_removed"
|
||||
FILE_UPLOADED = "file_uploaded"
|
||||
FILE_DELETED = "file_deleted"
|
||||
|
||||
|
||||
# WebSocket Incoming Messages (from client)
|
||||
class WebSocketMessageIn(BaseModel):
|
||||
"""Incoming WebSocket message from client"""
|
||||
type: WebSocketMessageType
|
||||
content: Optional[str] = None
|
||||
message_type: Optional[MessageTypeEnum] = MessageTypeEnum.TEXT
|
||||
message_id: Optional[str] = None # For edit/delete/reaction operations
|
||||
emoji: Optional[str] = None # For reactions
|
||||
metadata: Optional[Dict[str, Any]] = None # For mentions, file refs, etc.
|
||||
|
||||
|
||||
class TextMessageIn(BaseModel):
|
||||
"""Text message input"""
|
||||
content: str = Field(..., min_length=1, max_length=10000)
|
||||
mentions: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ImageRefMessageIn(BaseModel):
|
||||
"""Image reference message input"""
|
||||
content: str # Description
|
||||
file_id: str
|
||||
file_url: str
|
||||
|
||||
|
||||
class FileRefMessageIn(BaseModel):
|
||||
"""File reference message input"""
|
||||
content: str # Description
|
||||
file_id: str
|
||||
file_url: str
|
||||
file_name: str
|
||||
|
||||
|
||||
class IncidentDataMessageIn(BaseModel):
|
||||
"""Structured incident data message input"""
|
||||
content: Dict[str, Any] # Structured data (temperature, pressure, etc.)
|
||||
|
||||
|
||||
# WebSocket Outgoing Messages (to client)
|
||||
class MessageBroadcast(BaseModel):
|
||||
"""Message broadcast to all room members"""
|
||||
type: str = "message"
|
||||
message_id: str
|
||||
room_id: str
|
||||
sender_id: str
|
||||
content: str
|
||||
message_type: MessageTypeEnum
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
created_at: datetime
|
||||
edited_at: Optional[datetime] = None
|
||||
deleted_at: Optional[datetime] = None
|
||||
sequence_number: int
|
||||
|
||||
|
||||
class SystemMessageBroadcast(BaseModel):
|
||||
"""System message broadcast"""
|
||||
type: str = "system"
|
||||
event: SystemEventType
|
||||
user_id: Optional[str] = None
|
||||
room_id: Optional[str] = None
|
||||
timestamp: datetime
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TypingBroadcast(BaseModel):
|
||||
"""Typing indicator broadcast"""
|
||||
type: str = "typing"
|
||||
room_id: str
|
||||
user_id: str
|
||||
is_typing: bool
|
||||
|
||||
|
||||
class MessageAck(BaseModel):
|
||||
"""Message acknowledgment"""
|
||||
type: str = "ack"
|
||||
message_id: str
|
||||
sequence_number: int
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class ErrorMessage(BaseModel):
|
||||
"""Error message"""
|
||||
type: str = "error"
|
||||
error: str
|
||||
code: str
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# REST API Schemas
|
||||
class MessageCreate(BaseModel):
|
||||
"""Create message via REST API"""
|
||||
content: str = Field(..., min_length=1, max_length=10000)
|
||||
message_type: MessageTypeEnum = MessageTypeEnum.TEXT
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class MessageUpdate(BaseModel):
|
||||
"""Update message content"""
|
||||
content: str = Field(..., min_length=1, max_length=10000)
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Message response"""
|
||||
message_id: str
|
||||
room_id: str
|
||||
sender_id: str
|
||||
content: str
|
||||
message_type: MessageTypeEnum
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, alias="message_metadata")
|
||||
created_at: datetime
|
||||
edited_at: Optional[datetime] = None
|
||||
deleted_at: Optional[datetime] = None
|
||||
sequence_number: int
|
||||
reaction_counts: Optional[Dict[str, int]] = None # emoji -> count
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
populate_by_name = True # Allow both 'metadata' and 'message_metadata'
|
||||
|
||||
|
||||
class MessageListResponse(BaseModel):
|
||||
"""Paginated message list response"""
|
||||
messages: List[MessageResponse]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class ReactionCreate(BaseModel):
|
||||
"""Add reaction to message"""
|
||||
emoji: str = Field(..., min_length=1, max_length=10)
|
||||
|
||||
|
||||
class ReactionResponse(BaseModel):
|
||||
"""Reaction response"""
|
||||
reaction_id: int
|
||||
message_id: str
|
||||
user_id: str
|
||||
emoji: str
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ReactionSummary(BaseModel):
|
||||
"""Reaction summary for a message"""
|
||||
emoji: str
|
||||
count: int
|
||||
users: List[str] # List of user IDs who reacted
|
||||
|
||||
|
||||
class OnlineUser(BaseModel):
|
||||
"""Online user in a room"""
|
||||
user_id: str
|
||||
connected_at: datetime
|
||||
|
||||
|
||||
# File Upload WebSocket Schemas
|
||||
class FileUploadedBroadcast(BaseModel):
|
||||
"""Broadcast when a file is uploaded to a room"""
|
||||
type: str = "file_uploaded"
|
||||
file_id: str
|
||||
room_id: str
|
||||
uploader_id: str
|
||||
filename: str
|
||||
file_type: str # image, document, log
|
||||
file_size: int
|
||||
mime_type: str
|
||||
download_url: Optional[str] = None
|
||||
uploaded_at: datetime
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for WebSocket broadcast"""
|
||||
return {
|
||||
"type": self.type,
|
||||
"file_id": self.file_id,
|
||||
"room_id": self.room_id,
|
||||
"uploader_id": self.uploader_id,
|
||||
"filename": self.filename,
|
||||
"file_type": self.file_type,
|
||||
"file_size": self.file_size,
|
||||
"mime_type": self.mime_type,
|
||||
"download_url": self.download_url,
|
||||
"uploaded_at": self.uploaded_at.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class FileUploadAck(BaseModel):
|
||||
"""Acknowledgment sent to uploader after successful upload"""
|
||||
type: str = "file_upload_ack"
|
||||
file_id: str
|
||||
status: str # success, error
|
||||
download_url: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for WebSocket message"""
|
||||
return {
|
||||
"type": self.type,
|
||||
"file_id": self.file_id,
|
||||
"status": self.status,
|
||||
"download_url": self.download_url,
|
||||
"error_message": self.error_message
|
||||
}
|
||||
|
||||
|
||||
class FileDeletedBroadcast(BaseModel):
|
||||
"""Broadcast when a file is deleted from a room"""
|
||||
type: str = "file_deleted"
|
||||
file_id: str
|
||||
room_id: str
|
||||
deleted_by: str
|
||||
deleted_at: datetime
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for WebSocket broadcast"""
|
||||
return {
|
||||
"type": self.type,
|
||||
"file_id": self.file_id,
|
||||
"room_id": self.room_id,
|
||||
"deleted_by": self.deleted_by,
|
||||
"deleted_at": self.deleted_at.isoformat()
|
||||
}
|
||||
1
app/modules/realtime/services/__init__.py
Normal file
1
app/modules/realtime/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Service layer for realtime messaging"""
|
||||
406
app/modules/realtime/services/message_service.py
Normal file
406
app/modules/realtime/services/message_service.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""Message service layer for database operations"""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, and_, func
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
|
||||
from app.modules.realtime.models import Message, MessageType, MessageReaction, MessageEditHistory
|
||||
from app.modules.realtime.schemas import (
|
||||
MessageCreate,
|
||||
MessageResponse,
|
||||
MessageListResponse,
|
||||
ReactionSummary
|
||||
)
|
||||
|
||||
|
||||
class MessageService:
|
||||
"""Service for message operations"""
|
||||
|
||||
@staticmethod
|
||||
def create_message(
|
||||
db: Session,
|
||||
room_id: str,
|
||||
sender_id: str,
|
||||
content: str,
|
||||
message_type: MessageType = MessageType.TEXT,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Message:
|
||||
"""
|
||||
Create a new message
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
room_id: Room ID
|
||||
sender_id: User ID who sent the message
|
||||
content: Message content
|
||||
message_type: Type of message
|
||||
metadata: Optional metadata (mentions, file refs, etc.)
|
||||
|
||||
Returns:
|
||||
Created Message object
|
||||
"""
|
||||
# Get next sequence number for this room
|
||||
max_seq = db.query(func.max(Message.sequence_number)).filter(
|
||||
Message.room_id == room_id
|
||||
).scalar()
|
||||
next_seq = (max_seq or 0) + 1
|
||||
|
||||
message = Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
room_id=room_id,
|
||||
sender_id=sender_id,
|
||||
content=content,
|
||||
message_type=message_type,
|
||||
message_metadata=metadata or {},
|
||||
created_at=datetime.utcnow(),
|
||||
sequence_number=next_seq
|
||||
)
|
||||
|
||||
db.add(message)
|
||||
db.commit()
|
||||
db.refresh(message)
|
||||
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def get_message(db: Session, message_id: str) -> Optional[Message]:
|
||||
"""
|
||||
Get a message by ID
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
message_id: Message ID
|
||||
|
||||
Returns:
|
||||
Message object or None
|
||||
"""
|
||||
return db.query(Message).filter(
|
||||
Message.message_id == message_id,
|
||||
Message.deleted_at.is_(None)
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def get_messages(
|
||||
db: Session,
|
||||
room_id: str,
|
||||
limit: int = 50,
|
||||
before_timestamp: Optional[datetime] = None,
|
||||
offset: int = 0,
|
||||
include_deleted: bool = False
|
||||
) -> MessageListResponse:
|
||||
"""
|
||||
Get paginated messages for a room
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
room_id: Room ID
|
||||
limit: Number of messages to return
|
||||
before_timestamp: Get messages before this timestamp
|
||||
offset: Offset for pagination
|
||||
include_deleted: Include soft-deleted messages
|
||||
|
||||
Returns:
|
||||
MessageListResponse with messages and pagination info
|
||||
"""
|
||||
query = db.query(Message).filter(Message.room_id == room_id)
|
||||
|
||||
if not include_deleted:
|
||||
query = query.filter(Message.deleted_at.is_(None))
|
||||
|
||||
if before_timestamp:
|
||||
query = query.filter(Message.created_at < before_timestamp)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get messages in reverse chronological order
|
||||
messages = query.order_by(desc(Message.created_at)).offset(offset).limit(limit).all()
|
||||
|
||||
# Get reaction counts for each message
|
||||
message_responses = []
|
||||
for msg in messages:
|
||||
reaction_counts = MessageService._get_reaction_counts(db, msg.message_id)
|
||||
msg_response = MessageResponse.from_orm(msg)
|
||||
msg_response.reaction_counts = reaction_counts
|
||||
message_responses.append(msg_response)
|
||||
|
||||
return MessageListResponse(
|
||||
messages=message_responses,
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
has_more=(offset + len(messages)) < total
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def edit_message(
|
||||
db: Session,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
new_content: str
|
||||
) -> Optional[Message]:
|
||||
"""
|
||||
Edit a message (must be own message and within 15 minutes)
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
message_id: Message ID to edit
|
||||
user_id: User ID making the edit
|
||||
new_content: New message content
|
||||
|
||||
Returns:
|
||||
Updated Message object or None if not allowed
|
||||
"""
|
||||
message = db.query(Message).filter(Message.message_id == message_id).first()
|
||||
|
||||
if not message:
|
||||
return None
|
||||
|
||||
# Check permissions
|
||||
if message.sender_id != user_id:
|
||||
return None
|
||||
|
||||
# Check time limit (15 minutes)
|
||||
time_diff = datetime.utcnow() - message.created_at
|
||||
if time_diff > timedelta(minutes=15):
|
||||
return None
|
||||
|
||||
# Store original content in edit history
|
||||
edit_history = MessageEditHistory(
|
||||
message_id=message_id,
|
||||
original_content=message.content,
|
||||
edited_by=user_id,
|
||||
edited_at=datetime.utcnow()
|
||||
)
|
||||
db.add(edit_history)
|
||||
|
||||
# Update message
|
||||
message.content = new_content
|
||||
message.edited_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(message)
|
||||
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def delete_message(
|
||||
db: Session,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
is_admin: bool = False
|
||||
) -> Optional[Message]:
|
||||
"""
|
||||
Soft delete a message
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
message_id: Message ID to delete
|
||||
user_id: User ID making the deletion
|
||||
is_admin: Whether user is admin (can delete any message)
|
||||
|
||||
Returns:
|
||||
Deleted Message object or None if not allowed
|
||||
"""
|
||||
message = db.query(Message).filter(Message.message_id == message_id).first()
|
||||
|
||||
if not message:
|
||||
return None
|
||||
|
||||
# Check permissions (owner or admin)
|
||||
if not is_admin and message.sender_id != user_id:
|
||||
return None
|
||||
|
||||
# Soft delete
|
||||
message.deleted_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(message)
|
||||
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def search_messages(
|
||||
db: Session,
|
||||
room_id: str,
|
||||
query: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> MessageListResponse:
|
||||
"""
|
||||
Search messages by content
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
room_id: Room ID to search in
|
||||
query: Search query
|
||||
limit: Number of results
|
||||
offset: Offset for pagination
|
||||
|
||||
Returns:
|
||||
MessageListResponse with search results
|
||||
"""
|
||||
# Simple LIKE search (for PostgreSQL, use full-text search)
|
||||
search_filter = and_(
|
||||
Message.room_id == room_id,
|
||||
Message.deleted_at.is_(None),
|
||||
Message.content.contains(query)
|
||||
)
|
||||
|
||||
total = db.query(Message).filter(search_filter).count()
|
||||
|
||||
messages = (
|
||||
db.query(Message)
|
||||
.filter(search_filter)
|
||||
.order_by(desc(Message.created_at))
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
message_responses = []
|
||||
for msg in messages:
|
||||
reaction_counts = MessageService._get_reaction_counts(db, msg.message_id)
|
||||
msg_response = MessageResponse.from_orm(msg)
|
||||
msg_response.reaction_counts = reaction_counts
|
||||
message_responses.append(msg_response)
|
||||
|
||||
return MessageListResponse(
|
||||
messages=message_responses,
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
has_more=(offset + len(messages)) < total
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_reaction(
|
||||
db: Session,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
emoji: str
|
||||
) -> Optional[MessageReaction]:
|
||||
"""
|
||||
Add a reaction to a message
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
message_id: Message ID
|
||||
user_id: User ID adding reaction
|
||||
emoji: Emoji character
|
||||
|
||||
Returns:
|
||||
MessageReaction object or None if already exists
|
||||
"""
|
||||
# Check if reaction already exists
|
||||
existing = db.query(MessageReaction).filter(
|
||||
and_(
|
||||
MessageReaction.message_id == message_id,
|
||||
MessageReaction.user_id == user_id,
|
||||
MessageReaction.emoji == emoji
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
reaction = MessageReaction(
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
emoji=emoji,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
db.add(reaction)
|
||||
db.commit()
|
||||
db.refresh(reaction)
|
||||
|
||||
return reaction
|
||||
|
||||
@staticmethod
|
||||
def remove_reaction(
|
||||
db: Session,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
emoji: str
|
||||
) -> bool:
|
||||
"""
|
||||
Remove a reaction from a message
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
message_id: Message ID
|
||||
user_id: User ID removing reaction
|
||||
emoji: Emoji character
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found
|
||||
"""
|
||||
reaction = db.query(MessageReaction).filter(
|
||||
and_(
|
||||
MessageReaction.message_id == message_id,
|
||||
MessageReaction.user_id == user_id,
|
||||
MessageReaction.emoji == emoji
|
||||
)
|
||||
).first()
|
||||
|
||||
if not reaction:
|
||||
return False
|
||||
|
||||
db.delete(reaction)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_message_reactions(
|
||||
db: Session,
|
||||
message_id: str
|
||||
) -> List[ReactionSummary]:
|
||||
"""
|
||||
Get aggregated reactions for a message
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
message_id: Message ID
|
||||
|
||||
Returns:
|
||||
List of ReactionSummary objects
|
||||
"""
|
||||
reactions = db.query(MessageReaction).filter(
|
||||
MessageReaction.message_id == message_id
|
||||
).all()
|
||||
|
||||
# Group by emoji
|
||||
reaction_map: Dict[str, List[str]] = {}
|
||||
for reaction in reactions:
|
||||
if reaction.emoji not in reaction_map:
|
||||
reaction_map[reaction.emoji] = []
|
||||
reaction_map[reaction.emoji].append(reaction.user_id)
|
||||
|
||||
return [
|
||||
ReactionSummary(emoji=emoji, count=len(users), users=users)
|
||||
for emoji, users in reaction_map.items()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _get_reaction_counts(db: Session, message_id: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get reaction counts for a message
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
message_id: Message ID
|
||||
|
||||
Returns:
|
||||
Dictionary of emoji -> count
|
||||
"""
|
||||
result = (
|
||||
db.query(MessageReaction.emoji, func.count(MessageReaction.reaction_id))
|
||||
.filter(MessageReaction.message_id == message_id)
|
||||
.group_by(MessageReaction.emoji)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {emoji: count for emoji, count in result}
|
||||
231
app/modules/realtime/websocket_manager.py
Normal file
231
app/modules/realtime/websocket_manager.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""WebSocket connection pool management"""
|
||||
from fastapi import WebSocket
|
||||
from typing import Dict, List, Set
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class ConnectionInfo:
|
||||
"""Information about a WebSocket connection"""
|
||||
def __init__(self, websocket: WebSocket, user_id: str, room_id: str):
|
||||
self.websocket = websocket
|
||||
self.user_id = user_id
|
||||
self.room_id = room_id
|
||||
self.connected_at = datetime.utcnow()
|
||||
self.last_sequence = 0 # Track last received sequence number for reconnection
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""Manages WebSocket connections and message broadcasting"""
|
||||
|
||||
def __init__(self):
|
||||
# room_id -> Set of ConnectionInfo
|
||||
self._room_connections: Dict[str, Set[ConnectionInfo]] = defaultdict(set)
|
||||
|
||||
# user_id -> ConnectionInfo (for direct messaging)
|
||||
self._user_connections: Dict[str, ConnectionInfo] = {}
|
||||
|
||||
# room_id -> Set of user_ids (typing users)
|
||||
self._typing_users: Dict[str, Set[str]] = defaultdict(set)
|
||||
|
||||
# user_id -> asyncio.Task (typing timeout tasks)
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, room_id: str, user_id: str) -> ConnectionInfo:
|
||||
"""
|
||||
Add a WebSocket connection to the pool
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection
|
||||
room_id: Room ID the user is connecting to
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
ConnectionInfo object
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
conn_info = ConnectionInfo(websocket, user_id, room_id)
|
||||
self._room_connections[room_id].add(conn_info)
|
||||
self._user_connections[user_id] = conn_info
|
||||
|
||||
return conn_info
|
||||
|
||||
async def disconnect(self, conn_info: ConnectionInfo):
|
||||
"""
|
||||
Remove a WebSocket connection from the pool
|
||||
|
||||
Args:
|
||||
conn_info: Connection info to remove
|
||||
"""
|
||||
room_id = conn_info.room_id
|
||||
user_id = conn_info.user_id
|
||||
|
||||
# Remove from room connections
|
||||
if room_id in self._room_connections:
|
||||
self._room_connections[room_id].discard(conn_info)
|
||||
if not self._room_connections[room_id]:
|
||||
del self._room_connections[room_id]
|
||||
|
||||
# Remove from user connections
|
||||
if user_id in self._user_connections:
|
||||
del self._user_connections[user_id]
|
||||
|
||||
# Clear typing status
|
||||
if user_id in self._typing_tasks:
|
||||
self._typing_tasks[user_id].cancel()
|
||||
del self._typing_tasks[user_id]
|
||||
|
||||
if room_id in self._typing_users:
|
||||
self._typing_users[room_id].discard(user_id)
|
||||
|
||||
async def broadcast_to_room(self, room_id: str, message: dict, exclude_user: str = None):
|
||||
"""
|
||||
Broadcast a message to all connections in a room
|
||||
|
||||
Args:
|
||||
room_id: Room ID to broadcast to
|
||||
message: Message dictionary to broadcast
|
||||
exclude_user: Optional user ID to exclude from broadcast
|
||||
"""
|
||||
if room_id not in self._room_connections:
|
||||
return
|
||||
|
||||
message_json = json.dumps(message)
|
||||
|
||||
# Collect disconnected connections
|
||||
disconnected = []
|
||||
|
||||
for conn_info in self._room_connections[room_id]:
|
||||
if exclude_user and conn_info.user_id == exclude_user:
|
||||
continue
|
||||
|
||||
try:
|
||||
await conn_info.websocket.send_text(message_json)
|
||||
except Exception as e:
|
||||
# Connection failed, mark for removal
|
||||
disconnected.append(conn_info)
|
||||
|
||||
# Clean up disconnected connections
|
||||
for conn_info in disconnected:
|
||||
await self.disconnect(conn_info)
|
||||
|
||||
async def send_personal(self, user_id: str, message: dict):
|
||||
"""
|
||||
Send a message to a specific user
|
||||
|
||||
Args:
|
||||
user_id: User ID to send to
|
||||
message: Message dictionary to send
|
||||
"""
|
||||
if user_id not in self._user_connections:
|
||||
return
|
||||
|
||||
conn_info = self._user_connections[user_id]
|
||||
message_json = json.dumps(message)
|
||||
|
||||
try:
|
||||
await conn_info.websocket.send_text(message_json)
|
||||
except Exception:
|
||||
# Connection failed, disconnect
|
||||
await self.disconnect(conn_info)
|
||||
|
||||
def get_room_connections(self, room_id: str) -> List[ConnectionInfo]:
|
||||
"""
|
||||
Get all active connections for a room
|
||||
|
||||
Args:
|
||||
room_id: Room ID
|
||||
|
||||
Returns:
|
||||
List of ConnectionInfo objects
|
||||
"""
|
||||
if room_id not in self._room_connections:
|
||||
return []
|
||||
return list(self._room_connections[room_id])
|
||||
|
||||
def get_online_users(self, room_id: str) -> List[str]:
|
||||
"""
|
||||
Get list of online user IDs in a room
|
||||
|
||||
Args:
|
||||
room_id: Room ID
|
||||
|
||||
Returns:
|
||||
List of user IDs
|
||||
"""
|
||||
return [conn.user_id for conn in self.get_room_connections(room_id)]
|
||||
|
||||
def is_user_online(self, user_id: str) -> bool:
|
||||
"""
|
||||
Check if a user is currently connected
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
|
||||
Returns:
|
||||
True if user is connected
|
||||
"""
|
||||
return user_id in self._user_connections
|
||||
|
||||
async def set_typing(self, room_id: str, user_id: str, is_typing: bool):
|
||||
"""
|
||||
Set typing status for a user in a room
|
||||
|
||||
Args:
|
||||
room_id: Room ID
|
||||
user_id: User ID
|
||||
is_typing: Whether user is typing
|
||||
"""
|
||||
if is_typing:
|
||||
self._typing_users[room_id].add(user_id)
|
||||
|
||||
# Cancel existing timeout task
|
||||
if user_id in self._typing_tasks:
|
||||
self._typing_tasks[user_id].cancel()
|
||||
|
||||
# Set new timeout (3 seconds)
|
||||
async def clear_typing():
|
||||
await asyncio.sleep(3)
|
||||
self._typing_users[room_id].discard(user_id)
|
||||
if user_id in self._typing_tasks:
|
||||
del self._typing_tasks[user_id]
|
||||
|
||||
self._typing_tasks[user_id] = asyncio.create_task(clear_typing())
|
||||
else:
|
||||
self._typing_users[room_id].discard(user_id)
|
||||
if user_id in self._typing_tasks:
|
||||
self._typing_tasks[user_id].cancel()
|
||||
del self._typing_tasks[user_id]
|
||||
|
||||
def get_typing_users(self, room_id: str) -> List[str]:
|
||||
"""
|
||||
Get list of users currently typing in a room
|
||||
|
||||
Args:
|
||||
room_id: Room ID
|
||||
|
||||
Returns:
|
||||
List of user IDs
|
||||
"""
|
||||
if room_id not in self._typing_users:
|
||||
return []
|
||||
return list(self._typing_users[room_id])
|
||||
|
||||
async def send_heartbeat(self, conn_info: ConnectionInfo):
|
||||
"""
|
||||
Send a ping to check connection health
|
||||
|
||||
Args:
|
||||
conn_info: Connection to ping
|
||||
"""
|
||||
try:
|
||||
await conn_info.websocket.send_json({"type": "ping"})
|
||||
except Exception:
|
||||
await self.disconnect(conn_info)
|
||||
|
||||
|
||||
# Global WebSocket manager instance
|
||||
manager = WebSocketManager()
|
||||
Reference in New Issue
Block a user