fix(security): 重構 table query 至標準架構,修復 SQL injection 與 regex 安全問題

- 重構 get_table_data/get_table_columns 使用 TABLES_CONFIG 白名單 + QueryBuilder + read_sql_df
- 移除 get_db_connection() 直連,改用連線池 + 熔斷器 + 慢查詢監控
- get_engine() 從 Flask Config 讀取 DB_POOL_SIZE/DB_MAX_OVERFLOW
- query_table limit 上限 10,000 防止記憶體溢出
- wip_service 6 處 str.contains 加 regex=False 防止 ReDoS

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
beabigegg
2026-02-09 07:35:07 +08:00
parent 32f3e18e9d
commit 21ec1ea541
3 changed files with 103 additions and 103 deletions

View File

@@ -92,7 +92,7 @@ def create_app(config_name: str | None = None) -> Flask:
# Initialize database teardown and pool # Initialize database teardown and pool
init_db(app) init_db(app)
with app.app_context(): with app.app_context():
get_engine() get_engine(app.config) # Use config for pool_size/max_overflow
start_keepalive() # Keep database connections alive start_keepalive() # Keep database connections alive
start_cache_updater() # Start Redis cache updater start_cache_updater() # Start Redis cache updater
init_realtime_equipment_cache(app) # Start realtime equipment status cache init_realtime_equipment_cache(app) # Start realtime equipment status cache
@@ -230,7 +230,7 @@ def create_app(config_name: str | None = None) -> Flask:
"""API: query table data with optional column filters.""" """API: query table data with optional column filters."""
data = request.get_json() data = request.get_json()
table_name = data.get('table_name') table_name = data.get('table_name')
limit = data.get('limit', 1000) limit = min(data.get('limit', 1000), 10000) # Cap at 10,000
time_field = data.get('time_field') time_field = data.get('time_field')
filters = data.get('filters') filters = data.get('filters')

View File

