""" Query monitoring utilities for detecting N+1 queries and performance issues. This module provides: 1. Query counting per request in development mode 2. SQLAlchemy event listeners for query logging 3. Threshold-based warnings for excessive queries """ import logging import threading import time from contextlib import contextmanager from typing import Optional, Callable, Any from sqlalchemy import event from sqlalchemy.engine import Engine from app.core.config import settings logger = logging.getLogger(__name__) # Thread-local storage for per-request query counting _query_context = threading.local() class QueryCounter: """ Context manager for counting database queries within a request. Usage: with QueryCounter() as counter: # ... execute queries ... print(f"Executed {counter.count} queries") """ def __init__(self, threshold: Optional[int] = None, context_name: str = "request"): self.threshold = threshold or settings.QUERY_COUNT_THRESHOLD self.context_name = context_name self.count = 0 self.queries = [] self.start_time = None self.total_time = 0.0 def __enter__(self): self.count = 0 self.queries = [] self.start_time = time.time() _query_context.counter = self return self def __exit__(self, exc_type, exc_val, exc_tb): self.total_time = time.time() - self.start_time _query_context.counter = None # Log warning if threshold exceeded if self.count > self.threshold: logger.warning( "Query count threshold exceeded in %s: %d queries (threshold: %d, time: %.3fs)", self.context_name, self.count, self.threshold, self.total_time, ) if settings.DEBUG: # In debug mode, also log the individual queries for i, (sql, duration) in enumerate(self.queries[:20], 1): logger.debug(" Query %d (%.3fs): %s", i, duration, sql[:200]) if len(self.queries) > 20: logger.debug(" ... and %d more queries", len(self.queries) - 20) elif settings.DEBUG and self.count > 0: logger.debug( "Query count for %s: %d queries in %.3fs", self.context_name, self.count, self.total_time, ) return False def record_query(self, statement: str, duration: float): """Record a query execution.""" self.count += 1 if settings.DEBUG: self.queries.append((statement, duration)) def get_current_counter() -> Optional[QueryCounter]: """Get the current request's query counter, if any.""" return getattr(_query_context, 'counter', None) def setup_query_logging(engine: Engine): """ Set up SQLAlchemy event listeners for query logging. This should be called once during application startup. Only activates if QUERY_LOGGING is enabled in settings. """ if not settings.QUERY_LOGGING: logger.info("Query logging is disabled") return logger.info("Setting up query logging with threshold=%d", settings.QUERY_COUNT_THRESHOLD) @event.listens_for(engine, "before_cursor_execute") def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): conn.info.setdefault('query_start_time', []).append(time.time()) @event.listens_for(engine, "after_cursor_execute") def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): start_times = conn.info.get('query_start_time', []) duration = time.time() - start_times.pop() if start_times else 0.0 # Record in current counter if active counter = get_current_counter() if counter: counter.record_query(statement, duration) # Also log individual queries if in debug mode if settings.DEBUG: logger.debug("SQL (%.3fs): %s", duration, statement[:500]) @contextmanager def count_queries(context_name: str = "operation", threshold: Optional[int] = None): """ Context manager to count queries for a specific operation. Args: context_name: Name for logging purposes threshold: Override the default query count threshold Usage: with count_queries("list_members") as counter: members = db.query(ProjectMember).all() for member in members: print(member.user.name) # N+1 query! # After block, logs warning if threshold exceeded print(f"Total queries: {counter.count}") """ with QueryCounter(threshold=threshold, context_name=context_name) as counter: yield counter def assert_query_count(max_queries: int): """ Decorator for testing that asserts maximum query count. Usage in tests: @assert_query_count(5) def test_list_members(): # Should use at most 5 queries response = client.get("/api/projects/xxx/members") """ def decorator(func: Callable) -> Callable: def wrapper(*args, **kwargs): with QueryCounter(threshold=max_queries, context_name=func.__name__) as counter: result = func(*args, **kwargs) if counter.count > max_queries: raise AssertionError( f"Query count {counter.count} exceeded maximum {max_queries} " f"in {func.__name__}" ) return result return wrapper return decorator