"""Message service layer for database operations""" from sqlalchemy.orm import Session from sqlalchemy import desc, and_, func, text from sqlalchemy.exc import IntegrityError from typing import List, Optional, Dict, Any, Tuple from datetime import datetime, timedelta import uuid import logging import re from app.core.config import get_settings from app.modules.realtime.models import Message, MessageType, MessageReaction, MessageEditHistory from app.modules.realtime.schemas import ( MessageCreate, MessageResponse, MessageListResponse, ReactionSummary ) from app.modules.auth.models import User settings = get_settings() logger = logging.getLogger(__name__) class MessageService: """Service for message operations""" @staticmethod def parse_mentions(content: str, room_members: List[Dict[str, str]]) -> List[str]: """ Parse @mentions from message content and resolve to user IDs Args: content: Message content that may contain @mentions room_members: List of room members with user_id and display_name Returns: List of mentioned user_ids """ # Pattern matches @displayname (alphanumeric, spaces, Chinese chars, etc.) # Captures text after @ until we hit a character that's not part of a name mention_pattern = r'@(\S+)' matches = re.findall(mention_pattern, content) mentioned_ids = [] for mention_text in matches: # Try to match against display names or user IDs for member in room_members: display_name = member.get('display_name', '') or member.get('user_id', '') user_id = member.get('user_id', '') # Match against display_name or user_id (case-insensitive) if (mention_text.lower() == display_name.lower() or mention_text.lower() == user_id.lower()): if user_id not in mentioned_ids: mentioned_ids.append(user_id) break return mentioned_ids @staticmethod def create_message( db: Session, room_id: str, sender_id: str, content: str, message_type: MessageType = MessageType.TEXT, metadata: Optional[Dict[str, Any]] = None, mentions: Optional[List[str]] = None, max_retries: int = 3 ) -> Message: """ Create a new message with race condition protection Uses SELECT ... FOR UPDATE to lock the sequence number calculation, preventing duplicate sequence numbers when multiple users send messages simultaneously. Args: db: Database session room_id: Room ID sender_id: User ID who sent the message content: Message content message_type: Type of message metadata: Optional metadata (file refs, etc.) mentions: List of mentioned user_ids (parsed from @mentions) max_retries: Maximum retry attempts for deadlock handling Returns: Created Message object Raises: IntegrityError: If max retries exceeded """ last_error = None for attempt in range(max_retries): try: # Use FOR UPDATE to lock rows while calculating next sequence # This prevents race conditions where two transactions read # the same max_seq and try to insert duplicate sequence numbers result = db.execute( text(""" SELECT COALESCE(MAX(sequence_number), 0) FROM tr_messages WHERE room_id = :room_id FOR UPDATE """), {"room_id": room_id} ) max_seq = result.scalar() next_seq = (max_seq or 0) + 1 message = Message( message_id=str(uuid.uuid4()), room_id=room_id, sender_id=sender_id, content=content, message_type=message_type, message_metadata=metadata or {}, mentions=mentions or [], created_at=datetime.utcnow(), sequence_number=next_seq ) db.add(message) db.commit() db.refresh(message) return message except IntegrityError as e: last_error = e db.rollback() logger.warning( f"Sequence number conflict on attempt {attempt + 1}/{max_retries} " f"for room {room_id}: {e}" ) if attempt == max_retries - 1: logger.error( f"Failed to create message after {max_retries} attempts " f"for room {room_id}" ) raise # Should not reach here, but just in case raise last_error if last_error else IntegrityError( "Failed to create message", None, None ) @staticmethod def get_message(db: Session, message_id: str) -> Optional[Message]: """ Get a message by ID Args: db: Database session message_id: Message ID Returns: Message object or None """ return db.query(Message).filter( Message.message_id == message_id, Message.deleted_at.is_(None) ).first() @staticmethod def get_display_name(db: Session, sender_id: str) -> Optional[str]: """ Get display name for a sender from users table Args: db: Database session sender_id: User ID (email) Returns: Display name or None if not found """ user = db.query(User.display_name).filter(User.user_id == sender_id).first() return user[0] if user else None @staticmethod def get_messages( db: Session, room_id: str, limit: int = 50, before_timestamp: Optional[datetime] = None, offset: int = 0, include_deleted: bool = False ) -> MessageListResponse: """ Get paginated messages for a room Args: db: Database session room_id: Room ID limit: Number of messages to return before_timestamp: Get messages before this timestamp offset: Offset for pagination include_deleted: Include soft-deleted messages Returns: MessageListResponse with messages and pagination info """ # Build base query with LEFT JOIN to users table for display names query = db.query(Message, User.display_name).outerjoin( User, Message.sender_id == User.user_id ).filter(Message.room_id == room_id) if not include_deleted: query = query.filter(Message.deleted_at.is_(None)) if before_timestamp: query = query.filter(Message.created_at < before_timestamp) # Get total count (need separate query without join for accurate count) count_query = db.query(Message).filter(Message.room_id == room_id) if not include_deleted: count_query = count_query.filter(Message.deleted_at.is_(None)) if before_timestamp: count_query = count_query.filter(Message.created_at < before_timestamp) total = count_query.count() # Get messages with display names in reverse chronological order results = query.order_by(desc(Message.created_at)).offset(offset).limit(limit).all() # Get reaction counts for each message and build responses message_responses = [] for msg, display_name in results: reaction_counts = MessageService._get_reaction_counts(db, msg.message_id) msg_response = MessageResponse.from_orm(msg) msg_response.reaction_counts = reaction_counts msg_response.sender_display_name = display_name or msg.sender_id message_responses.append(msg_response) return MessageListResponse( messages=message_responses, total=total, limit=limit, offset=offset, has_more=(offset + len(results)) < total ) @staticmethod def edit_message( db: Session, message_id: str, user_id: str, new_content: str ) -> Optional[Message]: """ Edit a message (must be own message and within 15 minutes) Args: db: Database session message_id: Message ID to edit user_id: User ID making the edit new_content: New message content Returns: Updated Message object or None if not allowed """ message = db.query(Message).filter(Message.message_id == message_id).first() if not message: return None # Check permissions if message.sender_id != user_id: return None # Check time limit (configurable via MESSAGE_EDIT_TIME_LIMIT_MINUTES) time_diff = datetime.utcnow() - message.created_at if time_diff > timedelta(minutes=settings.MESSAGE_EDIT_TIME_LIMIT_MINUTES): return None # Store original content in edit history edit_history = MessageEditHistory( message_id=message_id, original_content=message.content, edited_by=user_id, edited_at=datetime.utcnow() ) db.add(edit_history) # Update message message.content = new_content message.edited_at = datetime.utcnow() db.commit() db.refresh(message) return message @staticmethod def delete_message( db: Session, message_id: str, user_id: str, is_admin: bool = False ) -> Optional[Message]: """ Soft delete a message Args: db: Database session message_id: Message ID to delete user_id: User ID making the deletion is_admin: Whether user is admin (can delete any message) Returns: Deleted Message object or None if not allowed """ message = db.query(Message).filter(Message.message_id == message_id).first() if not message: return None # Check permissions (owner or admin) if not is_admin and message.sender_id != user_id: return None # Soft delete message.deleted_at = datetime.utcnow() db.commit() db.refresh(message) return message @staticmethod def search_messages( db: Session, room_id: str, query: str, limit: int = 50, offset: int = 0 ) -> MessageListResponse: """ Search messages by content Args: db: Database session room_id: Room ID to search in query: Search query limit: Number of results offset: Offset for pagination Returns: MessageListResponse with search results """ # Simple LIKE search (for PostgreSQL, use full-text search) search_filter = and_( Message.room_id == room_id, Message.deleted_at.is_(None), Message.content.contains(query) ) total = db.query(Message).filter(search_filter).count() # Query with LEFT JOIN for display names results = ( db.query(Message, User.display_name) .outerjoin(User, Message.sender_id == User.user_id) .filter(search_filter) .order_by(desc(Message.created_at)) .offset(offset) .limit(limit) .all() ) message_responses = [] for msg, display_name in results: reaction_counts = MessageService._get_reaction_counts(db, msg.message_id) msg_response = MessageResponse.from_orm(msg) msg_response.reaction_counts = reaction_counts msg_response.sender_display_name = display_name or msg.sender_id message_responses.append(msg_response) return MessageListResponse( messages=message_responses, total=total, limit=limit, offset=offset, has_more=(offset + len(results)) < total ) @staticmethod def add_reaction( db: Session, message_id: str, user_id: str, emoji: str ) -> Optional[MessageReaction]: """ Add a reaction to a message Args: db: Database session message_id: Message ID user_id: User ID adding reaction emoji: Emoji character Returns: MessageReaction object or None if already exists """ # Check if reaction already exists existing = db.query(MessageReaction).filter( and_( MessageReaction.message_id == message_id, MessageReaction.user_id == user_id, MessageReaction.emoji == emoji ) ).first() if existing: return existing reaction = MessageReaction( message_id=message_id, user_id=user_id, emoji=emoji, created_at=datetime.utcnow() ) db.add(reaction) db.commit() db.refresh(reaction) return reaction @staticmethod def remove_reaction( db: Session, message_id: str, user_id: str, emoji: str ) -> bool: """ Remove a reaction from a message Args: db: Database session message_id: Message ID user_id: User ID removing reaction emoji: Emoji character Returns: True if removed, False if not found """ reaction = db.query(MessageReaction).filter( and_( MessageReaction.message_id == message_id, MessageReaction.user_id == user_id, MessageReaction.emoji == emoji ) ).first() if not reaction: return False db.delete(reaction) db.commit() return True @staticmethod def get_message_reactions( db: Session, message_id: str ) -> List[ReactionSummary]: """ Get aggregated reactions for a message Args: db: Database session message_id: Message ID Returns: List of ReactionSummary objects """ reactions = db.query(MessageReaction).filter( MessageReaction.message_id == message_id ).all() # Group by emoji reaction_map: Dict[str, List[str]] = {} for reaction in reactions: if reaction.emoji not in reaction_map: reaction_map[reaction.emoji] = [] reaction_map[reaction.emoji].append(reaction.user_id) return [ ReactionSummary(emoji=emoji, count=len(users), users=users) for emoji, users in reaction_map.items() ] @staticmethod def _get_reaction_counts(db: Session, message_id: str) -> Dict[str, int]: """ Get reaction counts for a message Args: db: Database session message_id: Message ID Returns: Dictionary of emoji -> count """ result = ( db.query(MessageReaction.emoji, func.count(MessageReaction.reaction_id)) .filter(MessageReaction.message_id == message_id) .group_by(MessageReaction.emoji) .all() ) return {emoji: count for emoji, count in result}