feat: Add Chat UX improvements with notifications and @mention support
- Add ActionBar component with expandable toolbar for mobile - Add @mention functionality with autocomplete dropdown - Add browser notification system (push, sound, vibration) - Add NotificationSettings modal for user preferences - Add mention badges on room list cards - Add ReportPreview with Markdown rendering and copy/download - Add message copy functionality with hover actions - Add backend mentions field to messages with Alembic migration - Add lots field to rooms, remove templates - Optimize WebSocket database session handling - Various UX polish (animations, accessibility) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -35,9 +35,12 @@ class Message(Base):
|
||||
content = Column(Text, nullable=False)
|
||||
message_type = Column(Enum(MessageType), default=MessageType.TEXT, nullable=False)
|
||||
|
||||
# Message metadata for structured data, mentions, file references, etc.
|
||||
# Message metadata for structured data, file references, etc.
|
||||
message_metadata = Column(JSON)
|
||||
|
||||
# @Mention tracking - stores array of mentioned user_ids
|
||||
mentions = Column(JSON, default=list)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
edited_at = Column(DateTime) # Last edit timestamp
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.database import get_db, get_db_context
|
||||
from app.core.config import get_settings
|
||||
from app.modules.auth.dependencies import get_current_user
|
||||
from app.modules.auth.services.session_service import session_service
|
||||
@@ -87,15 +87,17 @@ async def websocket_endpoint(
|
||||
3. Connection added to pool
|
||||
4. User joined event broadcast to room
|
||||
5. Client can send/receive messages
|
||||
|
||||
Note: Uses short-lived database sessions for each operation to prevent
|
||||
connection pool exhaustion with many concurrent WebSocket connections.
|
||||
"""
|
||||
db: Session = next(get_db())
|
||||
|
||||
try:
|
||||
# Authenticate token via session lookup
|
||||
if not token:
|
||||
await websocket.close(code=4001, reason="Authentication required")
|
||||
return
|
||||
# Authenticate and get user info using short-lived session
|
||||
if not token:
|
||||
await websocket.close(code=4001, reason="Authentication required")
|
||||
return
|
||||
|
||||
# Authenticate token and check membership with short session
|
||||
with get_db_context() as db:
|
||||
user_session = session_service.get_session_by_token(db, token)
|
||||
if not user_session:
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
@@ -103,55 +105,59 @@ async def websocket_endpoint(
|
||||
|
||||
user_id = user_session.username
|
||||
|
||||
# Check room membership
|
||||
# Check room membership and cache the role
|
||||
membership = get_user_room_membership(db, room_id, user_id)
|
||||
if not membership and not is_system_admin(user_id):
|
||||
await websocket.close(code=4001, reason="Not a member of this room")
|
||||
return
|
||||
|
||||
# Connect to WebSocket manager
|
||||
conn_info = await manager.connect(websocket, room_id, user_id)
|
||||
# Cache membership role for permission checks (avoid holding DB reference)
|
||||
user_role = membership.role if membership else None
|
||||
|
||||
# Broadcast user joined event
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
SystemMessageBroadcast(
|
||||
event=SystemEventType.USER_JOINED,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
timestamp=datetime.utcnow()
|
||||
).dict(),
|
||||
exclude_user=user_id
|
||||
)
|
||||
# Connect to WebSocket manager (no DB needed)
|
||||
conn_info = await manager.connect(websocket, room_id, user_id)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Receive message from client
|
||||
data = await websocket.receive_text()
|
||||
message_data = json.loads(data)
|
||||
# Broadcast user joined event (no DB needed)
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
SystemMessageBroadcast(
|
||||
event=SystemEventType.USER_JOINED,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
timestamp=datetime.utcnow()
|
||||
).dict(),
|
||||
exclude_user=user_id
|
||||
)
|
||||
|
||||
# Parse incoming message
|
||||
try:
|
||||
ws_message = WebSocketMessageIn(**message_data)
|
||||
except Exception as e:
|
||||
try:
|
||||
while True:
|
||||
# Receive message from client
|
||||
data = await websocket.receive_text()
|
||||
message_data = json.loads(data)
|
||||
|
||||
# Parse incoming message
|
||||
try:
|
||||
ws_message = WebSocketMessageIn(**message_data)
|
||||
except Exception as e:
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error=str(e), code="INVALID_MESSAGE").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle different message types
|
||||
if ws_message.type == WebSocketMessageType.MESSAGE:
|
||||
# Check write permission using cached role
|
||||
if not _can_write_with_role(user_role, user_id):
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error=str(e), code="INVALID_MESSAGE").dict()
|
||||
ErrorMessage(
|
||||
error="Insufficient permissions",
|
||||
code="PERMISSION_DENIED"
|
||||
).dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle different message types
|
||||
if ws_message.type == WebSocketMessageType.MESSAGE:
|
||||
# Check write permission
|
||||
if not can_write_message(membership, user_id):
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(
|
||||
error="Insufficient permissions",
|
||||
code="PERMISSION_DENIED"
|
||||
).dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Create message in database
|
||||
# Create message in database with short session
|
||||
with get_db_context() as db:
|
||||
message = MessageService.create_message(
|
||||
db=db,
|
||||
room_id=room_id,
|
||||
@@ -160,131 +166,170 @@ async def websocket_endpoint(
|
||||
message_type=MessageType(ws_message.message_type.value) if ws_message.message_type else MessageType.TEXT,
|
||||
metadata=ws_message.metadata
|
||||
)
|
||||
# Get sender display name
|
||||
display_name = MessageService.get_display_name(db, user_id)
|
||||
# Extract data before session closes
|
||||
msg_data = {
|
||||
"message_id": message.message_id,
|
||||
"room_id": message.room_id,
|
||||
"sender_id": message.sender_id,
|
||||
"sender_display_name": display_name or user_id,
|
||||
"content": message.content,
|
||||
"message_type": message.message_type.value,
|
||||
"metadata": message.message_metadata,
|
||||
"created_at": message.created_at,
|
||||
"sequence_number": message.sequence_number
|
||||
}
|
||||
|
||||
# Send acknowledgment to sender
|
||||
# Send acknowledgment to sender
|
||||
await ws_send_json(websocket,
|
||||
MessageAck(
|
||||
message_id=msg_data["message_id"],
|
||||
sequence_number=msg_data["sequence_number"],
|
||||
timestamp=msg_data["created_at"]
|
||||
).dict()
|
||||
)
|
||||
|
||||
# Broadcast message to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
MessageBroadcast(
|
||||
message_id=msg_data["message_id"],
|
||||
room_id=msg_data["room_id"],
|
||||
sender_id=msg_data["sender_id"],
|
||||
sender_display_name=msg_data["sender_display_name"],
|
||||
content=msg_data["content"],
|
||||
message_type=MessageTypeEnum(msg_data["message_type"]),
|
||||
metadata=msg_data["metadata"],
|
||||
created_at=msg_data["created_at"],
|
||||
sequence_number=msg_data["sequence_number"]
|
||||
).dict()
|
||||
)
|
||||
|
||||
elif ws_message.type == WebSocketMessageType.EDIT_MESSAGE:
|
||||
if not ws_message.message_id or not ws_message.content:
|
||||
await ws_send_json(websocket,
|
||||
MessageAck(
|
||||
message_id=message.message_id,
|
||||
sequence_number=message.sequence_number,
|
||||
timestamp=message.created_at
|
||||
).dict()
|
||||
ErrorMessage(error="Missing message_id or content", code="INVALID_REQUEST").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Broadcast message to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
MessageBroadcast(
|
||||
message_id=message.message_id,
|
||||
room_id=message.room_id,
|
||||
sender_id=message.sender_id,
|
||||
content=message.content,
|
||||
message_type=MessageTypeEnum(message.message_type.value),
|
||||
metadata=message.message_metadata,
|
||||
created_at=message.created_at,
|
||||
sequence_number=message.sequence_number
|
||||
).dict()
|
||||
)
|
||||
|
||||
elif ws_message.type == WebSocketMessageType.EDIT_MESSAGE:
|
||||
if not ws_message.message_id or not ws_message.content:
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error="Missing message_id or content", code="INVALID_REQUEST").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Edit message
|
||||
# Edit message with short session
|
||||
with get_db_context() as db:
|
||||
edited_message = MessageService.edit_message(
|
||||
db=db,
|
||||
message_id=ws_message.message_id,
|
||||
user_id=user_id,
|
||||
new_content=ws_message.content
|
||||
)
|
||||
if edited_message:
|
||||
# Get sender display name
|
||||
display_name = MessageService.get_display_name(db, edited_message.sender_id)
|
||||
edit_data = {
|
||||
"message_id": edited_message.message_id,
|
||||
"room_id": edited_message.room_id,
|
||||
"sender_id": edited_message.sender_id,
|
||||
"sender_display_name": display_name or edited_message.sender_id,
|
||||
"content": edited_message.content,
|
||||
"message_type": edited_message.message_type.value,
|
||||
"metadata": edited_message.message_metadata,
|
||||
"created_at": edited_message.created_at,
|
||||
"edited_at": edited_message.edited_at,
|
||||
"sequence_number": edited_message.sequence_number
|
||||
}
|
||||
else:
|
||||
edit_data = None
|
||||
|
||||
if not edited_message:
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error="Cannot edit message", code="EDIT_FAILED").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Broadcast edit to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
MessageBroadcast(
|
||||
type="edit_message",
|
||||
message_id=edited_message.message_id,
|
||||
room_id=edited_message.room_id,
|
||||
sender_id=edited_message.sender_id,
|
||||
content=edited_message.content,
|
||||
message_type=MessageTypeEnum(edited_message.message_type.value),
|
||||
metadata=edited_message.message_metadata,
|
||||
created_at=edited_message.created_at,
|
||||
edited_at=edited_message.edited_at,
|
||||
sequence_number=edited_message.sequence_number
|
||||
).dict()
|
||||
if not edit_data:
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error="Cannot edit message", code="EDIT_FAILED").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
elif ws_message.type == WebSocketMessageType.DELETE_MESSAGE:
|
||||
if not ws_message.message_id:
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error="Missing message_id", code="INVALID_REQUEST").dict()
|
||||
)
|
||||
continue
|
||||
# Broadcast edit to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
MessageBroadcast(
|
||||
type="edit_message",
|
||||
message_id=edit_data["message_id"],
|
||||
room_id=edit_data["room_id"],
|
||||
sender_id=edit_data["sender_id"],
|
||||
sender_display_name=edit_data["sender_display_name"],
|
||||
content=edit_data["content"],
|
||||
message_type=MessageTypeEnum(edit_data["message_type"]),
|
||||
metadata=edit_data["metadata"],
|
||||
created_at=edit_data["created_at"],
|
||||
edited_at=edit_data["edited_at"],
|
||||
sequence_number=edit_data["sequence_number"]
|
||||
).dict()
|
||||
)
|
||||
|
||||
# Delete message
|
||||
elif ws_message.type == WebSocketMessageType.DELETE_MESSAGE:
|
||||
if not ws_message.message_id:
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error="Missing message_id", code="INVALID_REQUEST").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Delete message with short session
|
||||
with get_db_context() as db:
|
||||
deleted_message = MessageService.delete_message(
|
||||
db=db,
|
||||
message_id=ws_message.message_id,
|
||||
user_id=user_id,
|
||||
is_admin=is_system_admin(user_id)
|
||||
)
|
||||
deleted_msg_id = deleted_message.message_id if deleted_message else None
|
||||
|
||||
if not deleted_message:
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error="Cannot delete message", code="DELETE_FAILED").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Broadcast deletion to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{"type": "delete_message", "message_id": deleted_message.message_id}
|
||||
if not deleted_msg_id:
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error="Cannot delete message", code="DELETE_FAILED").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
elif ws_message.type == WebSocketMessageType.ADD_REACTION:
|
||||
if not ws_message.message_id or not ws_message.emoji:
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error="Missing message_id or emoji", code="INVALID_REQUEST").dict()
|
||||
)
|
||||
continue
|
||||
# Broadcast deletion to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{"type": "delete_message", "message_id": deleted_msg_id}
|
||||
)
|
||||
|
||||
# Add reaction
|
||||
elif ws_message.type == WebSocketMessageType.ADD_REACTION:
|
||||
if not ws_message.message_id or not ws_message.emoji:
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error="Missing message_id or emoji", code="INVALID_REQUEST").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Add reaction with short session
|
||||
with get_db_context() as db:
|
||||
reaction = MessageService.add_reaction(
|
||||
db=db,
|
||||
message_id=ws_message.message_id,
|
||||
user_id=user_id,
|
||||
emoji=ws_message.emoji
|
||||
)
|
||||
reaction_added = reaction is not None
|
||||
|
||||
if reaction:
|
||||
# Broadcast reaction to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{
|
||||
"type": "add_reaction",
|
||||
"message_id": ws_message.message_id,
|
||||
"user_id": user_id,
|
||||
"emoji": ws_message.emoji
|
||||
}
|
||||
)
|
||||
if reaction_added:
|
||||
# Broadcast reaction to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{
|
||||
"type": "add_reaction",
|
||||
"message_id": ws_message.message_id,
|
||||
"user_id": user_id,
|
||||
"emoji": ws_message.emoji
|
||||
}
|
||||
)
|
||||
|
||||
elif ws_message.type == WebSocketMessageType.REMOVE_REACTION:
|
||||
if not ws_message.message_id or not ws_message.emoji:
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error="Missing message_id or emoji", code="INVALID_REQUEST").dict()
|
||||
)
|
||||
continue
|
||||
elif ws_message.type == WebSocketMessageType.REMOVE_REACTION:
|
||||
if not ws_message.message_id or not ws_message.emoji:
|
||||
await ws_send_json(websocket,
|
||||
ErrorMessage(error="Missing message_id or emoji", code="INVALID_REQUEST").dict()
|
||||
)
|
||||
continue
|
||||
|
||||
# Remove reaction
|
||||
# Remove reaction with short session
|
||||
with get_db_context() as db:
|
||||
removed = MessageService.remove_reaction(
|
||||
db=db,
|
||||
message_id=ws_message.message_id,
|
||||
@@ -292,47 +337,53 @@ async def websocket_endpoint(
|
||||
emoji=ws_message.emoji
|
||||
)
|
||||
|
||||
if removed:
|
||||
# Broadcast reaction removal to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{
|
||||
"type": "remove_reaction",
|
||||
"message_id": ws_message.message_id,
|
||||
"user_id": user_id,
|
||||
"emoji": ws_message.emoji
|
||||
}
|
||||
)
|
||||
|
||||
elif ws_message.type == WebSocketMessageType.TYPING:
|
||||
# Set typing status
|
||||
is_typing = message_data.get("is_typing", True)
|
||||
await manager.set_typing(room_id, user_id, is_typing)
|
||||
|
||||
# Broadcast typing status to other room members
|
||||
if removed:
|
||||
# Broadcast reaction removal to all room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{"type": "typing", "user_id": user_id, "is_typing": is_typing},
|
||||
exclude_user=user_id
|
||||
{
|
||||
"type": "remove_reaction",
|
||||
"message_id": ws_message.message_id,
|
||||
"user_id": user_id,
|
||||
"emoji": ws_message.emoji
|
||||
}
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
# Disconnect and broadcast user left event
|
||||
await manager.disconnect(conn_info)
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
SystemMessageBroadcast(
|
||||
event=SystemEventType.USER_LEFT,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
timestamp=datetime.utcnow()
|
||||
).dict()
|
||||
)
|
||||
elif ws_message.type == WebSocketMessageType.TYPING:
|
||||
# Set typing status (no DB needed)
|
||||
is_typing = message_data.get("is_typing", True)
|
||||
await manager.set_typing(room_id, user_id, is_typing)
|
||||
|
||||
# Broadcast typing status to other room members
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
{"type": "typing", "user_id": user_id, "is_typing": is_typing},
|
||||
exclude_user=user_id
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
db.close()
|
||||
# Disconnect and broadcast user left event
|
||||
await manager.disconnect(conn_info)
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
SystemMessageBroadcast(
|
||||
event=SystemEventType.USER_LEFT,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
timestamp=datetime.utcnow()
|
||||
).dict()
|
||||
)
|
||||
|
||||
|
||||
def _can_write_with_role(role: Optional[MemberRole], user_id: str) -> bool:
|
||||
"""Check if user has write permission based on cached role"""
|
||||
if is_system_admin(user_id):
|
||||
return True
|
||||
if not role:
|
||||
return False
|
||||
return role in [MemberRole.OWNER, MemberRole.EDITOR]
|
||||
|
||||
|
||||
# REST API endpoints
|
||||
@@ -387,6 +438,10 @@ async def create_message(
|
||||
metadata=message.metadata
|
||||
)
|
||||
|
||||
# Get sender display name
|
||||
display_name = MessageService.get_display_name(db, user_id)
|
||||
sender_display_name = display_name or user_id
|
||||
|
||||
# Broadcast to WebSocket connections
|
||||
await manager.broadcast_to_room(
|
||||
room_id,
|
||||
@@ -394,6 +449,7 @@ async def create_message(
|
||||
message_id=created_message.message_id,
|
||||
room_id=created_message.room_id,
|
||||
sender_id=created_message.sender_id,
|
||||
sender_display_name=sender_display_name,
|
||||
content=created_message.content,
|
||||
message_type=MessageTypeEnum(created_message.message_type.value),
|
||||
metadata=created_message.message_metadata,
|
||||
@@ -402,7 +458,10 @@ async def create_message(
|
||||
).dict()
|
||||
)
|
||||
|
||||
return MessageResponse.from_orm(created_message)
|
||||
# Build response with display name
|
||||
response = MessageResponse.from_orm(created_message)
|
||||
response.sender_display_name = sender_display_name
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/rooms/{room_id}/messages/search", response_model=MessageListResponse)
|
||||
|
||||
@@ -80,6 +80,7 @@ class MessageBroadcast(BaseModel):
|
||||
message_id: str
|
||||
room_id: str
|
||||
sender_id: str
|
||||
sender_display_name: Optional[str] = None # Display name from users table
|
||||
content: str
|
||||
message_type: MessageTypeEnum
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
@@ -141,6 +142,7 @@ class MessageResponse(BaseModel):
|
||||
message_id: str
|
||||
room_id: str
|
||||
sender_id: str
|
||||
sender_display_name: Optional[str] = None # Display name from users table
|
||||
content: str
|
||||
message_type: MessageTypeEnum
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, alias="message_metadata")
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Message service layer for database operations"""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, and_, func
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy import desc, and_, func, text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
import logging
|
||||
import re
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.modules.realtime.models import Message, MessageType, MessageReaction, MessageEditHistory
|
||||
@@ -13,13 +16,48 @@ from app.modules.realtime.schemas import (
|
||||
MessageListResponse,
|
||||
ReactionSummary
|
||||
)
|
||||
from app.modules.auth.models import User
|
||||
|
||||
settings = get_settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageService:
|
||||
"""Service for message operations"""
|
||||
|
||||
@staticmethod
|
||||
def parse_mentions(content: str, room_members: List[Dict[str, str]]) -> List[str]:
|
||||
"""
|
||||
Parse @mentions from message content and resolve to user IDs
|
||||
|
||||
Args:
|
||||
content: Message content that may contain @mentions
|
||||
room_members: List of room members with user_id and display_name
|
||||
|
||||
Returns:
|
||||
List of mentioned user_ids
|
||||
"""
|
||||
# Pattern matches @displayname (alphanumeric, spaces, Chinese chars, etc.)
|
||||
# Captures text after @ until we hit a character that's not part of a name
|
||||
mention_pattern = r'@(\S+)'
|
||||
matches = re.findall(mention_pattern, content)
|
||||
|
||||
mentioned_ids = []
|
||||
for mention_text in matches:
|
||||
# Try to match against display names or user IDs
|
||||
for member in room_members:
|
||||
display_name = member.get('display_name', '') or member.get('user_id', '')
|
||||
user_id = member.get('user_id', '')
|
||||
|
||||
# Match against display_name or user_id (case-insensitive)
|
||||
if (mention_text.lower() == display_name.lower() or
|
||||
mention_text.lower() == user_id.lower()):
|
||||
if user_id not in mentioned_ids:
|
||||
mentioned_ids.append(user_id)
|
||||
break
|
||||
|
||||
return mentioned_ids
|
||||
|
||||
@staticmethod
|
||||
def create_message(
|
||||
db: Session,
|
||||
@@ -27,10 +65,16 @@ class MessageService:
|
||||
sender_id: str,
|
||||
content: str,
|
||||
message_type: MessageType = MessageType.TEXT,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
mentions: Optional[List[str]] = None,
|
||||
max_retries: int = 3
|
||||
) -> Message:
|
||||
"""
|
||||
Create a new message
|
||||
Create a new message with race condition protection
|
||||
|
||||
Uses SELECT ... FOR UPDATE to lock the sequence number calculation,
|
||||
preventing duplicate sequence numbers when multiple users send
|
||||
messages simultaneously.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
@@ -38,34 +82,72 @@ class MessageService:
|
||||
sender_id: User ID who sent the message
|
||||
content: Message content
|
||||
message_type: Type of message
|
||||
metadata: Optional metadata (mentions, file refs, etc.)
|
||||
metadata: Optional metadata (file refs, etc.)
|
||||
mentions: List of mentioned user_ids (parsed from @mentions)
|
||||
max_retries: Maximum retry attempts for deadlock handling
|
||||
|
||||
Returns:
|
||||
Created Message object
|
||||
|
||||
Raises:
|
||||
IntegrityError: If max retries exceeded
|
||||
"""
|
||||
# Get next sequence number for this room
|
||||
max_seq = db.query(func.max(Message.sequence_number)).filter(
|
||||
Message.room_id == room_id
|
||||
).scalar()
|
||||
next_seq = (max_seq or 0) + 1
|
||||
last_error = None
|
||||
|
||||
message = Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
room_id=room_id,
|
||||
sender_id=sender_id,
|
||||
content=content,
|
||||
message_type=message_type,
|
||||
message_metadata=metadata or {},
|
||||
created_at=datetime.utcnow(),
|
||||
sequence_number=next_seq
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Use FOR UPDATE to lock rows while calculating next sequence
|
||||
# This prevents race conditions where two transactions read
|
||||
# the same max_seq and try to insert duplicate sequence numbers
|
||||
result = db.execute(
|
||||
text("""
|
||||
SELECT COALESCE(MAX(sequence_number), 0)
|
||||
FROM tr_messages
|
||||
WHERE room_id = :room_id
|
||||
FOR UPDATE
|
||||
"""),
|
||||
{"room_id": room_id}
|
||||
)
|
||||
max_seq = result.scalar()
|
||||
next_seq = (max_seq or 0) + 1
|
||||
|
||||
message = Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
room_id=room_id,
|
||||
sender_id=sender_id,
|
||||
content=content,
|
||||
message_type=message_type,
|
||||
message_metadata=metadata or {},
|
||||
mentions=mentions or [],
|
||||
created_at=datetime.utcnow(),
|
||||
sequence_number=next_seq
|
||||
)
|
||||
|
||||
db.add(message)
|
||||
db.commit()
|
||||
db.refresh(message)
|
||||
|
||||
return message
|
||||
|
||||
except IntegrityError as e:
|
||||
last_error = e
|
||||
db.rollback()
|
||||
logger.warning(
|
||||
f"Sequence number conflict on attempt {attempt + 1}/{max_retries} "
|
||||
f"for room {room_id}: {e}"
|
||||
)
|
||||
if attempt == max_retries - 1:
|
||||
logger.error(
|
||||
f"Failed to create message after {max_retries} attempts "
|
||||
f"for room {room_id}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Should not reach here, but just in case
|
||||
raise last_error if last_error else IntegrityError(
|
||||
"Failed to create message", None, None
|
||||
)
|
||||
|
||||
db.add(message)
|
||||
db.commit()
|
||||
db.refresh(message)
|
||||
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def get_message(db: Session, message_id: str) -> Optional[Message]:
|
||||
"""
|
||||
@@ -83,6 +165,21 @@ class MessageService:
|
||||
Message.deleted_at.is_(None)
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def get_display_name(db: Session, sender_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get display name for a sender from users table
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
sender_id: User ID (email)
|
||||
|
||||
Returns:
|
||||
Display name or None if not found
|
||||
"""
|
||||
user = db.query(User.display_name).filter(User.user_id == sender_id).first()
|
||||
return user[0] if user else None
|
||||
|
||||
@staticmethod
|
||||
def get_messages(
|
||||
db: Session,
|
||||
@@ -106,7 +203,10 @@ class MessageService:
|
||||
Returns:
|
||||
MessageListResponse with messages and pagination info
|
||||
"""
|
||||
query = db.query(Message).filter(Message.room_id == room_id)
|
||||
# Build base query with LEFT JOIN to users table for display names
|
||||
query = db.query(Message, User.display_name).outerjoin(
|
||||
User, Message.sender_id == User.user_id
|
||||
).filter(Message.room_id == room_id)
|
||||
|
||||
if not include_deleted:
|
||||
query = query.filter(Message.deleted_at.is_(None))
|
||||
@@ -114,18 +214,24 @@ class MessageService:
|
||||
if before_timestamp:
|
||||
query = query.filter(Message.created_at < before_timestamp)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
# Get total count (need separate query without join for accurate count)
|
||||
count_query = db.query(Message).filter(Message.room_id == room_id)
|
||||
if not include_deleted:
|
||||
count_query = count_query.filter(Message.deleted_at.is_(None))
|
||||
if before_timestamp:
|
||||
count_query = count_query.filter(Message.created_at < before_timestamp)
|
||||
total = count_query.count()
|
||||
|
||||
# Get messages in reverse chronological order
|
||||
messages = query.order_by(desc(Message.created_at)).offset(offset).limit(limit).all()
|
||||
# Get messages with display names in reverse chronological order
|
||||
results = query.order_by(desc(Message.created_at)).offset(offset).limit(limit).all()
|
||||
|
||||
# Get reaction counts for each message
|
||||
# Get reaction counts for each message and build responses
|
||||
message_responses = []
|
||||
for msg in messages:
|
||||
for msg, display_name in results:
|
||||
reaction_counts = MessageService._get_reaction_counts(db, msg.message_id)
|
||||
msg_response = MessageResponse.from_orm(msg)
|
||||
msg_response.reaction_counts = reaction_counts
|
||||
msg_response.sender_display_name = display_name or msg.sender_id
|
||||
message_responses.append(msg_response)
|
||||
|
||||
return MessageListResponse(
|
||||
@@ -133,7 +239,7 @@ class MessageService:
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
has_more=(offset + len(messages)) < total
|
||||
has_more=(offset + len(results)) < total
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -253,8 +359,10 @@ class MessageService:
|
||||
|
||||
total = db.query(Message).filter(search_filter).count()
|
||||
|
||||
messages = (
|
||||
db.query(Message)
|
||||
# Query with LEFT JOIN for display names
|
||||
results = (
|
||||
db.query(Message, User.display_name)
|
||||
.outerjoin(User, Message.sender_id == User.user_id)
|
||||
.filter(search_filter)
|
||||
.order_by(desc(Message.created_at))
|
||||
.offset(offset)
|
||||
@@ -263,10 +371,11 @@ class MessageService:
|
||||
)
|
||||
|
||||
message_responses = []
|
||||
for msg in messages:
|
||||
for msg, display_name in results:
|
||||
reaction_counts = MessageService._get_reaction_counts(db, msg.message_id)
|
||||
msg_response = MessageResponse.from_orm(msg)
|
||||
msg_response.reaction_counts = reaction_counts
|
||||
msg_response.sender_display_name = display_name or msg.sender_id
|
||||
message_responses.append(msg_response)
|
||||
|
||||
return MessageListResponse(
|
||||
@@ -274,7 +383,7 @@ class MessageService:
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
has_more=(offset + len(messages)) < total
|
||||
has_more=(offset + len(results)) < total
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user