fix(security): 重構 table query 至標準架構,修復 SQL injection 與 regex 安全問題
- 重構 get_table_data/get_table_columns 使用 TABLES_CONFIG 白名單 + QueryBuilder + read_sql_df - 移除 get_db_connection() 直連,改用連線池 + 熔斷器 + 慢查詢監控 - get_engine() 從 Flask Config 讀取 DB_POOL_SIZE/DB_MAX_OVERFLOW - query_table limit 上限 10,000 防止記憶體溢出 - wip_service 6 處 str.contains 加 regex=False 防止 ReDoS Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -92,7 +92,7 @@ def create_app(config_name: str | None = None) -> Flask:
|
||||
# Initialize database teardown and pool
|
||||
init_db(app)
|
||||
with app.app_context():
|
||||
get_engine()
|
||||
get_engine(app.config) # Use config for pool_size/max_overflow
|
||||
start_keepalive() # Keep database connections alive
|
||||
start_cache_updater() # Start Redis cache updater
|
||||
init_realtime_equipment_cache(app) # Start realtime equipment status cache
|
||||
@@ -230,7 +230,7 @@ def create_app(config_name: str | None = None) -> Flask:
|
||||
"""API: query table data with optional column filters."""
|
||||
data = request.get_json()
|
||||
table_name = data.get('table_name')
|
||||
limit = data.get('limit', 1000)
|
||||
limit = min(data.get('limit', 1000), 10000) # Cap at 10,000
|
||||
time_field = data.get('time_field')
|
||||
filters = data.get('filters')
|
||||
|
||||
|
||||
@@ -59,23 +59,24 @@ logger = logging.getLogger('mes_dashboard.database')
|
||||
_ENGINE = None
|
||||
|
||||
|
||||
def get_engine():
|
||||
def get_engine(app_config=None):
|
||||
"""Get SQLAlchemy engine with connection pooling.
|
||||
|
||||
Uses QueuePool for connection reuse and better performance.
|
||||
- pool_size: Base number of persistent connections
|
||||
- max_overflow: Additional connections during peak load
|
||||
- pool_timeout: Max wait time for available connection
|
||||
- pool_recycle: Recycle connections after 30 minutes
|
||||
- pool_pre_ping: Validate connection before checkout
|
||||
|
||||
Args:
|
||||
app_config: Optional Flask app.config dict. If provided,
|
||||
reads DB_POOL_SIZE and DB_MAX_OVERFLOW from config.
|
||||
"""
|
||||
global _ENGINE
|
||||
if _ENGINE is None:
|
||||
pool_size = app_config.get('DB_POOL_SIZE', 5) if app_config else 5
|
||||
max_overflow = app_config.get('DB_MAX_OVERFLOW', 10) if app_config else 10
|
||||
_ENGINE = create_engine(
|
||||
CONNECTION_STRING,
|
||||
poolclass=QueuePool,
|
||||
pool_size=5, # Base connections
|
||||
max_overflow=10, # Peak extra connections (total max: 15)
|
||||
pool_size=pool_size,
|
||||
max_overflow=max_overflow,
|
||||
pool_timeout=30, # Wait up to 30s for connection
|
||||
pool_recycle=1800, # Recycle connections every 30 minutes
|
||||
pool_pre_ping=True, # Validate connection before use
|
||||
@@ -89,7 +90,7 @@ def get_engine():
|
||||
_register_pool_events(_ENGINE)
|
||||
logger.info(
|
||||
"Database engine created with QueuePool "
|
||||
f"(pool_size=5, max_overflow=10, pool_recycle=1800)"
|
||||
f"(pool_size={pool_size}, max_overflow={max_overflow}, pool_recycle=1800)"
|
||||
)
|
||||
return _ENGINE
|
||||
|
||||
@@ -437,23 +438,43 @@ def read_sql_df_slow(
|
||||
# Table Utilities
|
||||
# ============================================================
|
||||
|
||||
# Whitelist cache: maps uppercase table name → TABLES_CONFIG entry
|
||||
_ALLOWED_TABLES: Dict[str, dict] = {}
|
||||
|
||||
|
||||
def _get_table_config(table_name: str) -> Optional[dict]:
|
||||
"""Look up table in TABLES_CONFIG whitelist.
|
||||
|
||||
Returns the config dict if found, None otherwise.
|
||||
"""
|
||||
global _ALLOWED_TABLES
|
||||
if not _ALLOWED_TABLES:
|
||||
from mes_dashboard.config.tables import TABLES_CONFIG
|
||||
for tables in TABLES_CONFIG.values():
|
||||
for t in tables:
|
||||
_ALLOWED_TABLES[t['name'].upper()] = t
|
||||
return _ALLOWED_TABLES.get(table_name.upper())
|
||||
|
||||
|
||||
def _validate_column_name(col_name: str) -> bool:
|
||||
"""Validate column name format (alphanumeric + underscore only)."""
|
||||
return bool(re.match(r'^[A-Za-z_][A-Za-z0-9_]*$', col_name))
|
||||
|
||||
|
||||
def get_table_columns(table_name: str) -> list:
|
||||
"""Get column names for a table."""
|
||||
connection = get_db_connection()
|
||||
if not connection:
|
||||
"""Get column names for a whitelisted table.
|
||||
|
||||
Uses read_sql_df() for connection pooling and circuit breaker.
|
||||
"""
|
||||
if not _get_table_config(table_name):
|
||||
logger.warning(f"Table not in whitelist: {table_name}")
|
||||
return []
|
||||
|
||||
try:
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(f"SELECT * FROM {table_name} WHERE ROWNUM <= 1")
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
cursor.close()
|
||||
connection.close()
|
||||
return columns
|
||||
# table_name validated against whitelist — safe to embed
|
||||
df = read_sql_df(f"SELECT * FROM {table_name} WHERE ROWNUM <= 1")
|
||||
return list(df.columns)
|
||||
except Exception:
|
||||
if connection:
|
||||
connection.close()
|
||||
return []
|
||||
|
||||
|
||||
@@ -463,87 +484,61 @@ def get_table_data(
|
||||
time_field: Optional[str] = None,
|
||||
filters: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch rows from a table with optional filtering and sorting."""
|
||||
from datetime import datetime
|
||||
"""Fetch rows from a whitelisted table with optional filtering and sorting.
|
||||
|
||||
connection = get_db_connection()
|
||||
if not connection:
|
||||
return {'error': 'Database connection failed'}
|
||||
Uses TABLES_CONFIG whitelist + QueryBuilder for safe parameterized queries.
|
||||
Executes via read_sql_df() (connection pool + circuit breaker + metrics).
|
||||
"""
|
||||
from mes_dashboard.sql.builder import QueryBuilder
|
||||
|
||||
# 1. Whitelist validation
|
||||
table_cfg = _get_table_config(table_name)
|
||||
if not table_cfg:
|
||||
return {'error': f'不允許查詢的表: {table_name}'}
|
||||
|
||||
# 2. time_field validation: only allow the value defined in TABLES_CONFIG
|
||||
if time_field:
|
||||
allowed_tf = table_cfg.get('time_field')
|
||||
if not allowed_tf or time_field.upper() != allowed_tf.upper():
|
||||
return {'error': f'不允許的時間欄位: {time_field}'}
|
||||
|
||||
# 3. Build WHERE clause with QueryBuilder (parameterized)
|
||||
builder = QueryBuilder()
|
||||
if time_field:
|
||||
builder.add_is_not_null(time_field)
|
||||
if filters:
|
||||
for col, val in filters.items():
|
||||
if val and val.strip() and _validate_column_name(col):
|
||||
builder.add_like_condition(
|
||||
f"UPPER(TO_CHAR({col}))", val.strip(), position="both"
|
||||
)
|
||||
where_clause, params = builder.build_where_only()
|
||||
|
||||
# 4. Build SQL (table_name and time_field are whitelist-validated)
|
||||
order_by = f"ORDER BY {time_field} DESC" if time_field else ""
|
||||
if where_clause or order_by:
|
||||
sql = (
|
||||
f"SELECT * FROM ("
|
||||
f"SELECT * FROM {table_name} {where_clause} {order_by}"
|
||||
f") WHERE ROWNUM <= :row_limit"
|
||||
)
|
||||
else:
|
||||
sql = f"SELECT * FROM {table_name} WHERE ROWNUM <= :row_limit"
|
||||
params['row_limit'] = limit
|
||||
|
||||
# 5. Execute via read_sql_df (connection pool + circuit breaker + metrics)
|
||||
try:
|
||||
cursor = connection.cursor()
|
||||
|
||||
where_conditions = []
|
||||
bind_params = {}
|
||||
|
||||
if filters:
|
||||
for col, val in filters.items():
|
||||
if val and val.strip():
|
||||
safe_col = ''.join(c for c in col if c.isalnum() or c == '_')
|
||||
param_name = f"p_{safe_col}"
|
||||
where_conditions.append(
|
||||
f"UPPER(TO_CHAR({safe_col})) LIKE UPPER(:{param_name})"
|
||||
)
|
||||
bind_params[param_name] = f"%{val.strip()}%"
|
||||
|
||||
if time_field:
|
||||
time_condition = f"{time_field} IS NOT NULL"
|
||||
if where_conditions:
|
||||
all_conditions = " AND ".join([time_condition] + where_conditions)
|
||||
else:
|
||||
all_conditions = time_condition
|
||||
|
||||
sql = f"""
|
||||
SELECT * FROM (
|
||||
SELECT * FROM {table_name}
|
||||
WHERE {all_conditions}
|
||||
ORDER BY {time_field} DESC
|
||||
) WHERE ROWNUM <= :row_limit
|
||||
"""
|
||||
else:
|
||||
if where_conditions:
|
||||
all_conditions = " AND ".join(where_conditions)
|
||||
sql = f"""
|
||||
SELECT * FROM (
|
||||
SELECT * FROM {table_name}
|
||||
WHERE {all_conditions}
|
||||
) WHERE ROWNUM <= :row_limit
|
||||
"""
|
||||
else:
|
||||
sql = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE ROWNUM <= :row_limit
|
||||
"""
|
||||
|
||||
bind_params['row_limit'] = limit
|
||||
cursor.execute(sql, bind_params)
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
rows = cursor.fetchall()
|
||||
|
||||
data = []
|
||||
for row in rows:
|
||||
row_dict = {}
|
||||
for i, col in enumerate(columns):
|
||||
value = row[i]
|
||||
if isinstance(value, datetime):
|
||||
row_dict[col] = value.strftime('%Y-%m-%d %H:%M:%S')
|
||||
else:
|
||||
row_dict[col] = value
|
||||
data.append(row_dict)
|
||||
|
||||
cursor.close()
|
||||
connection.close()
|
||||
|
||||
df = read_sql_df(sql, params)
|
||||
# Convert datetime columns to string for JSON serialization
|
||||
for col in df.select_dtypes(include=['datetime64']).columns:
|
||||
df[col] = df[col].dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
return {
|
||||
'columns': columns,
|
||||
'data': data,
|
||||
'row_count': len(data)
|
||||
'columns': list(df.columns),
|
||||
'data': df.to_dict('records'),
|
||||
'row_count': len(df),
|
||||
}
|
||||
except Exception as exc:
|
||||
ora_code = _extract_ora_code(exc)
|
||||
logger.error(f"get_table_data failed - ORA-{ora_code}: {exc}")
|
||||
if connection:
|
||||
connection.close()
|
||||
logger.error(f"get_table_data failed: {exc}")
|
||||
return {'error': f'查詢失敗: {str(exc)}'}
|
||||
|
||||
|
||||
@@ -552,6 +547,7 @@ def get_table_column_metadata(table_name: str) -> Dict[str, Any]:
|
||||
|
||||
Args:
|
||||
table_name: Table name in format 'SCHEMA.TABLE' or 'TABLE'
|
||||
Must be in TABLES_CONFIG whitelist.
|
||||
|
||||
Returns:
|
||||
Dict with 'columns' list containing column info:
|
||||
@@ -563,6 +559,10 @@ def get_table_column_metadata(table_name: str) -> Dict[str, Any]:
|
||||
- is_date: True if column is DATE or TIMESTAMP type
|
||||
- is_number: True if column is NUMBER type
|
||||
"""
|
||||
if not _get_table_config(table_name):
|
||||
logger.warning(f"Table not in whitelist: {table_name}")
|
||||
return {'error': f'不允許查詢的表: {table_name}', 'columns': []}
|
||||
|
||||
connection = get_db_connection()
|
||||
if not connection:
|
||||
return {'error': 'Database connection failed', 'columns': []}
|
||||
|
||||
@@ -195,11 +195,11 @@ def _filter_base_conditions(
|
||||
|
||||
# WORKORDER filter (fuzzy match)
|
||||
if workorder:
|
||||
df = df[df['WORKORDER'].str.contains(workorder, case=False, na=False)]
|
||||
df = df[df['WORKORDER'].str.contains(workorder, case=False, na=False, regex=False)]
|
||||
|
||||
# LOTID filter (fuzzy match)
|
||||
if lotid:
|
||||
df = df[df['LOTID'].str.contains(lotid, case=False, na=False)]
|
||||
df = df[df['LOTID'].str.contains(lotid, case=False, na=False, regex=False)]
|
||||
|
||||
return df
|
||||
|
||||
@@ -1160,7 +1160,7 @@ def search_workorders(
|
||||
df = df[df['PJ_TYPE'] == pj_type]
|
||||
|
||||
# Filter by search query (case-insensitive)
|
||||
df = df[df['WORKORDER'].str.contains(q, case=False, na=False)]
|
||||
df = df[df['WORKORDER'].str.contains(q, case=False, na=False, regex=False)]
|
||||
|
||||
if df.empty:
|
||||
return []
|
||||
@@ -1258,7 +1258,7 @@ def search_lot_ids(
|
||||
df = df[df['PJ_TYPE'] == pj_type]
|
||||
|
||||
# Filter by search query (case-insensitive)
|
||||
df = df[df['LOTID'].str.contains(q, case=False, na=False)]
|
||||
df = df[df['LOTID'].str.contains(q, case=False, na=False, regex=False)]
|
||||
|
||||
if df.empty:
|
||||
return []
|
||||
@@ -1360,7 +1360,7 @@ def search_packages(
|
||||
df = df[df['PJ_TYPE'] == pj_type]
|
||||
|
||||
# Filter by search query (case-insensitive)
|
||||
df = df[df['PACKAGE_LEF'].str.contains(q, case=False, na=False)]
|
||||
df = df[df['PACKAGE_LEF'].str.contains(q, case=False, na=False, regex=False)]
|
||||
|
||||
if df.empty:
|
||||
return []
|
||||
@@ -1461,7 +1461,7 @@ def search_types(
|
||||
df = df[df['PACKAGE_LEF'] == package]
|
||||
|
||||
# Filter by search query (case-insensitive)
|
||||
df = df[df['PJ_TYPE'].str.contains(q, case=False, na=False)]
|
||||
df = df[df['PJ_TYPE'].str.contains(q, case=False, na=False, regex=False)]
|
||||
|
||||
if df.empty:
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user