"""WebSocket and REST API router for realtime messaging""" from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, HTTPException, Query from fastapi.responses import JSONResponse from sqlalchemy.orm import Session from typing import Optional from datetime import datetime import json from app.core.database import get_db, get_db_context from app.core.config import get_settings from app.modules.auth.dependencies import get_current_user from app.modules.auth.services.session_service import session_service from app.modules.chat_room.models import RoomMember, MemberRole from app.modules.realtime.websocket_manager import manager, json_serializer from app.modules.realtime.services.message_service import MessageService from app.modules.realtime.schemas import ( WebSocketMessageIn, MessageBroadcast, SystemMessageBroadcast, MessageAck, ErrorMessage, MessageCreate, MessageUpdate, MessageResponse, MessageListResponse, ReactionCreate, WebSocketMessageType, SystemEventType, MessageTypeEnum ) from app.modules.realtime.models import MessageType, Message from sqlalchemy import and_ router = APIRouter(prefix="/api", tags=["realtime"]) settings = get_settings() async def ws_send_json(websocket: WebSocket, data: dict): """Send JSON with custom datetime serializer""" await websocket.send_text(json.dumps(data, default=json_serializer)) def get_user_room_membership(db: Session, room_id: str, user_id: str) -> Optional[RoomMember]: """Check if user is a member of the room""" return db.query(RoomMember).filter( and_( RoomMember.room_id == room_id, RoomMember.user_id == user_id, RoomMember.removed_at.is_(None) ) ).first() def is_system_admin(user_id: str) -> bool: """Check if user is the system administrator""" return bool(settings.SYSTEM_ADMIN_EMAIL and user_id == settings.SYSTEM_ADMIN_EMAIL) def can_write_message(membership: Optional[RoomMember], user_id: str) -> bool: """Check if user has write permission (OWNER or EDITOR)""" if is_system_admin(user_id): return True if not membership: return False return membership.role in [MemberRole.OWNER, MemberRole.EDITOR] @router.websocket("/ws/{room_id}") async def websocket_endpoint( websocket: WebSocket, room_id: str, token: Optional[str] = Query(None) ): """ WebSocket endpoint for realtime messaging Authentication: - Token can be provided via query parameter: /ws/{room_id}?token=xxx - Or via WebSocket headers Connection flow: 1. Client connects with room_id 2. Server validates authentication and room membership 3. Connection added to pool 4. User joined event broadcast to room 5. Client can send/receive messages Note: Uses short-lived database sessions for each operation to prevent connection pool exhaustion with many concurrent WebSocket connections. """ # Authenticate and get user info using short-lived session if not token: await websocket.close(code=4001, reason="Authentication required") return # Authenticate token and check membership with short session with get_db_context() as db: user_session = session_service.get_session_by_token(db, token) if not user_session: await websocket.close(code=4001, reason="Invalid or expired token") return user_id = user_session.username # Check room membership and cache the role membership = get_user_room_membership(db, room_id, user_id) if not membership and not is_system_admin(user_id): await websocket.close(code=4001, reason="Not a member of this room") return # Cache membership role for permission checks (avoid holding DB reference) user_role = membership.role if membership else None # Connect to WebSocket manager (no DB needed) conn_info = await manager.connect(websocket, room_id, user_id) # Broadcast user joined event (no DB needed) await manager.broadcast_to_room( room_id, SystemMessageBroadcast( event=SystemEventType.USER_JOINED, user_id=user_id, room_id=room_id, timestamp=datetime.utcnow() ).dict(), exclude_user=user_id ) try: while True: # Receive message from client data = await websocket.receive_text() message_data = json.loads(data) # Parse incoming message try: ws_message = WebSocketMessageIn(**message_data) except Exception as e: await ws_send_json(websocket, ErrorMessage(error=str(e), code="INVALID_MESSAGE").dict() ) continue # Handle different message types if ws_message.type == WebSocketMessageType.MESSAGE: # Check write permission using cached role if not _can_write_with_role(user_role, user_id): await ws_send_json(websocket, ErrorMessage( error="Insufficient permissions", code="PERMISSION_DENIED" ).dict() ) continue # Create message in database with short session with get_db_context() as db: message = MessageService.create_message( db=db, room_id=room_id, sender_id=user_id, content=ws_message.content or "", message_type=MessageType(ws_message.message_type.value) if ws_message.message_type else MessageType.TEXT, metadata=ws_message.metadata ) # Get sender display name display_name = MessageService.get_display_name(db, user_id) # Extract data before session closes msg_data = { "message_id": message.message_id, "room_id": message.room_id, "sender_id": message.sender_id, "sender_display_name": display_name or user_id, "content": message.content, "message_type": message.message_type.value, "metadata": message.message_metadata, "created_at": message.created_at, "sequence_number": message.sequence_number } # Send acknowledgment to sender await ws_send_json(websocket, MessageAck( message_id=msg_data["message_id"], sequence_number=msg_data["sequence_number"], timestamp=msg_data["created_at"] ).dict() ) # Broadcast message to all room members await manager.broadcast_to_room( room_id, MessageBroadcast( message_id=msg_data["message_id"], room_id=msg_data["room_id"], sender_id=msg_data["sender_id"], sender_display_name=msg_data["sender_display_name"], content=msg_data["content"], message_type=MessageTypeEnum(msg_data["message_type"]), metadata=msg_data["metadata"], created_at=msg_data["created_at"], sequence_number=msg_data["sequence_number"] ).dict() ) elif ws_message.type == WebSocketMessageType.EDIT_MESSAGE: if not ws_message.message_id or not ws_message.content: await ws_send_json(websocket, ErrorMessage(error="Missing message_id or content", code="INVALID_REQUEST").dict() ) continue # Edit message with short session with get_db_context() as db: edited_message = MessageService.edit_message( db=db, message_id=ws_message.message_id, user_id=user_id, new_content=ws_message.content ) if edited_message: # Get sender display name display_name = MessageService.get_display_name(db, edited_message.sender_id) edit_data = { "message_id": edited_message.message_id, "room_id": edited_message.room_id, "sender_id": edited_message.sender_id, "sender_display_name": display_name or edited_message.sender_id, "content": edited_message.content, "message_type": edited_message.message_type.value, "metadata": edited_message.message_metadata, "created_at": edited_message.created_at, "edited_at": edited_message.edited_at, "sequence_number": edited_message.sequence_number } else: edit_data = None if not edit_data: await ws_send_json(websocket, ErrorMessage(error="Cannot edit message", code="EDIT_FAILED").dict() ) continue # Broadcast edit to all room members await manager.broadcast_to_room( room_id, MessageBroadcast( type="edit_message", message_id=edit_data["message_id"], room_id=edit_data["room_id"], sender_id=edit_data["sender_id"], sender_display_name=edit_data["sender_display_name"], content=edit_data["content"], message_type=MessageTypeEnum(edit_data["message_type"]), metadata=edit_data["metadata"], created_at=edit_data["created_at"], edited_at=edit_data["edited_at"], sequence_number=edit_data["sequence_number"] ).dict() ) elif ws_message.type == WebSocketMessageType.DELETE_MESSAGE: if not ws_message.message_id: await ws_send_json(websocket, ErrorMessage(error="Missing message_id", code="INVALID_REQUEST").dict() ) continue # Delete message with short session with get_db_context() as db: deleted_message = MessageService.delete_message( db=db, message_id=ws_message.message_id, user_id=user_id, is_admin=is_system_admin(user_id) ) deleted_msg_id = deleted_message.message_id if deleted_message else None if not deleted_msg_id: await ws_send_json(websocket, ErrorMessage(error="Cannot delete message", code="DELETE_FAILED").dict() ) continue # Broadcast deletion to all room members await manager.broadcast_to_room( room_id, {"type": "delete_message", "message_id": deleted_msg_id} ) elif ws_message.type == WebSocketMessageType.ADD_REACTION: if not ws_message.message_id or not ws_message.emoji: await ws_send_json(websocket, ErrorMessage(error="Missing message_id or emoji", code="INVALID_REQUEST").dict() ) continue # Add reaction with short session with get_db_context() as db: reaction = MessageService.add_reaction( db=db, message_id=ws_message.message_id, user_id=user_id, emoji=ws_message.emoji ) reaction_added = reaction is not None if reaction_added: # Broadcast reaction to all room members await manager.broadcast_to_room( room_id, { "type": "add_reaction", "message_id": ws_message.message_id, "user_id": user_id, "emoji": ws_message.emoji } ) elif ws_message.type == WebSocketMessageType.REMOVE_REACTION: if not ws_message.message_id or not ws_message.emoji: await ws_send_json(websocket, ErrorMessage(error="Missing message_id or emoji", code="INVALID_REQUEST").dict() ) continue # Remove reaction with short session with get_db_context() as db: removed = MessageService.remove_reaction( db=db, message_id=ws_message.message_id, user_id=user_id, emoji=ws_message.emoji ) if removed: # Broadcast reaction removal to all room members await manager.broadcast_to_room( room_id, { "type": "remove_reaction", "message_id": ws_message.message_id, "user_id": user_id, "emoji": ws_message.emoji } ) elif ws_message.type == WebSocketMessageType.TYPING: # Set typing status (no DB needed) is_typing = message_data.get("is_typing", True) await manager.set_typing(room_id, user_id, is_typing) # Broadcast typing status to other room members await manager.broadcast_to_room( room_id, {"type": "typing", "user_id": user_id, "is_typing": is_typing}, exclude_user=user_id ) except WebSocketDisconnect: pass finally: # Disconnect and broadcast user left event await manager.disconnect(conn_info) await manager.broadcast_to_room( room_id, SystemMessageBroadcast( event=SystemEventType.USER_LEFT, user_id=user_id, room_id=room_id, timestamp=datetime.utcnow() ).dict() ) def _can_write_with_role(role: Optional[MemberRole], user_id: str) -> bool: """Check if user has write permission based on cached role""" if is_system_admin(user_id): return True if not role: return False return role in [MemberRole.OWNER, MemberRole.EDITOR] # REST API endpoints @router.get("/rooms/{room_id}/messages", response_model=MessageListResponse) async def get_messages( room_id: str, limit: int = Query(50, ge=1, le=100), before: Optional[datetime] = None, offset: int = Query(0, ge=0), current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ): """Get message history for a room""" user_id = current_user["username"] # Check room membership membership = get_user_room_membership(db, room_id, user_id) if not membership and not is_system_admin(user_id): raise HTTPException(status_code=403, detail="Not a member of this room") return MessageService.get_messages( db=db, room_id=room_id, limit=limit, before_timestamp=before, offset=offset ) @router.post("/rooms/{room_id}/messages", response_model=MessageResponse, status_code=201) async def create_message( room_id: str, message: MessageCreate, current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ): """Create a message via REST API (alternative to WebSocket)""" user_id = current_user["username"] # Check room membership and write permission membership = get_user_room_membership(db, room_id, user_id) if not can_write_message(membership, user_id): raise HTTPException(status_code=403, detail="Insufficient permissions") # Create message created_message = MessageService.create_message( db=db, room_id=room_id, sender_id=user_id, content=message.content, message_type=MessageType(message.message_type.value), metadata=message.metadata ) # Get sender display name display_name = MessageService.get_display_name(db, user_id) sender_display_name = display_name or user_id # Broadcast to WebSocket connections await manager.broadcast_to_room( room_id, MessageBroadcast( message_id=created_message.message_id, room_id=created_message.room_id, sender_id=created_message.sender_id, sender_display_name=sender_display_name, content=created_message.content, message_type=MessageTypeEnum(created_message.message_type.value), metadata=created_message.message_metadata, created_at=created_message.created_at, sequence_number=created_message.sequence_number ).dict() ) # Build response with display name response = MessageResponse.from_orm(created_message) response.sender_display_name = sender_display_name return response @router.get("/rooms/{room_id}/messages/search", response_model=MessageListResponse) async def search_messages( room_id: str, q: str = Query(..., min_length=1), limit: int = Query(50, ge=1, le=100), offset: int = Query(0, ge=0), current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ): """Search messages in a room""" user_id = current_user["username"] # Check room membership membership = get_user_room_membership(db, room_id, user_id) if not membership and not is_system_admin(user_id): raise HTTPException(status_code=403, detail="Not a member of this room") return MessageService.search_messages( db=db, room_id=room_id, query=q, limit=limit, offset=offset ) @router.get("/rooms/{room_id}/online") async def get_online_users( room_id: str, current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ): """Get list of online users in a room""" user_id = current_user["username"] # Check room membership membership = get_user_room_membership(db, room_id, user_id) if not membership and not is_system_admin(user_id): raise HTTPException(status_code=403, detail="Not a member of this room") online_users = manager.get_online_users(room_id) return {"room_id": room_id, "online_users": online_users, "count": len(online_users)} @router.get("/rooms/{room_id}/typing") async def get_typing_users( room_id: str, current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ): """Get list of users currently typing in a room""" user_id = current_user["username"] # Check room membership membership = get_user_room_membership(db, room_id, user_id) if not membership and not is_system_admin(user_id): raise HTTPException(status_code=403, detail="Not a member of this room") typing_users = manager.get_typing_users(room_id) return {"room_id": room_id, "typing_users": typing_users, "count": len(typing_users)}