"""Redis Pub/Sub service for cross-process notification broadcasting.""" import asyncio import json import logging import uuid from datetime import datetime from typing import Optional, Callable, Any import redis.asyncio as aioredis from app.core.config import settings logger = logging.getLogger(__name__) # Redis retry configuration MAX_REDIS_RETRIES = 3 REDIS_RETRY_DELAY = 0.5 # seconds (base delay for exponential backoff) # Global async Redis client for pub/sub _pubsub_redis: Optional[aioredis.Redis] = None def get_channel_name(user_id: str) -> str: """Get the Redis channel name for a user's notifications.""" return f"notifications:{user_id}" def get_project_channel_name(project_id: str) -> str: """Get the Redis channel name for project task events.""" return f"project:{project_id}:tasks" async def get_pubsub_redis() -> aioredis.Redis: """Get or create the async Redis client for pub/sub.""" global _pubsub_redis if _pubsub_redis is None: _pubsub_redis = aioredis.from_url( settings.REDIS_URL, encoding="utf-8", decode_responses=True, ) return _pubsub_redis async def close_pubsub_redis() -> None: """Close the async Redis client.""" global _pubsub_redis if _pubsub_redis is not None: await _pubsub_redis.close() _pubsub_redis = None async def publish_notification(user_id: str, notification: dict) -> bool: """ Publish a notification to a user's channel. Args: user_id: The user ID to send the notification to notification: The notification data (will be JSON serialized) Returns: True if published successfully, False otherwise """ try: redis_client = await get_pubsub_redis() channel = get_channel_name(user_id) message = json.dumps(notification, default=str) await redis_client.publish(channel, message) logger.debug(f"Published notification to channel {channel}") return True except Exception as e: logger.error(f"Failed to publish notification: {e}") return False class NotificationSubscriber: """ Subscriber for user notification channels. Used by WebSocket connections to receive real-time updates. """ def __init__(self, user_id: str): self.user_id = user_id self.channel = get_channel_name(user_id) self.pubsub: Optional[aioredis.client.PubSub] = None self._running = False async def start(self) -> None: """Start subscribing to the user's notification channel.""" redis_client = await get_pubsub_redis() self.pubsub = redis_client.pubsub() await self.pubsub.subscribe(self.channel) self._running = True logger.debug(f"Subscribed to channel {self.channel}") async def stop(self) -> None: """Stop subscribing and clean up.""" self._running = False if self.pubsub: await self.pubsub.unsubscribe(self.channel) await self.pubsub.close() self.pubsub = None logger.debug(f"Unsubscribed from channel {self.channel}") async def listen(self, callback: Callable[[dict], Any]) -> None: """ Listen for messages and call the callback for each notification. Args: callback: Async function to call with each notification dict """ if not self.pubsub: raise RuntimeError("Subscriber not started. Call start() first.") try: async for message in self.pubsub.listen(): if not self._running: break if message["type"] == "message": try: data = json.loads(message["data"]) await callback(data) except json.JSONDecodeError: logger.warning(f"Invalid JSON in notification: {message['data']}") except Exception as e: logger.error(f"Error processing notification: {e}") except Exception as e: if self._running: logger.error(f"Error in notification listener: {e}") @property def is_running(self) -> bool: return self._running async def _reset_pubsub_redis() -> None: """Reset the Redis connection on failure.""" global _pubsub_redis if _pubsub_redis is not None: try: await _pubsub_redis.close() except Exception: pass _pubsub_redis = None async def publish_task_event( project_id: str, event_type: str, task_data: dict, triggered_by: str ) -> bool: """ Publish a task event to a project's channel with retry logic. Args: project_id: The project ID event_type: Event type (task_created, task_updated, task_status_changed, task_deleted, task_assigned) task_data: The task data to include in the event triggered_by: User ID who triggered this event Returns: True if published successfully, False otherwise """ channel = get_project_channel_name(project_id) message = json.dumps({ "type": event_type, "event_id": str(uuid.uuid4()), # Unique event ID for multi-tab deduplication "data": task_data, "triggered_by": triggered_by, "timestamp": datetime.utcnow().isoformat(), }, default=str) for attempt in range(MAX_REDIS_RETRIES): try: redis_client = await get_pubsub_redis() # Test connection with ping before publishing await redis_client.ping() await redis_client.publish(channel, message) logger.debug(f"Published task event '{event_type}' to channel {channel}") return True except Exception as e: logger.warning(f"Redis publish attempt {attempt + 1}/{MAX_REDIS_RETRIES} failed: {e}") if attempt < MAX_REDIS_RETRIES - 1: # Exponential backoff await asyncio.sleep(REDIS_RETRY_DELAY * (attempt + 1)) # Reset connection on failure await _reset_pubsub_redis() else: logger.error(f"Failed to publish task event '{event_type}' after {MAX_REDIS_RETRIES} attempts") return False return False class ProjectTaskSubscriber: """ Subscriber for project task events via Redis Pub/Sub. Used by WebSocket connections to receive real-time task updates. Includes automatic reconnection handling. """ def __init__(self, project_id: str): self.project_id = project_id self.channel = get_project_channel_name(project_id) self.pubsub: Optional[aioredis.client.PubSub] = None self._running = False self._reconnect_attempts = 0 async def start(self) -> None: """Start subscribing to the project's task channel with retry logic.""" for attempt in range(MAX_REDIS_RETRIES): try: redis_client = await get_pubsub_redis() # Test connection health await redis_client.ping() self.pubsub = redis_client.pubsub() await self.pubsub.subscribe(self.channel) self._running = True self._reconnect_attempts = 0 logger.debug(f"Subscribed to project task channel {self.channel}") return except Exception as e: logger.warning(f"Redis subscribe attempt {attempt + 1}/{MAX_REDIS_RETRIES} failed: {e}") if attempt < MAX_REDIS_RETRIES - 1: await asyncio.sleep(REDIS_RETRY_DELAY * (attempt + 1)) await _reset_pubsub_redis() else: logger.error(f"Failed to subscribe to channel {self.channel} after {MAX_REDIS_RETRIES} attempts") raise async def _reconnect(self) -> bool: """Attempt to reconnect to Redis and resubscribe.""" self._reconnect_attempts += 1 if self._reconnect_attempts > MAX_REDIS_RETRIES: logger.error(f"Max reconnection attempts reached for channel {self.channel}") return False logger.info(f"Attempting to reconnect to Redis (attempt {self._reconnect_attempts}/{MAX_REDIS_RETRIES})") # Clean up old pubsub if self.pubsub: try: await self.pubsub.close() except Exception: pass self.pubsub = None # Reset global connection await _reset_pubsub_redis() # Wait with exponential backoff await asyncio.sleep(REDIS_RETRY_DELAY * self._reconnect_attempts) try: redis_client = await get_pubsub_redis() await redis_client.ping() self.pubsub = redis_client.pubsub() await self.pubsub.subscribe(self.channel) self._reconnect_attempts = 0 logger.info(f"Successfully reconnected to channel {self.channel}") return True except Exception as e: logger.warning(f"Reconnection attempt failed: {e}") return False async def stop(self) -> None: """Stop subscribing and clean up.""" self._running = False if self.pubsub: try: await self.pubsub.unsubscribe(self.channel) await self.pubsub.close() except Exception as e: logger.warning(f"Error during pubsub cleanup: {e}") self.pubsub = None logger.debug(f"Unsubscribed from project task channel {self.channel}") async def listen(self, callback: Callable[[dict], Any]) -> None: """ Listen for task events and call the callback for each event. Includes automatic reconnection on connection failures. Args: callback: Async function to call with each task event dict. The dict contains: type, data, triggered_by """ if not self.pubsub: raise RuntimeError("Subscriber not started. Call start() first.") while self._running: try: async for message in self.pubsub.listen(): if not self._running: break if message["type"] == "message": try: data = json.loads(message["data"]) await callback(data) except json.JSONDecodeError: logger.warning(f"Invalid JSON in task event: {message['data']}") except Exception as e: logger.error(f"Error processing task event: {e}") except Exception as e: if not self._running: break logger.warning(f"Redis connection error in task listener: {e}") # Attempt to reconnect if await self._reconnect(): continue # Resume listening after successful reconnection else: logger.error(f"Failed to recover connection for channel {self.channel}") break @property def is_running(self) -> bool: return self._running