Files
PROJECT-CONTORL/backend/app/services/notification_service.py
beabigegg 679b89ae4c feat: implement security, error resilience, and query optimization proposals
Security Validation (enhance-security-validation):
- JWT secret validation with entropy checking and pattern detection
- CSRF protection middleware with token generation/validation
- Frontend CSRF token auto-injection for DELETE/PUT/PATCH requests
- MIME type validation with magic bytes detection for file uploads

Error Resilience (add-error-resilience):
- React ErrorBoundary component with fallback UI and retry functionality
- ErrorBoundaryWithI18n wrapper for internationalization support
- Page-level and section-level error boundaries in App.tsx

Query Performance (optimize-query-performance):
- Query monitoring utility with threshold warnings
- N+1 query fixes using joinedload/selectinload
- Optimized project members, tasks, and subtasks endpoints

Bug Fixes:
- WebSocket session management (P0): Return primitives instead of ORM objects
- LIKE query injection (P1): Escape special characters in search queries

Tests: 543 backend tests, 56 frontend tests passing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-11 18:41:19 +08:00

483 lines
17 KiB
Python

import json
import uuid
import re
import asyncio
import logging
import threading
import os
from datetime import datetime, timezone
from typing import List, Optional, Dict, Set
from collections import deque
from sqlalchemy.orm import Session
from sqlalchemy import event
from app.models import User, Notification, Task, Comment, Mention
from app.core.redis_pubsub import publish_notification as redis_publish, get_channel_name
from app.core.redis import get_redis_sync
from app.core.database import escape_like
logger = logging.getLogger(__name__)
# Thread-safe lock for module-level state
_lock = threading.Lock()
# Module-level queue for notifications pending publish after commit
_pending_publish: Dict[int, List[dict]] = {}
# Track which sessions have handlers registered
_registered_sessions: Set[int] = set()
# Redis fallback queue configuration
REDIS_FALLBACK_MAX_QUEUE_SIZE = int(os.getenv("REDIS_FALLBACK_MAX_QUEUE_SIZE", "1000"))
REDIS_FALLBACK_RETRY_INTERVAL = int(os.getenv("REDIS_FALLBACK_RETRY_INTERVAL", "5")) # seconds
REDIS_FALLBACK_MAX_RETRIES = int(os.getenv("REDIS_FALLBACK_MAX_RETRIES", "10"))
# Redis fallback queue for failed publishes
_redis_fallback_lock = threading.Lock()
_redis_fallback_queue: deque = deque(maxlen=REDIS_FALLBACK_MAX_QUEUE_SIZE)
_redis_retry_timer: Optional[threading.Timer] = None
_redis_available = True
_redis_consecutive_failures = 0
def _add_to_fallback_queue(user_id: str, data: dict, retry_count: int = 0) -> bool:
"""
Add a failed notification to the fallback queue.
Returns True if added successfully, False if queue is full.
"""
global _redis_consecutive_failures
with _redis_fallback_lock:
if len(_redis_fallback_queue) >= REDIS_FALLBACK_MAX_QUEUE_SIZE:
logger.warning(
"Redis fallback queue is full (%d items), dropping notification for user %s",
REDIS_FALLBACK_MAX_QUEUE_SIZE, user_id
)
return False
_redis_fallback_queue.append({
"user_id": user_id,
"data": data,
"retry_count": retry_count,
"queued_at": datetime.now(timezone.utc).isoformat(),
})
_redis_consecutive_failures += 1
queue_size = len(_redis_fallback_queue)
logger.debug("Added notification to fallback queue (size: %d)", queue_size)
# Start retry mechanism if not already running
_ensure_retry_timer_running()
return True
def _ensure_retry_timer_running():
"""Ensure the retry timer is running if there are items in the queue."""
global _redis_retry_timer
if _redis_retry_timer is None or not _redis_retry_timer.is_alive():
_redis_retry_timer = threading.Timer(REDIS_FALLBACK_RETRY_INTERVAL, _process_fallback_queue)
_redis_retry_timer.daemon = True
_redis_retry_timer.start()
def _process_fallback_queue():
"""Process the fallback queue and retry sending notifications to Redis."""
global _redis_available, _redis_consecutive_failures, _redis_retry_timer
items_to_retry = []
with _redis_fallback_lock:
# Get all items from queue
while _redis_fallback_queue:
items_to_retry.append(_redis_fallback_queue.popleft())
if not items_to_retry:
_redis_retry_timer = None
return
logger.info("Processing %d items from Redis fallback queue", len(items_to_retry))
failed_items = []
success_count = 0
for item in items_to_retry:
user_id = item["user_id"]
data = item["data"]
retry_count = item["retry_count"]
if retry_count >= REDIS_FALLBACK_MAX_RETRIES:
logger.warning(
"Notification for user %s exceeded max retries (%d), dropping",
user_id, REDIS_FALLBACK_MAX_RETRIES
)
continue
try:
redis_client = get_redis_sync()
channel = get_channel_name(user_id)
message = json.dumps(data, default=str)
redis_client.publish(channel, message)
success_count += 1
except Exception as e:
logger.debug("Retry failed for user %s: %s", user_id, e)
failed_items.append({
**item,
"retry_count": retry_count + 1,
})
# Re-queue failed items
if failed_items:
with _redis_fallback_lock:
for item in failed_items:
if len(_redis_fallback_queue) < REDIS_FALLBACK_MAX_QUEUE_SIZE:
_redis_fallback_queue.append(item)
# Log recovery if we had successes
if success_count > 0:
with _redis_fallback_lock:
_redis_consecutive_failures = 0
if not _redis_fallback_queue:
_redis_available = True
logger.info(
"Redis connection recovered. Successfully processed %d notifications from fallback queue",
success_count
)
# Schedule next retry if queue is not empty
with _redis_fallback_lock:
if _redis_fallback_queue:
_redis_retry_timer = threading.Timer(REDIS_FALLBACK_RETRY_INTERVAL, _process_fallback_queue)
_redis_retry_timer.daemon = True
_redis_retry_timer.start()
else:
_redis_retry_timer = None
def get_redis_fallback_status() -> dict:
"""Get current Redis fallback queue status for health checks."""
with _redis_fallback_lock:
return {
"queue_size": len(_redis_fallback_queue),
"max_queue_size": REDIS_FALLBACK_MAX_QUEUE_SIZE,
"redis_available": _redis_available,
"consecutive_failures": _redis_consecutive_failures,
"retry_interval_seconds": REDIS_FALLBACK_RETRY_INTERVAL,
"max_retries": REDIS_FALLBACK_MAX_RETRIES,
}
def _sync_publish(user_id: str, data: dict):
"""Sync fallback to publish notification via Redis when no event loop available."""
global _redis_available
try:
redis_client = get_redis_sync()
channel = get_channel_name(user_id)
message = json.dumps(data, default=str)
redis_client.publish(channel, message)
logger.debug(f"Sync published notification to channel {channel}")
except Exception as e:
logger.error(f"Failed to sync publish notification to Redis: {e}")
# Add to fallback queue for retry
with _redis_fallback_lock:
_redis_available = False
_add_to_fallback_queue(user_id, data)
def _cleanup_session(session_id: int, remove_registration: bool = True):
"""Clean up session state after commit/rollback. Thread-safe.
Args:
session_id: The session ID to clean up
remove_registration: If True, also remove from _registered_sessions.
Set to False for soft_rollback to avoid handler stacking.
"""
with _lock:
if remove_registration:
_registered_sessions.discard(session_id)
return _pending_publish.pop(session_id, [])
def _register_session_handlers(db: Session, session_id: int):
"""Register after_commit, after_rollback, and after_soft_rollback handlers for a session."""
with _lock:
if session_id in _registered_sessions:
return
_registered_sessions.add(session_id)
@event.listens_for(db, "after_commit", once=True)
def _after_commit(session):
notifications = _cleanup_session(session_id)
if notifications:
try:
loop = asyncio.get_running_loop()
for n in notifications:
loop.create_task(_async_publish(n["user_id"], n["data"]))
except RuntimeError:
# No running event loop - use sync fallback
logger.info(f"No event loop, using sync publish for {len(notifications)} notification(s)")
for n in notifications:
_sync_publish(n["user_id"], n["data"])
@event.listens_for(db, "after_rollback", once=True)
def _after_rollback(session):
cleared = _cleanup_session(session_id)
if cleared:
logger.debug(f"Cleared {len(cleared)} pending notification(s) after rollback")
@event.listens_for(db, "after_soft_rollback", once=True)
def _after_soft_rollback(session, previous_transaction):
# Only clear pending notifications, keep handler registration to avoid stacking
cleared = _cleanup_session(session_id, remove_registration=False)
if cleared:
logger.debug(f"Cleared {len(cleared)} pending notification(s) after soft rollback")
async def _async_publish(user_id: str, data: dict):
"""Async helper to publish notification to Redis."""
global _redis_available
try:
await redis_publish(user_id, data)
except Exception as e:
logger.error(f"Failed to publish notification to Redis: {e}")
# Add to fallback queue for retry
with _redis_fallback_lock:
_redis_available = False
_add_to_fallback_queue(user_id, data)
class NotificationService:
"""Service for creating and managing notifications."""
MAX_MENTIONS_PER_COMMENT = 10
@staticmethod
def notification_to_dict(notification: Notification) -> dict:
"""Convert a Notification to a dict for publishing."""
created_at = notification.created_at
if created_at is None:
created_at = datetime.now(timezone.utc).replace(tzinfo=None)
return {
"id": notification.id,
"type": notification.type,
"reference_type": notification.reference_type,
"reference_id": notification.reference_id,
"title": notification.title,
"message": notification.message,
"is_read": notification.is_read,
"created_at": created_at.isoformat() if created_at else None,
}
@staticmethod
async def publish_notifications(notifications: List[Notification]) -> None:
"""Publish notifications to Redis for real-time WebSocket delivery."""
for notification in notifications:
if notification and notification.user_id:
data = NotificationService.notification_to_dict(notification)
await redis_publish(notification.user_id, data)
@staticmethod
async def publish_notification(notification: Optional[Notification]) -> None:
"""Publish a single notification to Redis."""
if notification:
await NotificationService.publish_notifications([notification])
@staticmethod
def _queue_for_publish(db: Session, notification: Notification):
"""Queue notification for auto-publish after commit. Thread-safe."""
session_id = id(db)
# Register handlers first (has its own lock)
_register_session_handlers(db, session_id)
# Store notification data (not object) for publishing
notification_data = {
"user_id": notification.user_id,
"data": NotificationService.notification_to_dict(notification),
}
with _lock:
if session_id not in _pending_publish:
_pending_publish[session_id] = []
_pending_publish[session_id].append(notification_data)
@staticmethod
def create_notification(
db: Session,
user_id: str,
notification_type: str,
reference_type: str,
reference_id: str,
title: str,
message: Optional[str] = None,
) -> Notification:
"""Create a notification for a user. Auto-publishes via Redis after commit."""
notification = Notification(
id=str(uuid.uuid4()),
user_id=user_id,
type=notification_type,
reference_type=reference_type,
reference_id=reference_id,
title=title,
message=message,
)
db.add(notification)
# Queue for auto-publish after commit
NotificationService._queue_for_publish(db, notification)
return notification
@staticmethod
def notify_task_assignment(
db: Session,
task: Task,
assigned_by: User,
) -> Optional[Notification]:
"""Notify user when they are assigned to a task."""
if not task.assignee_id or task.assignee_id == assigned_by.id:
return None
return NotificationService.create_notification(
db=db,
user_id=task.assignee_id,
notification_type="assignment",
reference_type="task",
reference_id=task.id,
title=f"You've been assigned to: {task.title}",
message=f"Assigned by {assigned_by.name}",
)
@staticmethod
def notify_blocker(
db: Session,
task: Task,
reported_by: User,
reason: str,
) -> List[Notification]:
"""Notify project owner when a task is blocked."""
notifications = []
# Notify project owner
project = task.project
if project and project.owner_id and project.owner_id != reported_by.id:
notification = NotificationService.create_notification(
db=db,
user_id=project.owner_id,
notification_type="blocker",
reference_type="task",
reference_id=task.id,
title=f"Task blocked: {task.title}",
message=f"Reported by {reported_by.name}: {reason[:100]}...",
)
notifications.append(notification)
return notifications
@staticmethod
def notify_blocker_resolved(
db: Session,
task: Task,
resolved_by: User,
reporter_id: str,
) -> Optional[Notification]:
"""Notify the original reporter when a blocker is resolved."""
if reporter_id == resolved_by.id:
return None
return NotificationService.create_notification(
db=db,
user_id=reporter_id,
notification_type="blocker_resolved",
reference_type="task",
reference_id=task.id,
title=f"Blocker resolved: {task.title}",
message=f"Resolved by {resolved_by.name}",
)
@staticmethod
def count_mentions(content: str) -> int:
"""Count the number of @mentions in content."""
pattern = r'@([a-zA-Z0-9._-]+(?:@[a-zA-Z0-9.-]+)?)'
matches = re.findall(pattern, content)
return len(matches)
@staticmethod
def parse_mentions(content: str) -> List[str]:
"""Extract @mentions from comment content. Returns list of email usernames."""
# Match @username patterns (alphanumeric and common email chars before @domain)
pattern = r'@([a-zA-Z0-9._-]+(?:@[a-zA-Z0-9.-]+)?)'
matches = re.findall(pattern, content)
return matches[:NotificationService.MAX_MENTIONS_PER_COMMENT]
@staticmethod
def process_mentions(
db: Session,
comment: Comment,
task: Task,
author: User,
) -> List[Notification]:
"""Process mentions in a comment and create notifications."""
notifications = []
mentioned_usernames = NotificationService.parse_mentions(comment.content)
if not mentioned_usernames:
return notifications
# Find users by email or name
for username in mentioned_usernames:
# Escape special LIKE characters to prevent injection
escaped_username = escape_like(username)
# Try to find user by email first
user = db.query(User).filter(
(User.email.ilike(f"{escaped_username}%", escape="\\")) |
(User.name.ilike(f"%{escaped_username}%", escape="\\"))
).first()
if user and user.id != author.id:
# Create mention record
mention = Mention(
id=str(uuid.uuid4()),
comment_id=comment.id,
mentioned_user_id=user.id,
)
db.add(mention)
# Create notification
notification = NotificationService.create_notification(
db=db,
user_id=user.id,
notification_type="mention",
reference_type="comment",
reference_id=comment.id,
title=f"{author.name} mentioned you in: {task.title}",
message=comment.content[:100] + ("..." if len(comment.content) > 100 else ""),
)
notifications.append(notification)
return notifications
@staticmethod
def notify_comment_reply(
db: Session,
comment: Comment,
task: Task,
author: User,
parent_author_id: str,
) -> Optional[Notification]:
"""Notify original commenter when someone replies."""
if parent_author_id == author.id:
return None
return NotificationService.create_notification(
db=db,
user_id=parent_author_id,
notification_type="comment",
reference_type="comment",
reference_id=comment.id,
title=f"{author.name} replied to your comment on: {task.title}",
message=comment.content[:100] + ("..." if len(comment.content) > 100 else ""),
)