- Add ActionBar component with expandable toolbar for mobile - Add @mention functionality with autocomplete dropdown - Add browser notification system (push, sound, vibration) - Add NotificationSettings modal for user preferences - Add mention badges on room list cards - Add ReportPreview with Markdown rendering and copy/download - Add message copy functionality with hover actions - Add backend mentions field to messages with Alembic migration - Add lots field to rooms, remove templates - Optimize WebSocket database session handling - Various UX polish (animations, accessibility) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
519 lines
16 KiB
Python
519 lines
16 KiB
Python
"""Message service layer for database operations"""
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import desc, and_, func, text
|
|
from sqlalchemy.exc import IntegrityError
|
|
from typing import List, Optional, Dict, Any, Tuple
|
|
from datetime import datetime, timedelta
|
|
import uuid
|
|
import logging
|
|
import re
|
|
|
|
from app.core.config import get_settings
|
|
from app.modules.realtime.models import Message, MessageType, MessageReaction, MessageEditHistory
|
|
from app.modules.realtime.schemas import (
|
|
MessageCreate,
|
|
MessageResponse,
|
|
MessageListResponse,
|
|
ReactionSummary
|
|
)
|
|
from app.modules.auth.models import User
|
|
|
|
settings = get_settings()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MessageService:
|
|
"""Service for message operations"""
|
|
|
|
@staticmethod
|
|
def parse_mentions(content: str, room_members: List[Dict[str, str]]) -> List[str]:
|
|
"""
|
|
Parse @mentions from message content and resolve to user IDs
|
|
|
|
Args:
|
|
content: Message content that may contain @mentions
|
|
room_members: List of room members with user_id and display_name
|
|
|
|
Returns:
|
|
List of mentioned user_ids
|
|
"""
|
|
# Pattern matches @displayname (alphanumeric, spaces, Chinese chars, etc.)
|
|
# Captures text after @ until we hit a character that's not part of a name
|
|
mention_pattern = r'@(\S+)'
|
|
matches = re.findall(mention_pattern, content)
|
|
|
|
mentioned_ids = []
|
|
for mention_text in matches:
|
|
# Try to match against display names or user IDs
|
|
for member in room_members:
|
|
display_name = member.get('display_name', '') or member.get('user_id', '')
|
|
user_id = member.get('user_id', '')
|
|
|
|
# Match against display_name or user_id (case-insensitive)
|
|
if (mention_text.lower() == display_name.lower() or
|
|
mention_text.lower() == user_id.lower()):
|
|
if user_id not in mentioned_ids:
|
|
mentioned_ids.append(user_id)
|
|
break
|
|
|
|
return mentioned_ids
|
|
|
|
@staticmethod
|
|
def create_message(
|
|
db: Session,
|
|
room_id: str,
|
|
sender_id: str,
|
|
content: str,
|
|
message_type: MessageType = MessageType.TEXT,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
mentions: Optional[List[str]] = None,
|
|
max_retries: int = 3
|
|
) -> Message:
|
|
"""
|
|
Create a new message with race condition protection
|
|
|
|
Uses SELECT ... FOR UPDATE to lock the sequence number calculation,
|
|
preventing duplicate sequence numbers when multiple users send
|
|
messages simultaneously.
|
|
|
|
Args:
|
|
db: Database session
|
|
room_id: Room ID
|
|
sender_id: User ID who sent the message
|
|
content: Message content
|
|
message_type: Type of message
|
|
metadata: Optional metadata (file refs, etc.)
|
|
mentions: List of mentioned user_ids (parsed from @mentions)
|
|
max_retries: Maximum retry attempts for deadlock handling
|
|
|
|
Returns:
|
|
Created Message object
|
|
|
|
Raises:
|
|
IntegrityError: If max retries exceeded
|
|
"""
|
|
last_error = None
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
# Use FOR UPDATE to lock rows while calculating next sequence
|
|
# This prevents race conditions where two transactions read
|
|
# the same max_seq and try to insert duplicate sequence numbers
|
|
result = db.execute(
|
|
text("""
|
|
SELECT COALESCE(MAX(sequence_number), 0)
|
|
FROM tr_messages
|
|
WHERE room_id = :room_id
|
|
FOR UPDATE
|
|
"""),
|
|
{"room_id": room_id}
|
|
)
|
|
max_seq = result.scalar()
|
|
next_seq = (max_seq or 0) + 1
|
|
|
|
message = Message(
|
|
message_id=str(uuid.uuid4()),
|
|
room_id=room_id,
|
|
sender_id=sender_id,
|
|
content=content,
|
|
message_type=message_type,
|
|
message_metadata=metadata or {},
|
|
mentions=mentions or [],
|
|
created_at=datetime.utcnow(),
|
|
sequence_number=next_seq
|
|
)
|
|
|
|
db.add(message)
|
|
db.commit()
|
|
db.refresh(message)
|
|
|
|
return message
|
|
|
|
except IntegrityError as e:
|
|
last_error = e
|
|
db.rollback()
|
|
logger.warning(
|
|
f"Sequence number conflict on attempt {attempt + 1}/{max_retries} "
|
|
f"for room {room_id}: {e}"
|
|
)
|
|
if attempt == max_retries - 1:
|
|
logger.error(
|
|
f"Failed to create message after {max_retries} attempts "
|
|
f"for room {room_id}"
|
|
)
|
|
raise
|
|
|
|
# Should not reach here, but just in case
|
|
raise last_error if last_error else IntegrityError(
|
|
"Failed to create message", None, None
|
|
)
|
|
|
|
@staticmethod
|
|
def get_message(db: Session, message_id: str) -> Optional[Message]:
|
|
"""
|
|
Get a message by ID
|
|
|
|
Args:
|
|
db: Database session
|
|
message_id: Message ID
|
|
|
|
Returns:
|
|
Message object or None
|
|
"""
|
|
return db.query(Message).filter(
|
|
Message.message_id == message_id,
|
|
Message.deleted_at.is_(None)
|
|
).first()
|
|
|
|
@staticmethod
|
|
def get_display_name(db: Session, sender_id: str) -> Optional[str]:
|
|
"""
|
|
Get display name for a sender from users table
|
|
|
|
Args:
|
|
db: Database session
|
|
sender_id: User ID (email)
|
|
|
|
Returns:
|
|
Display name or None if not found
|
|
"""
|
|
user = db.query(User.display_name).filter(User.user_id == sender_id).first()
|
|
return user[0] if user else None
|
|
|
|
@staticmethod
|
|
def get_messages(
|
|
db: Session,
|
|
room_id: str,
|
|
limit: int = 50,
|
|
before_timestamp: Optional[datetime] = None,
|
|
offset: int = 0,
|
|
include_deleted: bool = False
|
|
) -> MessageListResponse:
|
|
"""
|
|
Get paginated messages for a room
|
|
|
|
Args:
|
|
db: Database session
|
|
room_id: Room ID
|
|
limit: Number of messages to return
|
|
before_timestamp: Get messages before this timestamp
|
|
offset: Offset for pagination
|
|
include_deleted: Include soft-deleted messages
|
|
|
|
Returns:
|
|
MessageListResponse with messages and pagination info
|
|
"""
|
|
# Build base query with LEFT JOIN to users table for display names
|
|
query = db.query(Message, User.display_name).outerjoin(
|
|
User, Message.sender_id == User.user_id
|
|
).filter(Message.room_id == room_id)
|
|
|
|
if not include_deleted:
|
|
query = query.filter(Message.deleted_at.is_(None))
|
|
|
|
if before_timestamp:
|
|
query = query.filter(Message.created_at < before_timestamp)
|
|
|
|
# Get total count (need separate query without join for accurate count)
|
|
count_query = db.query(Message).filter(Message.room_id == room_id)
|
|
if not include_deleted:
|
|
count_query = count_query.filter(Message.deleted_at.is_(None))
|
|
if before_timestamp:
|
|
count_query = count_query.filter(Message.created_at < before_timestamp)
|
|
total = count_query.count()
|
|
|
|
# Get messages with display names in reverse chronological order
|
|
results = query.order_by(desc(Message.created_at)).offset(offset).limit(limit).all()
|
|
|
|
# Get reaction counts for each message and build responses
|
|
message_responses = []
|
|
for msg, display_name in results:
|
|
reaction_counts = MessageService._get_reaction_counts(db, msg.message_id)
|
|
msg_response = MessageResponse.from_orm(msg)
|
|
msg_response.reaction_counts = reaction_counts
|
|
msg_response.sender_display_name = display_name or msg.sender_id
|
|
message_responses.append(msg_response)
|
|
|
|
return MessageListResponse(
|
|
messages=message_responses,
|
|
total=total,
|
|
limit=limit,
|
|
offset=offset,
|
|
has_more=(offset + len(results)) < total
|
|
)
|
|
|
|
@staticmethod
|
|
def edit_message(
|
|
db: Session,
|
|
message_id: str,
|
|
user_id: str,
|
|
new_content: str
|
|
) -> Optional[Message]:
|
|
"""
|
|
Edit a message (must be own message and within 15 minutes)
|
|
|
|
Args:
|
|
db: Database session
|
|
message_id: Message ID to edit
|
|
user_id: User ID making the edit
|
|
new_content: New message content
|
|
|
|
Returns:
|
|
Updated Message object or None if not allowed
|
|
"""
|
|
message = db.query(Message).filter(Message.message_id == message_id).first()
|
|
|
|
if not message:
|
|
return None
|
|
|
|
# Check permissions
|
|
if message.sender_id != user_id:
|
|
return None
|
|
|
|
# Check time limit (configurable via MESSAGE_EDIT_TIME_LIMIT_MINUTES)
|
|
time_diff = datetime.utcnow() - message.created_at
|
|
if time_diff > timedelta(minutes=settings.MESSAGE_EDIT_TIME_LIMIT_MINUTES):
|
|
return None
|
|
|
|
# Store original content in edit history
|
|
edit_history = MessageEditHistory(
|
|
message_id=message_id,
|
|
original_content=message.content,
|
|
edited_by=user_id,
|
|
edited_at=datetime.utcnow()
|
|
)
|
|
db.add(edit_history)
|
|
|
|
# Update message
|
|
message.content = new_content
|
|
message.edited_at = datetime.utcnow()
|
|
|
|
db.commit()
|
|
db.refresh(message)
|
|
|
|
return message
|
|
|
|
@staticmethod
|
|
def delete_message(
|
|
db: Session,
|
|
message_id: str,
|
|
user_id: str,
|
|
is_admin: bool = False
|
|
) -> Optional[Message]:
|
|
"""
|
|
Soft delete a message
|
|
|
|
Args:
|
|
db: Database session
|
|
message_id: Message ID to delete
|
|
user_id: User ID making the deletion
|
|
is_admin: Whether user is admin (can delete any message)
|
|
|
|
Returns:
|
|
Deleted Message object or None if not allowed
|
|
"""
|
|
message = db.query(Message).filter(Message.message_id == message_id).first()
|
|
|
|
if not message:
|
|
return None
|
|
|
|
# Check permissions (owner or admin)
|
|
if not is_admin and message.sender_id != user_id:
|
|
return None
|
|
|
|
# Soft delete
|
|
message.deleted_at = datetime.utcnow()
|
|
|
|
db.commit()
|
|
db.refresh(message)
|
|
|
|
return message
|
|
|
|
@staticmethod
|
|
def search_messages(
|
|
db: Session,
|
|
room_id: str,
|
|
query: str,
|
|
limit: int = 50,
|
|
offset: int = 0
|
|
) -> MessageListResponse:
|
|
"""
|
|
Search messages by content
|
|
|
|
Args:
|
|
db: Database session
|
|
room_id: Room ID to search in
|
|
query: Search query
|
|
limit: Number of results
|
|
offset: Offset for pagination
|
|
|
|
Returns:
|
|
MessageListResponse with search results
|
|
"""
|
|
# Simple LIKE search (for PostgreSQL, use full-text search)
|
|
search_filter = and_(
|
|
Message.room_id == room_id,
|
|
Message.deleted_at.is_(None),
|
|
Message.content.contains(query)
|
|
)
|
|
|
|
total = db.query(Message).filter(search_filter).count()
|
|
|
|
# Query with LEFT JOIN for display names
|
|
results = (
|
|
db.query(Message, User.display_name)
|
|
.outerjoin(User, Message.sender_id == User.user_id)
|
|
.filter(search_filter)
|
|
.order_by(desc(Message.created_at))
|
|
.offset(offset)
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
message_responses = []
|
|
for msg, display_name in results:
|
|
reaction_counts = MessageService._get_reaction_counts(db, msg.message_id)
|
|
msg_response = MessageResponse.from_orm(msg)
|
|
msg_response.reaction_counts = reaction_counts
|
|
msg_response.sender_display_name = display_name or msg.sender_id
|
|
message_responses.append(msg_response)
|
|
|
|
return MessageListResponse(
|
|
messages=message_responses,
|
|
total=total,
|
|
limit=limit,
|
|
offset=offset,
|
|
has_more=(offset + len(results)) < total
|
|
)
|
|
|
|
@staticmethod
|
|
def add_reaction(
|
|
db: Session,
|
|
message_id: str,
|
|
user_id: str,
|
|
emoji: str
|
|
) -> Optional[MessageReaction]:
|
|
"""
|
|
Add a reaction to a message
|
|
|
|
Args:
|
|
db: Database session
|
|
message_id: Message ID
|
|
user_id: User ID adding reaction
|
|
emoji: Emoji character
|
|
|
|
Returns:
|
|
MessageReaction object or None if already exists
|
|
"""
|
|
# Check if reaction already exists
|
|
existing = db.query(MessageReaction).filter(
|
|
and_(
|
|
MessageReaction.message_id == message_id,
|
|
MessageReaction.user_id == user_id,
|
|
MessageReaction.emoji == emoji
|
|
)
|
|
).first()
|
|
|
|
if existing:
|
|
return existing
|
|
|
|
reaction = MessageReaction(
|
|
message_id=message_id,
|
|
user_id=user_id,
|
|
emoji=emoji,
|
|
created_at=datetime.utcnow()
|
|
)
|
|
|
|
db.add(reaction)
|
|
db.commit()
|
|
db.refresh(reaction)
|
|
|
|
return reaction
|
|
|
|
@staticmethod
|
|
def remove_reaction(
|
|
db: Session,
|
|
message_id: str,
|
|
user_id: str,
|
|
emoji: str
|
|
) -> bool:
|
|
"""
|
|
Remove a reaction from a message
|
|
|
|
Args:
|
|
db: Database session
|
|
message_id: Message ID
|
|
user_id: User ID removing reaction
|
|
emoji: Emoji character
|
|
|
|
Returns:
|
|
True if removed, False if not found
|
|
"""
|
|
reaction = db.query(MessageReaction).filter(
|
|
and_(
|
|
MessageReaction.message_id == message_id,
|
|
MessageReaction.user_id == user_id,
|
|
MessageReaction.emoji == emoji
|
|
)
|
|
).first()
|
|
|
|
if not reaction:
|
|
return False
|
|
|
|
db.delete(reaction)
|
|
db.commit()
|
|
|
|
return True
|
|
|
|
@staticmethod
|
|
def get_message_reactions(
|
|
db: Session,
|
|
message_id: str
|
|
) -> List[ReactionSummary]:
|
|
"""
|
|
Get aggregated reactions for a message
|
|
|
|
Args:
|
|
db: Database session
|
|
message_id: Message ID
|
|
|
|
Returns:
|
|
List of ReactionSummary objects
|
|
"""
|
|
reactions = db.query(MessageReaction).filter(
|
|
MessageReaction.message_id == message_id
|
|
).all()
|
|
|
|
# Group by emoji
|
|
reaction_map: Dict[str, List[str]] = {}
|
|
for reaction in reactions:
|
|
if reaction.emoji not in reaction_map:
|
|
reaction_map[reaction.emoji] = []
|
|
reaction_map[reaction.emoji].append(reaction.user_id)
|
|
|
|
return [
|
|
ReactionSummary(emoji=emoji, count=len(users), users=users)
|
|
for emoji, users in reaction_map.items()
|
|
]
|
|
|
|
@staticmethod
|
|
def _get_reaction_counts(db: Session, message_id: str) -> Dict[str, int]:
|
|
"""
|
|
Get reaction counts for a message
|
|
|
|
Args:
|
|
db: Database session
|
|
message_id: Message ID
|
|
|
|
Returns:
|
|
Dictionary of emoji -> count
|
|
"""
|
|
result = (
|
|
db.query(MessageReaction.emoji, func.count(MessageReaction.reaction_id))
|
|
.filter(MessageReaction.message_id == message_id)
|
|
.group_by(MessageReaction.emoji)
|
|
.all()
|
|
)
|
|
|
|
return {emoji: count for emoji, count in result}
|