@@ -59,23 +59,24 @@ logger = logging.getLogger('mes_dashboard.database')
_ENGINE = None _ENGINE = None
def get_engine(): def get_engine(app_config=None):
"""Get SQLAlchemy engine with connection pooling. """Get SQLAlchemy engine with connection pooling.
Uses QueuePool for connection reuse and better performance. Uses QueuePool for connection reuse and better performance.
- pool_size: Base number of persistent connections
- max_overflow: Additional connections during peak load Args:
- pool_timeout: Max wait time for available connection app_config: Optional Flask app.config dict. If provided,
- pool_recycle: Recycle connections after 30 minutes reads DB_POOL_SIZE and DB_MAX_OVERFLOW from config.
- pool_pre_ping: Validate connection before checkout
""" """
global _ENGINE global _ENGINE
if _ENGINE is None: if _ENGINE is None:
pool_size = app_config.get('DB_POOL_SIZE', 5) if app_config else 5
max_overflow = app_config.get('DB_MAX_OVERFLOW', 10) if app_config else 10
_ENGINE = create_engine( _ENGINE = create_engine(
CONNECTION_STRING, CONNECTION_STRING,
poolclass=QueuePool, poolclass=QueuePool,
pool_size=5, # Base connections pool_size=pool_size,
max_overflow=10, # Peak extra connections (total max: 15) max_overflow=max_overflow,
pool_timeout=30, # Wait up to 30s for connection pool_timeout=30, # Wait up to 30s for connection
pool_recycle=1800, # Recycle connections every 30 minutes pool_recycle=1800, # Recycle connections every 30 minutes
pool_pre_ping=True, # Validate connection before use pool_pre_ping=True, # Validate connection before use
@@ -89,7 +90,7 @@ def get_engine():
_register_pool_events(_ENGINE) _register_pool_events(_ENGINE)
logger.info( logger.info(
"Database engine created with QueuePool " "Database engine created with QueuePool "
f"(pool_size=5, max_overflow=10, pool_recycle=1800)" f"(pool_size={pool_size}, max_overflow={max_overflow}, pool_recycle=1800)"
) )
return _ENGINE return _ENGINE
@@ -437,23 +438,43 @@ def read_sql_df_slow(
# Table Utilities # Table Utilities
# ============================================================ # ============================================================
# Whitelist cache: maps uppercase table name → TABLES_CONFIG entry
_ALLOWED_TABLES: Dict[str, dict] = {}
def _get_table_config(table_name: str) -> Optional[dict]:
"""Look up table in TABLES_CONFIG whitelist.
Returns the config dict if found, None otherwise.
"""
global _ALLOWED_TABLES
if not _ALLOWED_TABLES:
from mes_dashboard.config.tables import TABLES_CONFIG
for tables in TABLES_CONFIG.values():
for t in tables:
_ALLOWED_TABLES[t['name'].upper()] = t
return _ALLOWED_TABLES.get(table_name.upper())
def _validate_column_name(col_name: str) -> bool:
"""Validate column name format (alphanumeric + underscore only)."""
return bool(re.match(r'^[A-Za-z_][A-Za-z0-9_]*$', col_name))
def get_table_columns(table_name: str) -> list: def get_table_columns(table_name: str) -> list:
"""Get column names for a table.""" """Get column names for a whitelisted table.
connection = get_db_connection()
if not connection: Uses read_sql_df() for connection pooling and circuit breaker.
"""
if not _get_table_config(table_name):
logger.warning(f"Table not in whitelist: {table_name}")
return [] return []
try: try:
cursor = connection.cursor() # table_name validated against whitelist — safe to embed
cursor.execute(f"SELECT * FROM {table_name} WHERE ROWNUM <= 1") df = read_sql_df(f"SELECT * FROM {table_name} WHERE ROWNUM <= 1")
columns = [desc[0] for desc in cursor.description] return list(df.columns)
cursor.close()
connection.close()
return columns
except Exception: except Exception:
if connection:
connection.close()
return [] return []
@@ -463,87 +484,61 @@ def get_table_data(
time_field: Optional[str] = None, time_field: Optional[str] = None,
filters: Optional[Dict[str, str]] = None, filters: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Fetch rows from a table with optional filtering and sorting.""" """Fetch rows from a whitelisted table with optional filtering and sorting.
from datetime import datetime
connection = get_db_connection() Uses TABLES_CONFIG whitelist + QueryBuilder for safe parameterized queries.
if not connection: Executes via read_sql_df() (connection pool + circuit breaker + metrics).
return {'error': 'Database connection failed'} """
from mes_dashboard.sql.builder import QueryBuilder
# 1. Whitelist validation
table_cfg = _get_table_config(table_name)
if not table_cfg:
return {'error': f'不允許查詢的表: {table_name}'}
# 2. time_field validation: only allow the value defined in TABLES_CONFIG
if time_field:
allowed_tf = table_cfg.get('time_field')
if not allowed_tf or time_field.upper() != allowed_tf.upper():
return {'error': f'不允許的時間欄位: {time_field}'}
# 3. Build WHERE clause with QueryBuilder (parameterized)
builder = QueryBuilder()
if time_field:
builder.add_is_not_null(time_field)
if filters:
for col, val in filters.items():
if val and val.strip() and _validate_column_name(col):
builder.add_like_condition(
f"UPPER(TO_CHAR({col}))", val.strip(), position="both"
)
where_clause, params = builder.build_where_only()
# 4. Build SQL (table_name and time_field are whitelist-validated)
order_by = f"ORDER BY {time_field} DESC" if time_field else ""
if where_clause or order_by:
sql = (
f"SELECT * FROM ("
f"SELECT * FROM {table_name} {where_clause} {order_by}"
f") WHERE ROWNUM <= :row_limit"
)
else:
sql = f"SELECT * FROM {table_name} WHERE ROWNUM <= :row_limit"
params['row_limit'] = limit
# 5. Execute via read_sql_df (connection pool + circuit breaker + metrics)
try: try:
cursor = connection.cursor() df = read_sql_df(sql, params)
# Convert datetime columns to string for JSON serialization
where_conditions = [] for col in df.select_dtypes(include=['datetime64']).columns:
bind_params = {} df[col] = df[col].dt.strftime('%Y-%m-%d %H:%M:%S')
if filters:
for col, val in filters.items():
if val and val.strip():
safe_col = ''.join(c for c in col if c.isalnum() or c == '_')
param_name = f"p_{safe_col}"
where_conditions.append(
f"UPPER(TO_CHAR({safe_col})) LIKE UPPER(:{param_name})"
)
bind_params[param_name] = f"%{val.strip()}%"
if time_field:
time_condition = f"{time_field} IS NOT NULL"
if where_conditions:
all_conditions = " AND ".join([time_condition] + where_conditions)
else:
all_conditions = time_condition
sql = f"""
SELECT * FROM (
SELECT * FROM {table_name}
WHERE {all_conditions}
ORDER BY {time_field} DESC
) WHERE ROWNUM <= :row_limit
"""
else:
if where_conditions:
all_conditions = " AND ".join(where_conditions)
sql = f"""
SELECT * FROM (
SELECT * FROM {table_name}
WHERE {all_conditions}
) WHERE ROWNUM <= :row_limit
"""
else:
sql = f"""
SELECT * FROM {table_name}
WHERE ROWNUM <= :row_limit
"""
bind_params['row_limit'] = limit
cursor.execute(sql, bind_params)
columns = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
data = []
for row in rows:
row_dict = {}
for i, col in enumerate(columns):
value = row[i]
if isinstance(value, datetime):
row_dict[col] = value.strftime('%Y-%m-%d %H:%M:%S')
else:
row_dict[col] = value
data.append(row_dict)
cursor.close()
connection.close()
return { return {
'columns': columns, 'columns': list(df.columns),
'data': data, 'data': df.to_dict('records'),
'row_count': len(data) 'row_count': len(df),
} }
except Exception as exc: except Exception as exc:
ora_code = _extract_ora_code(exc) logger.error(f"get_table_data failed: {exc}")
logger.error(f"get_table_data failed - ORA-{ora_code}: {exc}")
if connection:
connection.close()
return {'error': f'查詢失敗: {str(exc)}'} return {'error': f'查詢失敗: {str(exc)}'}
@@ -552,6 +547,7 @@ def get_table_column_metadata(table_name: str) -> Dict[str, Any]:
Args: Args:
table_name: Table name in format 'SCHEMA.TABLE' or 'TABLE' table_name: Table name in format 'SCHEMA.TABLE' or 'TABLE'
Must be in TABLES_CONFIG whitelist.
Returns: Returns:
Dict with 'columns' list containing column info: Dict with 'columns' list containing column info:
@@ -563,6 +559,10 @@ def get_table_column_metadata(table_name: str) -> Dict[str, Any]:
- is_date: True if column is DATE or TIMESTAMP type - is_date: True if column is DATE or TIMESTAMP type
- is_number: True if column is NUMBER type - is_number: True if column is NUMBER type
""" """
if not _get_table_config(table_name):
logger.warning(f"Table not in whitelist: {table_name}")
return {'error': f'不允許查詢的表: {table_name}', 'columns': []}
connection = get_db_connection() connection = get_db_connection()
if not connection: if not connection:
return {'error': 'Database connection failed', 'columns': []} return {'error': 'Database connection failed', 'columns': []}

View File

@@ -195,11 +195,11 @@ def _filter_base_conditions(
# WORKORDER filter (fuzzy match) # WORKORDER filter (fuzzy match)
if workorder: if workorder:
df = df[df['WORKORDER'].str.contains(workorder, case=False, na=False)] df = df[df['WORKORDER'].str.contains(workorder, case=False, na=False, regex=False)]
# LOTID filter (fuzzy match) # LOTID filter (fuzzy match)
if lotid: if lotid:
df = df[df['LOTID'].str.contains(lotid, case=False, na=False)] df = df[df['LOTID'].str.contains(lotid, case=False, na=False, regex=False)]
return df return df
@@ -1160,7 +1160,7 @@ def search_workorders(
df = df[df['PJ_TYPE'] == pj_type] df = df[df['PJ_TYPE'] == pj_type]
# Filter by search query (case-insensitive) # Filter by search query (case-insensitive)
df = df[df['WORKORDER'].str.contains(q, case=False, na=False)] df = df[df['WORKORDER'].str.contains(q, case=False, na=False, regex=False)]
if df.empty: if df.empty:
return [] return []
@@ -1258,7 +1258,7 @@ def search_lot_ids(
df = df[df['PJ_TYPE'] == pj_type] df = df[df['PJ_TYPE'] == pj_type]
# Filter by search query (case-insensitive) # Filter by search query (case-insensitive)
df = df[df['LOTID'].str.contains(q, case=False, na=False)] df = df[df['LOTID'].str.contains(q, case=False, na=False, regex=False)]
if df.empty: if df.empty:
return [] return []
@@ -1360,7 +1360,7 @@ def search_packages(
df = df[df['PJ_TYPE'] == pj_type] df = df[df['PJ_TYPE'] == pj_type]
# Filter by search query (case-insensitive) # Filter by search query (case-insensitive)
df = df[df['PACKAGE_LEF'].str.contains(q, case=False, na=False)] df = df[df['PACKAGE_LEF'].str.contains(q, case=False, na=False, regex=False)]
if df.empty: if df.empty:
return [] return []
@@ -1461,7 +1461,7 @@ def search_types(
df = df[df['PACKAGE_LEF'] == package] df = df[df['PACKAGE_LEF'] == package]
# Filter by search query (case-insensitive) # Filter by search query (case-insensitive)
df = df[df['PJ_TYPE'].str.contains(q, case=False, na=False)] df = df[df['PJ_TYPE'].str.contains(q, case=False, na=False, regex=False)]
if df.empty: if df.empty:
return [] return []