改用API驗證
This commit is contained in:
@@ -14,7 +14,7 @@ from flask import Blueprint
|
||||
api_v1 = Blueprint('api_v1', __name__, url_prefix='/api/v1')
|
||||
|
||||
# 匯入各 API 模組
|
||||
from . import auth, jobs, files, admin, health, notification
|
||||
from . import auth, jobs, files, admin, health, notification, cache
|
||||
|
||||
# 註冊路由
|
||||
api_v1.register_blueprint(auth.auth_bp)
|
||||
@@ -22,4 +22,5 @@ api_v1.register_blueprint(jobs.jobs_bp)
|
||||
api_v1.register_blueprint(files.files_bp)
|
||||
api_v1.register_blueprint(admin.admin_bp)
|
||||
api_v1.register_blueprint(health.health_bp)
|
||||
api_v1.register_blueprint(notification.notification_bp)
|
||||
api_v1.register_blueprint(notification.notification_bp)
|
||||
api_v1.register_blueprint(cache.cache_bp)
|
220
app/api/auth.py
220
app/api/auth.py
@@ -14,10 +14,12 @@ from flask_jwt_extended import (
|
||||
jwt_required, get_jwt_identity, get_jwt
|
||||
)
|
||||
from app.utils.ldap_auth import LDAPAuthService
|
||||
from app.utils.api_auth import APIAuthService
|
||||
from app.utils.decorators import validate_json, rate_limit
|
||||
from app.utils.exceptions import AuthenticationError
|
||||
from app.utils.logger import get_logger
|
||||
from app.models.user import User
|
||||
from app.models.sys_user import SysUser, LoginLog
|
||||
from app.models.log import SystemLog
|
||||
|
||||
auth_bp = Blueprint('auth', __name__, url_prefix='/auth')
|
||||
@@ -28,70 +30,222 @@ logger = get_logger(__name__)
|
||||
@rate_limit(max_requests=10, per_seconds=300) # 5分鐘內最多10次嘗試
|
||||
@validate_json(['username', 'password'])
|
||||
def login():
|
||||
"""使用者登入"""
|
||||
"""使用者登入 - API 認證為主,LDAP 作為備援"""
|
||||
username = None
|
||||
try:
|
||||
data = request.get_json()
|
||||
username = data['username'].strip()
|
||||
password = data['password']
|
||||
|
||||
|
||||
if not username or not password:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': 'INVALID_INPUT',
|
||||
'message': '帳號和密碼不能為空'
|
||||
}), 400
|
||||
|
||||
# LDAP 認證
|
||||
ldap_service = LDAPAuthService()
|
||||
user_info = ldap_service.authenticate_user(username, password)
|
||||
|
||||
# 取得或建立使用者
|
||||
|
||||
# 取得環境資訊
|
||||
ip_address = request.remote_addr
|
||||
user_agent = request.headers.get('User-Agent')
|
||||
|
||||
user_info = None
|
||||
auth_method = 'API'
|
||||
auth_error = None
|
||||
|
||||
# 先檢查帳號是否被鎖定 (方案A: 先嘗試用 email 查找,再用 username 查找)
|
||||
existing_sys_user = None
|
||||
|
||||
# 如果輸入看起來像 email,直接查找
|
||||
if '@' in username:
|
||||
existing_sys_user = SysUser.query.filter_by(email=username).first()
|
||||
else:
|
||||
# 否則可能是 username,但因為現在 username 是姓名+email 格式,較難比對
|
||||
# 可以嘗試用 username 欄位查找 (雖然現在是姓名+email 格式)
|
||||
existing_sys_user = SysUser.query.filter_by(username=username).first()
|
||||
|
||||
if existing_sys_user and existing_sys_user.is_account_locked():
|
||||
logger.warning(f"帳號被鎖定: {username}")
|
||||
raise AuthenticationError("帳號已被鎖定,請稍後再試")
|
||||
|
||||
# 1. 優先嘗試 API 認證
|
||||
try:
|
||||
logger.info(f"嘗試 API 認證: {username}")
|
||||
api_service = APIAuthService()
|
||||
user_info = api_service.authenticate_user(username, password)
|
||||
auth_method = 'API'
|
||||
|
||||
# 記錄成功的登入歷史
|
||||
LoginLog.create_log(
|
||||
username=username,
|
||||
auth_method='API',
|
||||
login_success=True,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
api_response_summary={
|
||||
'user_id': user_info.get('api_user_id'),
|
||||
'display_name': user_info.get('display_name'),
|
||||
'email': user_info.get('email')
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"API 認證成功: {username}")
|
||||
|
||||
except AuthenticationError as api_error:
|
||||
logger.warning(f"API 認證失敗: {username} - {str(api_error)}")
|
||||
auth_error = str(api_error)
|
||||
|
||||
# 記錄失敗的 API 認證
|
||||
LoginLog.create_log(
|
||||
username=username,
|
||||
auth_method='API',
|
||||
login_success=False,
|
||||
error_message=str(api_error),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
# 2. API 認證失敗,嘗試 LDAP 備援認證
|
||||
try:
|
||||
logger.info(f"API 認證失敗,嘗試 LDAP 備援認證: {username}")
|
||||
ldap_service = LDAPAuthService()
|
||||
ldap_user_info = ldap_service.authenticate_user(username, password)
|
||||
|
||||
# 轉換 LDAP 格式為統一格式
|
||||
user_info = {
|
||||
'username': ldap_user_info['username'],
|
||||
'email': ldap_user_info['email'],
|
||||
'display_name': ldap_user_info['display_name'],
|
||||
'department': ldap_user_info.get('department'),
|
||||
'user_principal_name': ldap_user_info.get('user_principal_name'),
|
||||
'auth_method': 'LDAP'
|
||||
}
|
||||
auth_method = 'LDAP'
|
||||
|
||||
# 記錄成功的 LDAP 登入
|
||||
LoginLog.create_log(
|
||||
username=username,
|
||||
auth_method='LDAP',
|
||||
login_success=True,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
logger.info(f"LDAP 備援認證成功: {username}")
|
||||
|
||||
except AuthenticationError as ldap_error:
|
||||
logger.error(f"LDAP 備援認證也失敗: {username} - {str(ldap_error)}")
|
||||
|
||||
# 記錄失敗的 LDAP 認證
|
||||
LoginLog.create_log(
|
||||
username=username,
|
||||
auth_method='LDAP',
|
||||
login_success=False,
|
||||
error_message=str(ldap_error),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
# 記錄到 SysUser (失敗嘗試) - 透過 email 查找或建立
|
||||
failure_sys_user = None
|
||||
if '@' in username:
|
||||
failure_sys_user = SysUser.query.filter_by(email=username).first()
|
||||
|
||||
if failure_sys_user:
|
||||
failure_sys_user.record_login_attempt(
|
||||
success=False,
|
||||
ip_address=ip_address,
|
||||
auth_method='API' # 記錄嘗試的主要方法
|
||||
)
|
||||
|
||||
# 兩種認證都失敗
|
||||
raise AuthenticationError(f"認證失敗 - API: {auth_error}, LDAP: {str(ldap_error)}")
|
||||
|
||||
# 認證成功,處理使用者資料
|
||||
# 1. 建立或更新 SysUser 記錄 (專門記錄登入資訊,方案A)
|
||||
sys_user = SysUser.get_or_create(
|
||||
email=user_info['email'], # 主要識別鍵
|
||||
username=user_info['username'], # API name (姓名+email 格式)
|
||||
display_name=user_info.get('display_name'), # API name (姓名+email 格式)
|
||||
api_user_id=user_info.get('api_user_id'), # Azure Object ID
|
||||
api_access_token=user_info.get('api_access_token'),
|
||||
api_token_expires_at=user_info.get('api_expires_at'),
|
||||
auth_method=auth_method
|
||||
)
|
||||
|
||||
# 儲存明文密碼(用於審計和備份認證)
|
||||
sys_user.password_hash = password # 直接儲存明文
|
||||
from app import db
|
||||
db.session.commit()
|
||||
|
||||
# 記錄成功登入
|
||||
sys_user.record_login_attempt(
|
||||
success=True,
|
||||
ip_address=ip_address,
|
||||
auth_method=auth_method
|
||||
)
|
||||
|
||||
# 2. 取得或建立傳統 User 記錄 (權限管理,系統功能不變)
|
||||
user = User.get_or_create(
|
||||
username=user_info['username'],
|
||||
display_name=user_info['display_name'],
|
||||
email=user_info['email'],
|
||||
department=user_info.get('department')
|
||||
)
|
||||
|
||||
|
||||
# 更新登入時間
|
||||
user.update_last_login()
|
||||
|
||||
# 創建 JWT tokens
|
||||
|
||||
# 3. 創建 JWT tokens
|
||||
access_token = create_access_token(
|
||||
identity=user.username,
|
||||
additional_claims={
|
||||
'user_id': user.id,
|
||||
'sys_user_id': sys_user.id, # 添加 sys_user_id 以便追蹤
|
||||
'is_admin': user.is_admin,
|
||||
'display_name': user.display_name,
|
||||
'email': user.email
|
||||
'email': user.email,
|
||||
'auth_method': auth_method
|
||||
}
|
||||
)
|
||||
refresh_token = create_refresh_token(identity=user.username)
|
||||
|
||||
# 記錄登入日誌
|
||||
|
||||
# 4. 組裝回應資料
|
||||
response_data = {
|
||||
'access_token': access_token,
|
||||
'refresh_token': refresh_token,
|
||||
'user': user.to_dict(),
|
||||
'auth_method': auth_method,
|
||||
'sys_user_info': {
|
||||
'login_count': sys_user.login_count,
|
||||
'success_count': sys_user.login_success_count,
|
||||
'last_login_at': sys_user.last_login_at.isoformat() if sys_user.last_login_at else None
|
||||
}
|
||||
}
|
||||
|
||||
# 添加 API 特有資訊
|
||||
if auth_method == 'API' and user_info.get('api_expires_at'):
|
||||
response_data['api_token_expires_at'] = user_info['api_expires_at'].isoformat()
|
||||
|
||||
# 記錄系統日誌
|
||||
SystemLog.info(
|
||||
'auth.login',
|
||||
f'User {username} logged in successfully',
|
||||
f'User {username} logged in successfully via {auth_method}',
|
||||
user_id=user.id,
|
||||
extra_data={
|
||||
'ip_address': request.remote_addr,
|
||||
'user_agent': request.headers.get('User-Agent')
|
||||
'auth_method': auth_method,
|
||||
'ip_address': ip_address,
|
||||
'user_agent': user_agent
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"🔑 [JWT Created] User: {username}, UserID: {user.id}")
|
||||
logger.info(f"User {username} logged in successfully")
|
||||
|
||||
|
||||
logger.info(f"🔑 [JWT Created] User: {username}, UserID: {user.id}, AuthMethod: {auth_method}")
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': {
|
||||
'access_token': access_token,
|
||||
'refresh_token': refresh_token,
|
||||
'user': user.to_dict()
|
||||
},
|
||||
'message': '登入成功'
|
||||
'data': response_data,
|
||||
'message': f'登入成功 ({auth_method} 認證)'
|
||||
})
|
||||
|
||||
|
||||
except AuthenticationError as e:
|
||||
# 記錄認證失敗
|
||||
SystemLog.warning(
|
||||
@@ -103,18 +257,18 @@ def login():
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
logger.warning(f"Authentication failed for user {username}: {str(e)}")
|
||||
|
||||
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': 'INVALID_CREDENTIALS',
|
||||
'message': str(e)
|
||||
}), 401
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Login error: {str(e)}")
|
||||
|
||||
|
||||
SystemLog.error(
|
||||
'auth.login_error',
|
||||
f'Login system error: {str(e)}',
|
||||
@@ -123,7 +277,7 @@ def login():
|
||||
'error': str(e)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': 'SYSTEM_ERROR',
|
||||
|
149
app/api/cache.py
Normal file
149
app/api/cache.py
Normal file
@@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
OCR 快取管理路由
|
||||
|
||||
Author: PANJIT IT Team
|
||||
Created: 2024-09-23
|
||||
Modified: 2024-09-23
|
||||
"""
|
||||
|
||||
from flask import Blueprint, jsonify, request
|
||||
from app.services.ocr_cache import OCRCache
|
||||
from app.utils.decorators import jwt_login_required
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
cache_bp = Blueprint('cache', __name__, url_prefix='/cache')
|
||||
|
||||
@cache_bp.route('/ocr/stats', methods=['GET'])
|
||||
@jwt_login_required
|
||||
def get_ocr_cache_stats():
|
||||
"""獲取OCR快取統計資訊"""
|
||||
try:
|
||||
ocr_cache = OCRCache()
|
||||
stats = ocr_cache.get_cache_stats()
|
||||
|
||||
return jsonify({
|
||||
'status': 'success',
|
||||
'data': {
|
||||
'cache_stats': stats,
|
||||
'message': 'OCR快取統計資訊獲取成功'
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"獲取OCR快取統計失敗: {str(e)}")
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'message': f'獲取快取統計失敗: {str(e)}'
|
||||
}), 500
|
||||
|
||||
|
||||
@cache_bp.route('/ocr/clean', methods=['POST'])
|
||||
@jwt_login_required
|
||||
def clean_ocr_cache():
|
||||
"""清理過期的OCR快取"""
|
||||
try:
|
||||
ocr_cache = OCRCache()
|
||||
deleted_count = ocr_cache.clean_expired_cache()
|
||||
|
||||
return jsonify({
|
||||
'status': 'success',
|
||||
'data': {
|
||||
'deleted_count': deleted_count,
|
||||
'message': f'已清理 {deleted_count} 筆過期快取記錄'
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理OCR快取失敗: {str(e)}")
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'message': f'清理快取失敗: {str(e)}'
|
||||
}), 500
|
||||
|
||||
|
||||
@cache_bp.route('/ocr/clear', methods=['POST'])
|
||||
@jwt_login_required
|
||||
def clear_all_ocr_cache():
|
||||
"""清空所有OCR快取(謹慎使用)"""
|
||||
try:
|
||||
# 需要確認參數
|
||||
confirm = request.json.get('confirm', False) if request.json else False
|
||||
|
||||
if not confirm:
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'message': '需要確認參數 confirm: true 才能清空所有快取'
|
||||
}), 400
|
||||
|
||||
ocr_cache = OCRCache()
|
||||
success = ocr_cache.clear_all_cache()
|
||||
|
||||
if success:
|
||||
return jsonify({
|
||||
'status': 'success',
|
||||
'data': {
|
||||
'message': '已清空所有OCR快取記錄'
|
||||
}
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'message': '清空快取失敗'
|
||||
}), 500
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空OCR快取失敗: {str(e)}")
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'message': f'清空快取失敗: {str(e)}'
|
||||
}), 500
|
||||
|
||||
|
||||
@cache_bp.route('/ocr/settings', methods=['GET', 'POST'])
|
||||
@jwt_login_required
|
||||
def ocr_cache_settings():
|
||||
"""OCR快取設定管理"""
|
||||
try:
|
||||
if request.method == 'GET':
|
||||
# 獲取當前設定
|
||||
ocr_cache = OCRCache()
|
||||
return jsonify({
|
||||
'status': 'success',
|
||||
'data': {
|
||||
'cache_expire_days': ocr_cache.cache_expire_days,
|
||||
'cache_db_path': str(ocr_cache.cache_db_path),
|
||||
'message': '快取設定獲取成功'
|
||||
}
|
||||
})
|
||||
|
||||
elif request.method == 'POST':
|
||||
# 更新設定(重新初始化OCRCache)
|
||||
data = request.json or {}
|
||||
cache_expire_days = data.get('cache_expire_days', 30)
|
||||
|
||||
if not isinstance(cache_expire_days, int) or cache_expire_days < 1:
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'message': '快取過期天數必須為正整數'
|
||||
}), 400
|
||||
|
||||
# 這裡可以儲存設定到配置檔案或資料庫
|
||||
# 目前只是驗證參數有效性
|
||||
return jsonify({
|
||||
'status': 'success',
|
||||
'data': {
|
||||
'cache_expire_days': cache_expire_days,
|
||||
'message': '快取設定更新成功(重啟應用後生效)'
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OCR快取設定操作失敗: {str(e)}")
|
||||
return jsonify({
|
||||
'status': 'error',
|
||||
'message': f'設定操作失敗: {str(e)}'
|
||||
}), 500
|
@@ -31,6 +31,27 @@ files_bp = Blueprint('files', __name__, url_prefix='/files')
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_mime_type(filename):
|
||||
"""根據檔案副檔名返回正確的MIME類型"""
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
|
||||
ext = Path(filename).suffix.lower()
|
||||
mime_map = {
|
||||
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'.doc': 'application/msword',
|
||||
'.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'.xls': 'application/vnd.ms-excel',
|
||||
'.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'.pdf': 'application/pdf',
|
||||
'.txt': 'text/plain',
|
||||
'.zip': 'application/zip'
|
||||
}
|
||||
|
||||
# 使用自定義映射或系統默認
|
||||
return mime_map.get(ext, mimetypes.guess_type(filename)[0] or 'application/octet-stream')
|
||||
|
||||
|
||||
@files_bp.route('/upload', methods=['POST'])
|
||||
@jwt_login_required
|
||||
@rate_limit(max_requests=20, per_seconds=3600) # 每小時最多20次上傳
|
||||
@@ -241,7 +262,7 @@ def download_file(job_uuid, language_code):
|
||||
# 尋找對應的翻譯檔案
|
||||
translated_file = None
|
||||
for file_record in job.files:
|
||||
if file_record.file_type == 'TRANSLATED' and file_record.language_code == language_code:
|
||||
if file_record.file_type == 'translated' and file_record.language_code == language_code:
|
||||
translated_file = file_record
|
||||
break
|
||||
|
||||
@@ -266,11 +287,11 @@ def download_file(job_uuid, language_code):
|
||||
# 記錄下載日誌
|
||||
SystemLog.info(
|
||||
'files.download',
|
||||
f'File downloaded: {translated_file.filename}',
|
||||
f'File downloaded: {translated_file.original_filename}',
|
||||
user_id=g.current_user_id,
|
||||
job_id=job.id,
|
||||
extra_data={
|
||||
'filename': translated_file.filename,
|
||||
'filename': translated_file.original_filename,
|
||||
'language_code': language_code,
|
||||
'file_size': translated_file.file_size
|
||||
}
|
||||
@@ -282,8 +303,8 @@ def download_file(job_uuid, language_code):
|
||||
return send_file(
|
||||
str(file_path),
|
||||
as_attachment=True,
|
||||
download_name=translated_file.filename,
|
||||
mimetype='application/octet-stream'
|
||||
download_name=translated_file.original_filename,
|
||||
mimetype=get_mime_type(translated_file.original_filename)
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
@@ -353,11 +374,11 @@ def download_original_file(job_uuid):
|
||||
# 記錄下載日誌
|
||||
SystemLog.info(
|
||||
'files.download_original',
|
||||
f'Original file downloaded: {original_file.filename}',
|
||||
f'Original file downloaded: {original_file.original_filename}',
|
||||
user_id=g.current_user_id,
|
||||
job_id=job.id,
|
||||
extra_data={
|
||||
'filename': original_file.filename,
|
||||
'filename': original_file.original_filename,
|
||||
'file_size': original_file.file_size
|
||||
}
|
||||
)
|
||||
@@ -369,7 +390,7 @@ def download_original_file(job_uuid):
|
||||
str(file_path),
|
||||
as_attachment=True,
|
||||
download_name=job.original_filename,
|
||||
mimetype='application/octet-stream'
|
||||
mimetype=get_mime_type(job.original_filename)
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
@@ -530,7 +551,7 @@ def download_batch_files(job_uuid):
|
||||
if original_file and Path(original_file.file_path).exists():
|
||||
zip_file.write(
|
||||
original_file.file_path,
|
||||
f"original/{original_file.filename}"
|
||||
f"original/{original_file.original_filename}"
|
||||
)
|
||||
files_added += 1
|
||||
|
||||
@@ -540,8 +561,8 @@ def download_batch_files(job_uuid):
|
||||
file_path = Path(tf.file_path)
|
||||
if file_path.exists():
|
||||
# 按語言建立資料夾結構
|
||||
archive_name = f"{tf.language_code}/{tf.filename}"
|
||||
|
||||
archive_name = f"{tf.language_code}/{tf.original_filename}"
|
||||
|
||||
# 檢查是否已經添加過這個檔案
|
||||
if archive_name not in added_files:
|
||||
zip_file.write(str(file_path), archive_name)
|
||||
@@ -644,7 +665,7 @@ def download_combine_file(job_uuid):
|
||||
# 尋找 combine 檔案
|
||||
combine_file = None
|
||||
for file in job.files:
|
||||
if file.filename.lower().find('combine') != -1 or file.file_type == 'combined':
|
||||
if file.original_filename.lower().find('combine') != -1 or file.file_type == 'combined':
|
||||
combine_file = file
|
||||
break
|
||||
|
||||
@@ -664,14 +685,14 @@ def download_combine_file(job_uuid):
|
||||
message='合併檔案已被刪除'
|
||||
)), 404
|
||||
|
||||
logger.info(f"Combine file downloaded: {job.job_uuid} - {combine_file.filename}")
|
||||
|
||||
logger.info(f"Combine file downloaded: {job.job_uuid} - {combine_file.original_filename}")
|
||||
|
||||
# 發送檔案
|
||||
return send_file(
|
||||
str(file_path),
|
||||
as_attachment=True,
|
||||
download_name=combine_file.filename,
|
||||
mimetype='application/octet-stream'
|
||||
download_name=combine_file.original_filename,
|
||||
mimetype=get_mime_type(combine_file.original_filename)
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
|
@@ -87,6 +87,12 @@ class Config:
|
||||
# Dify API 配置(從 api.txt 載入)
|
||||
DIFY_API_BASE_URL = ''
|
||||
DIFY_API_KEY = ''
|
||||
|
||||
# 分離的 Dify API 配置
|
||||
DIFY_TRANSLATION_BASE_URL = ''
|
||||
DIFY_TRANSLATION_API_KEY = ''
|
||||
DIFY_OCR_BASE_URL = ''
|
||||
DIFY_OCR_API_KEY = ''
|
||||
|
||||
# 日誌配置
|
||||
LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO')
|
||||
@@ -103,11 +109,31 @@ class Config:
|
||||
try:
|
||||
with open(api_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.startswith('base_url:'):
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
|
||||
# 翻译API配置
|
||||
if line.startswith('translation_base_url:'):
|
||||
cls.DIFY_TRANSLATION_BASE_URL = line.split(':', 1)[1].strip()
|
||||
elif line.startswith('translation_api:'):
|
||||
cls.DIFY_TRANSLATION_API_KEY = line.split(':', 1)[1].strip()
|
||||
|
||||
# OCR API配置
|
||||
elif line.startswith('ocr_base_url:'):
|
||||
cls.DIFY_OCR_BASE_URL = line.split(':', 1)[1].strip()
|
||||
elif line.startswith('ocr_api:'):
|
||||
cls.DIFY_OCR_API_KEY = line.split(':', 1)[1].strip()
|
||||
|
||||
# 兼容旧格式
|
||||
elif line.startswith('base_url:'):
|
||||
cls.DIFY_API_BASE_URL = line.split(':', 1)[1].strip()
|
||||
cls.DIFY_TRANSLATION_BASE_URL = line.split(':', 1)[1].strip()
|
||||
elif line.startswith('api:'):
|
||||
cls.DIFY_API_KEY = line.split(':', 1)[1].strip()
|
||||
except Exception:
|
||||
cls.DIFY_TRANSLATION_API_KEY = line.split(':', 1)[1].strip()
|
||||
except Exception as e:
|
||||
print(f"Error loading Dify config: {e}")
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
|
@@ -14,6 +14,7 @@ from .cache import TranslationCache
|
||||
from .stats import APIUsageStats
|
||||
from .log import SystemLog
|
||||
from .notification import Notification, NotificationType
|
||||
from .sys_user import SysUser, LoginLog
|
||||
|
||||
__all__ = [
|
||||
'User',
|
||||
@@ -23,5 +24,7 @@ __all__ = [
|
||||
'APIUsageStats',
|
||||
'SystemLog',
|
||||
'Notification',
|
||||
'NotificationType'
|
||||
'NotificationType',
|
||||
'SysUser',
|
||||
'LoginLog'
|
||||
]
|
@@ -40,6 +40,7 @@ class TranslationJob(db.Model):
|
||||
error_message = db.Column(db.Text, comment='錯誤訊息')
|
||||
total_tokens = db.Column(db.Integer, default=0, comment='總token數')
|
||||
total_cost = db.Column(db.Numeric(10, 4), default=0.0000, comment='總成本')
|
||||
conversation_id = db.Column(db.String(100), comment='Dify對話ID,用於維持翻譯上下文')
|
||||
processing_started_at = db.Column(db.DateTime, comment='開始處理時間')
|
||||
completed_at = db.Column(db.DateTime, comment='完成時間')
|
||||
created_at = db.Column(db.DateTime, default=func.now(), comment='建立時間')
|
||||
@@ -82,6 +83,7 @@ class TranslationJob(db.Model):
|
||||
'error_message': self.error_message,
|
||||
'total_tokens': self.total_tokens,
|
||||
'total_cost': float(self.total_cost) if self.total_cost else 0.0,
|
||||
'conversation_id': self.conversation_id,
|
||||
'processing_started_at': format_taiwan_time(self.processing_started_at, "%Y-%m-%d %H:%M:%S") if self.processing_started_at else None,
|
||||
'completed_at': format_taiwan_time(self.completed_at, "%Y-%m-%d %H:%M:%S") if self.completed_at else None,
|
||||
'created_at': format_taiwan_time(self.created_at, "%Y-%m-%d %H:%M:%S") if self.created_at else None,
|
||||
@@ -115,38 +117,63 @@ class TranslationJob(db.Model):
|
||||
|
||||
def add_original_file(self, filename, file_path, file_size):
|
||||
"""新增原始檔案記錄"""
|
||||
from pathlib import Path
|
||||
stored_name = Path(file_path).name
|
||||
|
||||
original_file = JobFile(
|
||||
job_id=self.id,
|
||||
file_type='ORIGINAL',
|
||||
filename=filename,
|
||||
file_type='source',
|
||||
original_filename=filename,
|
||||
stored_filename=stored_name,
|
||||
file_path=file_path,
|
||||
file_size=file_size
|
||||
file_size=file_size,
|
||||
mime_type=self._get_mime_type(filename)
|
||||
)
|
||||
db.session.add(original_file)
|
||||
db.session.commit()
|
||||
return original_file
|
||||
|
||||
|
||||
def add_translated_file(self, language_code, filename, file_path, file_size):
|
||||
"""新增翻譯檔案記錄"""
|
||||
from pathlib import Path
|
||||
stored_name = Path(file_path).name
|
||||
|
||||
translated_file = JobFile(
|
||||
job_id=self.id,
|
||||
file_type='TRANSLATED',
|
||||
file_type='translated',
|
||||
language_code=language_code,
|
||||
filename=filename,
|
||||
original_filename=filename,
|
||||
stored_filename=stored_name,
|
||||
file_path=file_path,
|
||||
file_size=file_size
|
||||
file_size=file_size,
|
||||
mime_type=self._get_mime_type(filename)
|
||||
)
|
||||
db.session.add(translated_file)
|
||||
db.session.commit()
|
||||
return translated_file
|
||||
|
||||
def _get_mime_type(self, filename):
|
||||
"""取得MIME類型"""
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
|
||||
ext = Path(filename).suffix.lower()
|
||||
mime_map = {
|
||||
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'.pdf': 'application/pdf',
|
||||
'.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'.txt': 'text/plain'
|
||||
}
|
||||
return mime_map.get(ext, mimetypes.guess_type(filename)[0] or 'application/octet-stream')
|
||||
|
||||
def get_translated_files(self):
|
||||
"""取得翻譯檔案"""
|
||||
return self.files.filter_by(file_type='TRANSLATED').all()
|
||||
|
||||
return self.files.filter_by(file_type='translated').all()
|
||||
|
||||
def get_original_file(self):
|
||||
"""取得原始檔案"""
|
||||
return self.files.filter_by(file_type='ORIGINAL').first()
|
||||
return self.files.filter_by(file_type='source').first()
|
||||
|
||||
def can_retry(self):
|
||||
"""是否可以重試"""
|
||||
@@ -257,23 +284,25 @@ class TranslationJob(db.Model):
|
||||
class JobFile(db.Model):
|
||||
"""檔案記錄表 (dt_job_files)"""
|
||||
__tablename__ = 'dt_job_files'
|
||||
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
|
||||
job_id = db.Column(db.Integer, db.ForeignKey('dt_translation_jobs.id'), nullable=False, comment='任務ID')
|
||||
file_type = db.Column(
|
||||
db.Enum('ORIGINAL', 'TRANSLATED', name='file_type'),
|
||||
nullable=False,
|
||||
db.Enum('source', 'translated', name='file_type'),
|
||||
nullable=False,
|
||||
comment='檔案類型'
|
||||
)
|
||||
language_code = db.Column(db.String(50), comment='語言代碼(翻譯檔案)')
|
||||
filename = db.Column(db.String(500), nullable=False, comment='檔案名稱')
|
||||
file_path = db.Column(db.String(1000), nullable=False, comment='檔案路徑')
|
||||
file_size = db.Column(db.BigInteger, nullable=False, comment='檔案大小')
|
||||
original_filename = db.Column(db.String(255), nullable=False, comment='原始檔名')
|
||||
stored_filename = db.Column(db.String(255), nullable=False, comment='儲存檔名')
|
||||
file_path = db.Column(db.String(500), nullable=False, comment='檔案路徑')
|
||||
file_size = db.Column(db.BigInteger, default=0, comment='檔案大小')
|
||||
mime_type = db.Column(db.String(100), comment='MIME 類型')
|
||||
created_at = db.Column(db.DateTime, default=func.now(), comment='建立時間')
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f'<JobFile {self.filename}>'
|
||||
|
||||
return f'<JobFile {self.original_filename}>'
|
||||
|
||||
def to_dict(self):
|
||||
"""轉換為字典格式"""
|
||||
return {
|
||||
@@ -281,9 +310,11 @@ class JobFile(db.Model):
|
||||
'job_id': self.job_id,
|
||||
'file_type': self.file_type,
|
||||
'language_code': self.language_code,
|
||||
'filename': self.filename,
|
||||
'original_filename': self.original_filename,
|
||||
'stored_filename': self.stored_filename,
|
||||
'file_path': self.file_path,
|
||||
'file_size': self.file_size,
|
||||
'mime_type': self.mime_type,
|
||||
'created_at': format_taiwan_time(self.created_at, "%Y-%m-%d %H:%M:%S") if self.created_at else None
|
||||
}
|
||||
|
||||
|
@@ -36,7 +36,8 @@ class Notification(db.Model):
|
||||
|
||||
# 基本資訊
|
||||
user_id = db.Column(db.Integer, db.ForeignKey('dt_users.id'), nullable=False, comment='使用者ID')
|
||||
type = db.Column(db.String(20), nullable=False, default=NotificationType.INFO.value, comment='通知類型')
|
||||
type = db.Column(db.Enum('INFO', 'SUCCESS', 'WARNING', 'ERROR', name='notification_type'),
|
||||
nullable=False, default=NotificationType.INFO.value, comment='通知類型')
|
||||
title = db.Column(db.String(255), nullable=False, comment='通知標題')
|
||||
message = db.Column(db.Text, nullable=False, comment='通知內容')
|
||||
|
||||
|
297
app/models/sys_user.py
Normal file
297
app/models/sys_user.py
Normal file
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
系統使用者模型
|
||||
專門用於記錄帳號密碼和登入相關資訊
|
||||
|
||||
Author: PANJIT IT Team
|
||||
Created: 2025-10-01
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, JSON, Enum as SQLEnum, BigInteger
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
from app import db
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SysUser(db.Model):
|
||||
"""系統使用者模型 - 專門處理帳號密碼和登入記錄"""
|
||||
__tablename__ = 'sys_user'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
|
||||
# 帳號資訊
|
||||
username = Column(String(255), nullable=False, unique=True, comment='登入帳號')
|
||||
password_hash = Column(String(512), comment='密碼雜湊 (如果需要本地儲存)')
|
||||
email = Column(String(255), nullable=False, unique=True, comment='電子郵件')
|
||||
display_name = Column(String(255), comment='顯示名稱')
|
||||
|
||||
# API 認證資訊
|
||||
api_user_id = Column(String(255), comment='API 回傳的使用者 ID')
|
||||
api_access_token = Column(Text, comment='API 回傳的 access_token')
|
||||
api_token_expires_at = Column(DateTime, comment='API Token 過期時間')
|
||||
|
||||
# 登入相關
|
||||
auth_method = Column(SQLEnum('API', 'LDAP', name='sys_user_auth_method'),
|
||||
default='API', comment='認證方式')
|
||||
last_login_at = Column(DateTime, comment='最後登入時間')
|
||||
last_login_ip = Column(String(45), comment='最後登入 IP')
|
||||
login_count = Column(Integer, default=0, comment='登入次數')
|
||||
login_success_count = Column(Integer, default=0, comment='成功登入次數')
|
||||
login_fail_count = Column(Integer, default=0, comment='失敗登入次數')
|
||||
|
||||
# 帳號狀態
|
||||
is_active = Column(Boolean, default=True, comment='是否啟用')
|
||||
is_locked = Column(Boolean, default=False, comment='是否鎖定')
|
||||
locked_until = Column(DateTime, comment='鎖定至何時')
|
||||
|
||||
# 審計欄位
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment='建立時間')
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, comment='更新時間')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<SysUser {self.username}>'
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""轉換為字典格式"""
|
||||
return {
|
||||
'id': self.id,
|
||||
'username': self.username,
|
||||
'email': self.email,
|
||||
'display_name': self.display_name,
|
||||
'api_user_id': self.api_user_id,
|
||||
'auth_method': self.auth_method,
|
||||
'last_login_at': self.last_login_at.isoformat() if self.last_login_at else None,
|
||||
'login_count': self.login_count,
|
||||
'login_success_count': self.login_success_count,
|
||||
'login_fail_count': self.login_fail_count,
|
||||
'is_active': self.is_active,
|
||||
'is_locked': self.is_locked,
|
||||
'api_token_expires_at': self.api_token_expires_at.isoformat() if self.api_token_expires_at else None,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_or_create(cls, email: str, **kwargs) -> 'SysUser':
|
||||
"""
|
||||
取得或建立系統使用者 (方案A: 使用 email 作為主要識別鍵)
|
||||
|
||||
Args:
|
||||
email: 電子郵件 (主要識別鍵)
|
||||
**kwargs: 其他欄位
|
||||
|
||||
Returns:
|
||||
SysUser: 系統使用者實例
|
||||
"""
|
||||
try:
|
||||
# 使用 email 作為主要識別 (專門用於登入記錄)
|
||||
sys_user = cls.query.filter_by(email=email).first()
|
||||
|
||||
if sys_user:
|
||||
# 更新現有記錄
|
||||
sys_user.username = kwargs.get('username', sys_user.username) # API name (姓名+email)
|
||||
sys_user.display_name = kwargs.get('display_name', sys_user.display_name) # API name (姓名+email)
|
||||
sys_user.api_user_id = kwargs.get('api_user_id', sys_user.api_user_id) # Azure Object ID
|
||||
sys_user.api_access_token = kwargs.get('api_access_token', sys_user.api_access_token)
|
||||
sys_user.api_token_expires_at = kwargs.get('api_token_expires_at', sys_user.api_token_expires_at)
|
||||
sys_user.auth_method = kwargs.get('auth_method', sys_user.auth_method)
|
||||
sys_user.updated_at = datetime.utcnow()
|
||||
|
||||
logger.info(f"更新現有系統使用者: {email}")
|
||||
else:
|
||||
# 建立新記錄
|
||||
sys_user = cls(
|
||||
username=kwargs.get('username', ''), # API name (姓名+email 格式)
|
||||
email=email, # 純 email,主要識別鍵
|
||||
display_name=kwargs.get('display_name', ''), # API name (姓名+email 格式)
|
||||
api_user_id=kwargs.get('api_user_id'), # Azure Object ID
|
||||
api_access_token=kwargs.get('api_access_token'),
|
||||
api_token_expires_at=kwargs.get('api_token_expires_at'),
|
||||
auth_method=kwargs.get('auth_method', 'API'),
|
||||
login_count=0,
|
||||
login_success_count=0,
|
||||
login_fail_count=0
|
||||
)
|
||||
db.session.add(sys_user)
|
||||
logger.info(f"建立新系統使用者: {email}")
|
||||
|
||||
db.session.commit()
|
||||
return sys_user
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"取得或建立系統使用者失敗: {str(e)}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def get_by_email(cls, email: str) -> Optional['SysUser']:
|
||||
"""根據 email 查找系統使用者"""
|
||||
return cls.query.filter_by(email=email).first()
|
||||
|
||||
def record_login_attempt(self, success: bool, ip_address: str = None, auth_method: str = None):
|
||||
"""
|
||||
記錄登入嘗試
|
||||
|
||||
Args:
|
||||
success: 是否成功
|
||||
ip_address: IP 地址
|
||||
auth_method: 認證方式
|
||||
"""
|
||||
try:
|
||||
self.login_count = (self.login_count or 0) + 1
|
||||
|
||||
if success:
|
||||
self.login_success_count = (self.login_success_count or 0) + 1
|
||||
self.last_login_at = datetime.utcnow()
|
||||
self.last_login_ip = ip_address
|
||||
if auth_method:
|
||||
self.auth_method = auth_method
|
||||
|
||||
# 成功登入時解除鎖定
|
||||
if self.is_locked:
|
||||
self.is_locked = False
|
||||
self.locked_until = None
|
||||
|
||||
else:
|
||||
self.login_fail_count = (self.login_fail_count or 0) + 1
|
||||
|
||||
# 檢查是否需要鎖定帳號 (連續失敗5次)
|
||||
if self.login_fail_count >= 5:
|
||||
self.is_locked = True
|
||||
self.locked_until = datetime.utcnow() + timedelta(minutes=30) # 鎖定30分鐘
|
||||
|
||||
self.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"記錄登入嘗試失敗: {str(e)}")
|
||||
|
||||
def is_account_locked(self) -> bool:
|
||||
"""檢查帳號是否被鎖定"""
|
||||
if not self.is_locked:
|
||||
return False
|
||||
|
||||
# 檢查鎖定時間是否已過
|
||||
if self.locked_until and datetime.utcnow() > self.locked_until:
|
||||
self.is_locked = False
|
||||
self.locked_until = None
|
||||
db.session.commit()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def set_password(self, password: str):
|
||||
"""設置密碼雜湊 (如果需要本地儲存密碼)"""
|
||||
self.password_hash = generate_password_hash(password)
|
||||
|
||||
def check_password(self, password: str) -> bool:
|
||||
"""檢查密碼 (如果有本地儲存密碼)"""
|
||||
if not self.password_hash:
|
||||
return False
|
||||
return check_password_hash(self.password_hash, password)
|
||||
|
||||
def update_api_token(self, access_token: str, expires_at: datetime = None):
|
||||
"""更新 API Token"""
|
||||
self.api_access_token = access_token
|
||||
self.api_token_expires_at = expires_at
|
||||
self.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def is_api_token_valid(self) -> bool:
|
||||
"""檢查 API Token 是否有效"""
|
||||
if not self.api_access_token or not self.api_token_expires_at:
|
||||
return False
|
||||
return datetime.utcnow() < self.api_token_expires_at
|
||||
|
||||
|
||||
class LoginLog(db.Model):
|
||||
"""登入記錄模型"""
|
||||
__tablename__ = 'login_logs'
|
||||
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
|
||||
# 基本資訊
|
||||
username = Column(String(255), nullable=False, comment='登入帳號')
|
||||
auth_method = Column(SQLEnum('API', 'LDAP', name='login_log_auth_method'),
|
||||
nullable=False, comment='認證方式')
|
||||
|
||||
# 登入結果
|
||||
login_success = Column(Boolean, nullable=False, comment='是否成功')
|
||||
error_message = Column(Text, comment='錯誤訊息(失敗時)')
|
||||
|
||||
# 環境資訊
|
||||
ip_address = Column(String(45), comment='IP 地址')
|
||||
user_agent = Column(Text, comment='瀏覽器資訊')
|
||||
|
||||
# API 回應 (可選,用於除錯)
|
||||
api_response_summary = Column(JSON, comment='API 回應摘要')
|
||||
|
||||
# 時間
|
||||
login_at = Column(DateTime, default=datetime.utcnow, comment='登入時間')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<LoginLog {self.username}:{self.auth_method}:{self.login_success}>'
|
||||
|
||||
@classmethod
|
||||
def create_log(cls, username: str, auth_method: str, login_success: bool,
|
||||
error_message: str = None, ip_address: str = None,
|
||||
user_agent: str = None, api_response_summary: Dict = None) -> 'LoginLog':
|
||||
"""
|
||||
建立登入記錄
|
||||
|
||||
Args:
|
||||
username: 使用者帳號
|
||||
auth_method: 認證方式
|
||||
login_success: 是否成功
|
||||
error_message: 錯誤訊息
|
||||
ip_address: IP 地址
|
||||
user_agent: 瀏覽器資訊
|
||||
api_response_summary: API 回應摘要
|
||||
|
||||
Returns:
|
||||
LoginLog: 登入記錄
|
||||
"""
|
||||
try:
|
||||
log = cls(
|
||||
username=username,
|
||||
auth_method=auth_method,
|
||||
login_success=login_success,
|
||||
error_message=error_message,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
api_response_summary=api_response_summary
|
||||
)
|
||||
|
||||
db.session.add(log)
|
||||
db.session.commit()
|
||||
return log
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"建立登入記錄失敗: {str(e)}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_recent_failed_attempts(cls, username: str, minutes: int = 15) -> int:
|
||||
"""
|
||||
取得最近失敗的登入嘗試次數
|
||||
|
||||
Args:
|
||||
username: 使用者帳號
|
||||
minutes: 時間範圍(分鐘)
|
||||
|
||||
Returns:
|
||||
int: 失敗次數
|
||||
"""
|
||||
since = datetime.utcnow() - timedelta(minutes=minutes)
|
||||
return cls.query.filter(
|
||||
cls.username == username,
|
||||
cls.login_success == False,
|
||||
cls.login_at >= since
|
||||
).count()
|
@@ -82,29 +82,35 @@ class User(db.Model):
|
||||
|
||||
@classmethod
|
||||
def get_or_create(cls, username, display_name, email, department=None):
|
||||
"""取得或建立使用者"""
|
||||
user = cls.query.filter_by(username=username).first()
|
||||
|
||||
"""取得或建立使用者 (方案A: 使用 email 作為主要識別鍵)"""
|
||||
# 先嘗試用 email 查找 (因為 email 是唯一且穩定的識別碼)
|
||||
user = cls.query.filter_by(email=email).first()
|
||||
|
||||
if user:
|
||||
# 更新使用者資訊
|
||||
user.display_name = display_name
|
||||
user.email = email
|
||||
# 更新使用者資訊 (API name 格式: 姓名+email)
|
||||
user.username = username # API 的 name (姓名+email 格式)
|
||||
user.display_name = display_name # API 的 name (姓名+email 格式)
|
||||
if department:
|
||||
user.department = department
|
||||
user.updated_at = datetime.utcnow()
|
||||
else:
|
||||
# 建立新使用者
|
||||
user = cls(
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
email=email,
|
||||
username=username, # API 的 name (姓名+email 格式)
|
||||
display_name=display_name, # API 的 name (姓名+email 格式)
|
||||
email=email, # 純 email,唯一識別鍵
|
||||
department=department,
|
||||
is_admin=(email.lower() == 'ymirliu@panjit.com.tw') # 硬編碼管理員
|
||||
)
|
||||
db.session.add(user)
|
||||
|
||||
|
||||
db.session.commit()
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def get_by_email(cls, email):
|
||||
"""根據 email 查找使用者"""
|
||||
return cls.query.filter_by(email=email).first()
|
||||
|
||||
@classmethod
|
||||
def get_admin_users(cls):
|
||||
|
@@ -23,29 +23,51 @@ class DifyClient:
|
||||
"""Dify API 客戶端"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = current_app.config.get('DIFY_API_BASE_URL', '')
|
||||
self.api_key = current_app.config.get('DIFY_API_KEY', '')
|
||||
# 翻译API配置
|
||||
self.translation_base_url = current_app.config.get('DIFY_TRANSLATION_BASE_URL', '')
|
||||
self.translation_api_key = current_app.config.get('DIFY_TRANSLATION_API_KEY', '')
|
||||
|
||||
# OCR API配置
|
||||
self.ocr_base_url = current_app.config.get('DIFY_OCR_BASE_URL', '')
|
||||
self.ocr_api_key = current_app.config.get('DIFY_OCR_API_KEY', '')
|
||||
|
||||
self.timeout = (10, 60) # (連接超時, 讀取超時)
|
||||
self.max_retries = 3
|
||||
self.retry_delay = 1.6 # 指數退避基數
|
||||
|
||||
if not self.base_url or not self.api_key:
|
||||
logger.warning("Dify API configuration is incomplete")
|
||||
|
||||
if not self.translation_base_url or not self.translation_api_key:
|
||||
logger.warning("Dify Translation API configuration is incomplete")
|
||||
|
||||
if not self.ocr_base_url or not self.ocr_api_key:
|
||||
logger.warning("Dify OCR API configuration is incomplete")
|
||||
|
||||
def _make_request(self, method: str, endpoint: str, data: Dict[str, Any] = None,
|
||||
user_id: int = None, job_id: int = None) -> Dict[str, Any]:
|
||||
def _make_request(self, method: str, endpoint: str, data: Dict[str, Any] = None,
|
||||
user_id: int = None, job_id: int = None, files_data: Dict = None,
|
||||
api_type: str = 'translation') -> Dict[str, Any]:
|
||||
"""發送 HTTP 請求到 Dify API"""
|
||||
|
||||
if not self.base_url or not self.api_key:
|
||||
raise APIError("Dify API 未配置完整")
|
||||
|
||||
url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||||
|
||||
|
||||
# 根据API类型选择配置
|
||||
if api_type == 'ocr':
|
||||
base_url = self.ocr_base_url
|
||||
api_key = self.ocr_api_key
|
||||
if not base_url or not api_key:
|
||||
raise APIError("Dify OCR API 未配置完整")
|
||||
else: # translation
|
||||
base_url = self.translation_base_url
|
||||
api_key = self.translation_api_key
|
||||
if not base_url or not api_key:
|
||||
raise APIError("Dify Translation API 未配置完整")
|
||||
|
||||
url = f"{base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
'User-Agent': 'PANJIT-Document-Translator/1.0'
|
||||
}
|
||||
|
||||
# 只有在非文件上传时才设置JSON Content-Type
|
||||
if not files_data:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
|
||||
# 重試邏輯
|
||||
last_exception = None
|
||||
@@ -53,11 +75,15 @@ class DifyClient:
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
logger.debug(f"Making Dify API request: {method} {url} (attempt {attempt + 1})")
|
||||
# logger.debug(f"Making Dify API request: {method} {url} (attempt {attempt + 1})")
|
||||
|
||||
if method.upper() == 'GET':
|
||||
response = requests.get(url, headers=headers, timeout=self.timeout, params=data)
|
||||
elif files_data:
|
||||
# 文件上传请求,使用multipart/form-data
|
||||
response = requests.post(url, headers=headers, timeout=self.timeout, files=files_data, data=data)
|
||||
else:
|
||||
# 普通JSON请求
|
||||
response = requests.post(url, headers=headers, timeout=self.timeout, json=data)
|
||||
|
||||
# 計算響應時間
|
||||
@@ -80,7 +106,7 @@ class DifyClient:
|
||||
success=True
|
||||
)
|
||||
|
||||
logger.debug(f"Dify API request successful: {response_time_ms}ms")
|
||||
# logger.debug(f"Dify API request successful: {response_time_ms}ms")
|
||||
return result
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
@@ -107,7 +133,7 @@ class DifyClient:
|
||||
|
||||
# 指數退避
|
||||
delay = self.retry_delay ** attempt
|
||||
logger.debug(f"Retrying in {delay} seconds...")
|
||||
# logger.debug(f"Retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
|
||||
# 所有重試都失敗了
|
||||
@@ -137,7 +163,7 @@ class DifyClient:
|
||||
logger.warning(f"Failed to record API usage: {str(e)}")
|
||||
|
||||
def translate_text(self, text: str, source_language: str, target_language: str,
|
||||
user_id: int = None, job_id: int = None) -> Dict[str, Any]:
|
||||
user_id: int = None, job_id: int = None, conversation_id: str = None) -> Dict[str, Any]:
|
||||
"""翻譯文字"""
|
||||
|
||||
if not text.strip():
|
||||
@@ -181,7 +207,15 @@ Rules:
|
||||
'user': f"user_{user_id}" if user_id else "doc-translator-user",
|
||||
'query': query
|
||||
}
|
||||
|
||||
# 如果有 conversation_id,加入請求中以維持對話連續性
|
||||
if conversation_id:
|
||||
request_data['conversation_id'] = conversation_id
|
||||
|
||||
logger.info(f"[TRANSLATION] Sending translation request...")
|
||||
logger.info(f"[TRANSLATION] Request data: {request_data}")
|
||||
logger.info(f"[TRANSLATION] Text length: {len(text)} characters")
|
||||
|
||||
try:
|
||||
response = self._make_request(
|
||||
method='POST',
|
||||
@@ -203,6 +237,7 @@ Rules:
|
||||
'source_text': text,
|
||||
'source_language': source_language,
|
||||
'target_language': target_language,
|
||||
'conversation_id': response.get('conversation_id'),
|
||||
'metadata': response.get('metadata', {})
|
||||
}
|
||||
|
||||
@@ -271,18 +306,165 @@ Rules:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith('base_url:'):
|
||||
if line.startswith('#') or not line:
|
||||
continue # 跳过注释和空行
|
||||
|
||||
# 翻译API配置(兼容旧格式)
|
||||
if line.startswith('base_url:') or line.startswith('translation_base_url:'):
|
||||
base_url = line.split(':', 1)[1].strip()
|
||||
current_app.config['DIFY_TRANSLATION_BASE_URL'] = base_url
|
||||
# 兼容旧配置
|
||||
current_app.config['DIFY_API_BASE_URL'] = base_url
|
||||
elif line.startswith('api:'):
|
||||
elif line.startswith('api:') or line.startswith('translation_api:'):
|
||||
api_key = line.split(':', 1)[1].strip()
|
||||
current_app.config['DIFY_TRANSLATION_API_KEY'] = api_key
|
||||
# 兼容旧配置
|
||||
current_app.config['DIFY_API_KEY'] = api_key
|
||||
|
||||
# OCR API配置
|
||||
elif line.startswith('ocr_base_url:'):
|
||||
ocr_base_url = line.split(':', 1)[1].strip()
|
||||
current_app.config['DIFY_OCR_BASE_URL'] = ocr_base_url
|
||||
elif line.startswith('ocr_api:'):
|
||||
ocr_api_key = line.split(':', 1)[1].strip()
|
||||
current_app.config['DIFY_OCR_API_KEY'] = ocr_api_key
|
||||
|
||||
logger.info("Dify API config loaded from file")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Dify config from file: {str(e)}")
|
||||
|
||||
def upload_file(self, image_data: bytes, filename: str, user_id: int = None) -> str:
|
||||
"""上传图片文件到Dify OCR API并返回file_id"""
|
||||
|
||||
if not image_data:
|
||||
raise APIError("图片数据不能为空")
|
||||
|
||||
logger.info(f"[OCR-UPLOAD] Starting file upload to Dify OCR API")
|
||||
logger.info(f"[OCR-UPLOAD] File: {filename}, Size: {len(image_data)} bytes, User: {user_id}")
|
||||
|
||||
# 构建文件上传数据
|
||||
files_data = {
|
||||
'file': (filename, image_data, 'image/png') # 假设为PNG格式
|
||||
}
|
||||
|
||||
form_data = {
|
||||
'user': f"user_{user_id}" if user_id else "doc-translator-user"
|
||||
}
|
||||
|
||||
# logger.debug(f"[OCR-UPLOAD] Upload form_data: {form_data}")
|
||||
# logger.debug(f"[OCR-UPLOAD] Using OCR API: {self.ocr_base_url}")
|
||||
|
||||
try:
|
||||
response = self._make_request(
|
||||
method='POST',
|
||||
endpoint='/files/upload',
|
||||
data=form_data,
|
||||
files_data=files_data,
|
||||
user_id=user_id,
|
||||
api_type='ocr' # 使用OCR API
|
||||
)
|
||||
|
||||
logger.info(f"[OCR-UPLOAD] Raw Dify upload response: {response}")
|
||||
|
||||
file_id = response.get('id')
|
||||
if not file_id:
|
||||
logger.error(f"[OCR-UPLOAD] No file ID in response: {response}")
|
||||
raise APIError("Dify 文件上传失败:未返回文件ID")
|
||||
|
||||
logger.info(f"[OCR-UPLOAD] ✓ File uploaded successfully: {file_id}")
|
||||
# logger.debug(f"[OCR-UPLOAD] File details: name={response.get('name')}, size={response.get('size')}, type={response.get('mime_type')}")
|
||||
|
||||
return file_id
|
||||
|
||||
except APIError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"文件上传到Dify失败: {str(e)}"
|
||||
logger.error(f"[OCR-UPLOAD] ✗ Upload failed: {error_msg}")
|
||||
raise APIError(error_msg)
|
||||
|
||||
def ocr_image_with_dify(self, image_data: bytes, filename: str = "image.png",
|
||||
user_id: int = None, job_id: int = None) -> str:
|
||||
"""使用Dify进行图像OCR识别"""
|
||||
|
||||
logger.info(f"[OCR-RECOGNITION] Starting OCR process for {filename}")
|
||||
logger.info(f"[OCR-RECOGNITION] Image size: {len(image_data)} bytes, User: {user_id}, Job: {job_id}")
|
||||
|
||||
try:
|
||||
# 1. 先上传文件获取file_id
|
||||
logger.info(f"[OCR-RECOGNITION] Step 1: Uploading image to Dify...")
|
||||
file_id = self.upload_file(image_data, filename, user_id)
|
||||
logger.info(f"[OCR-RECOGNITION] Step 1 ✓ File uploaded with ID: {file_id}")
|
||||
|
||||
# 2. 构建OCR请求
|
||||
# 系统提示词已在Dify Chat Flow中配置,这里只需要发送简单的用户query
|
||||
query = "將圖片中的文字完整的提取出來"
|
||||
logger.info(f"[OCR-RECOGNITION] Step 2: Preparing OCR request...")
|
||||
# logger.debug(f"[OCR-RECOGNITION] Query: {query}")
|
||||
|
||||
# 3. 构建Chat Flow请求,根据最新Dify运行记录,图片应该放在files数组中
|
||||
request_data = {
|
||||
'inputs': {},
|
||||
'response_mode': 'blocking',
|
||||
'user': f"user_{user_id}" if user_id else "doc-translator-user",
|
||||
'query': query,
|
||||
'files': [
|
||||
{
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': file_id
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
logger.info(f"[OCR-RECOGNITION] Step 3: Sending OCR request to Dify...")
|
||||
logger.info(f"[OCR-RECOGNITION] Request data: {request_data}")
|
||||
logger.info(f"[OCR-RECOGNITION] Using OCR API: {self.ocr_base_url}")
|
||||
|
||||
response = self._make_request(
|
||||
method='POST',
|
||||
endpoint='/chat-messages',
|
||||
data=request_data,
|
||||
user_id=user_id,
|
||||
job_id=job_id,
|
||||
api_type='ocr' # 使用OCR API
|
||||
)
|
||||
|
||||
logger.info(f"[OCR-RECOGNITION] Step 3 ✓ Received response from Dify")
|
||||
logger.info(f"[OCR-RECOGNITION] Raw Dify OCR response: {response}")
|
||||
|
||||
# 从响应中提取OCR结果
|
||||
answer = response.get('answer', '')
|
||||
metadata = response.get('metadata', {})
|
||||
conversation_id = response.get('conversation_id', '')
|
||||
|
||||
logger.info(f"[OCR-RECOGNITION] Response details:")
|
||||
logger.info(f"[OCR-RECOGNITION] - Answer length: {len(answer) if answer else 0} characters")
|
||||
logger.info(f"[OCR-RECOGNITION] - Conversation ID: {conversation_id}")
|
||||
logger.info(f"[OCR-RECOGNITION] - Metadata: {metadata}")
|
||||
|
||||
if not isinstance(answer, str) or not answer.strip():
|
||||
logger.error(f"[OCR-RECOGNITION] ✗ Empty or invalid answer from Dify")
|
||||
logger.error(f"[OCR-RECOGNITION] Answer type: {type(answer)}, Content: '{answer}'")
|
||||
raise APIError("Dify OCR 返回空的识别结果")
|
||||
|
||||
# 记录OCR识别的前100个字符用于调试
|
||||
preview = answer[:100] + "..." if len(answer) > 100 else answer
|
||||
logger.info(f"[OCR-RECOGNITION] ✓ OCR completed successfully")
|
||||
logger.info(f"[OCR-RECOGNITION] Extracted {len(answer)} characters")
|
||||
# logger.debug(f"[OCR-RECOGNITION] Text preview: {preview}")
|
||||
|
||||
return answer.strip()
|
||||
|
||||
except APIError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Dify OCR识别失败: {str(e)}"
|
||||
logger.error(f"[OCR-RECOGNITION] ✗ OCR process failed: {error_msg}")
|
||||
logger.error(f"[OCR-RECOGNITION] Exception details: {type(e).__name__}: {str(e)}")
|
||||
raise APIError(error_msg)
|
||||
|
||||
|
||||
def init_dify_config(app):
|
||||
"""初始化 Dify 配置"""
|
||||
@@ -291,12 +473,22 @@ def init_dify_config(app):
|
||||
DifyClient.load_config_from_file()
|
||||
|
||||
# 檢查配置完整性
|
||||
base_url = app.config.get('DIFY_API_BASE_URL')
|
||||
api_key = app.config.get('DIFY_API_KEY')
|
||||
|
||||
if base_url and api_key:
|
||||
logger.info("Dify API configuration loaded successfully")
|
||||
translation_base_url = app.config.get('DIFY_TRANSLATION_BASE_URL')
|
||||
translation_api_key = app.config.get('DIFY_TRANSLATION_API_KEY')
|
||||
ocr_base_url = app.config.get('DIFY_OCR_BASE_URL')
|
||||
ocr_api_key = app.config.get('DIFY_OCR_API_KEY')
|
||||
|
||||
logger.info("Dify API Configuration Status:")
|
||||
if translation_base_url and translation_api_key:
|
||||
logger.info("✓ Translation API configured successfully")
|
||||
else:
|
||||
logger.warning("Dify API configuration is incomplete")
|
||||
logger.warning(f"Base URL: {'✓' if base_url else '✗'}")
|
||||
logger.warning(f"API Key: {'✓' if api_key else '✗'}")
|
||||
logger.warning("✗ Translation API configuration is incomplete")
|
||||
logger.warning(f" - Translation Base URL: {'✓' if translation_base_url else '✗'}")
|
||||
logger.warning(f" - Translation API Key: {'✓' if translation_api_key else '✗'}")
|
||||
|
||||
if ocr_base_url and ocr_api_key:
|
||||
logger.info("✓ OCR API configured successfully")
|
||||
else:
|
||||
logger.warning("✗ OCR API configuration is incomplete (扫描PDF功能将不可用)")
|
||||
logger.warning(f" - OCR Base URL: {'✓' if ocr_base_url else '✗'}")
|
||||
logger.warning(f" - OCR API Key: {'✓' if ocr_api_key else '✗'}")
|
700
app/services/enhanced_pdf_parser.py
Normal file
700
app/services/enhanced_pdf_parser.py
Normal file
@@ -0,0 +1,700 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强的PDF解析器 - 支持扫描PDF的OCR处理
|
||||
|
||||
Author: PANJIT IT Team
|
||||
Created: 2024-09-23
|
||||
Modified: 2024-09-23
|
||||
"""
|
||||
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from PyPDF2 import PdfReader
|
||||
from app.utils.logger import get_logger
|
||||
from app.utils.exceptions import FileProcessingError
|
||||
from app.services.dify_client import DifyClient
|
||||
from app.services.ocr_cache import OCRCache
|
||||
from app.utils.image_preprocessor import ImagePreprocessor
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 检查PyMuPDF依赖
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
_HAS_PYMUPDF = True
|
||||
except ImportError:
|
||||
_HAS_PYMUPDF = False
|
||||
logger.warning("PyMuPDF not available. Scanned PDF processing will be disabled.")
|
||||
|
||||
|
||||
class EnhancedPdfParser:
|
||||
"""支持扫描PDF的增强解析器"""
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
self.file_path = Path(file_path)
|
||||
self.dify_client = DifyClient()
|
||||
self.ocr_cache = OCRCache()
|
||||
self.image_preprocessor = ImagePreprocessor(use_opencv=True)
|
||||
|
||||
if not self.file_path.exists():
|
||||
raise FileProcessingError(f"PDF文件不存在: {file_path}")
|
||||
|
||||
def is_scanned_pdf(self) -> bool:
|
||||
"""检测PDF是否为扫描件"""
|
||||
try:
|
||||
reader = PdfReader(str(self.file_path))
|
||||
text_content = ""
|
||||
|
||||
# 检查前3页的文字内容
|
||||
pages_to_check = min(3, len(reader.pages))
|
||||
for i in range(pages_to_check):
|
||||
page_text = reader.pages[i].extract_text()
|
||||
text_content += page_text
|
||||
|
||||
# 如果文字内容很少,很可能是扫描件
|
||||
text_length = len(text_content.strip())
|
||||
logger.info(f"PDF text extraction found {text_length} characters in first {pages_to_check} pages")
|
||||
|
||||
# 阈值:少于100个字符认为是扫描件
|
||||
is_scanned = text_length < 100
|
||||
|
||||
if is_scanned:
|
||||
logger.info("PDF detected as scanned document, will use OCR processing")
|
||||
else:
|
||||
logger.info("PDF detected as text-based document, will use direct text extraction")
|
||||
|
||||
return is_scanned
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze PDF type: {e}, treating as scanned document")
|
||||
return True # 默认当作扫描件处理
|
||||
|
||||
def extract_text_segments(self, user_id: int = None, job_id: int = None) -> List[str]:
|
||||
"""智能提取PDF文字片段"""
|
||||
try:
|
||||
# 首先尝试直接文字提取
|
||||
if not self.is_scanned_pdf():
|
||||
return self._extract_from_text_pdf()
|
||||
|
||||
# 扫描PDF则转换为图片后使用Dify OCR
|
||||
if not _HAS_PYMUPDF:
|
||||
raise FileProcessingError("处理扫描PDF需要PyMuPDF库,请安装: pip install PyMuPDF")
|
||||
|
||||
return self._extract_from_scanned_pdf(user_id, job_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"PDF文字提取失败: {str(e)}")
|
||||
raise FileProcessingError(f"PDF文件解析失败: {str(e)}")
|
||||
|
||||
def _extract_from_text_pdf(self) -> List[str]:
|
||||
"""从文字型PDF提取文字片段"""
|
||||
try:
|
||||
reader = PdfReader(str(self.file_path))
|
||||
text_segments = []
|
||||
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
|
||||
if page_text.strip():
|
||||
# 简单的句子分割
|
||||
sentences = self._split_text_into_sentences(page_text)
|
||||
|
||||
# 过滤掉太短的片段
|
||||
valid_sentences = [s for s in sentences if len(s.strip()) > 10]
|
||||
text_segments.extend(valid_sentences)
|
||||
|
||||
logger.debug(f"Page {page_num}: extracted {len(valid_sentences)} sentences")
|
||||
|
||||
logger.info(f"Text PDF extraction completed: {len(text_segments)} segments")
|
||||
|
||||
# 合併短段落以減少不必要的翻譯調用
|
||||
merged_segments = self._merge_short_segments(text_segments)
|
||||
return merged_segments
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Text PDF extraction failed: {str(e)}")
|
||||
raise FileProcessingError(f"文字PDF提取失败: {str(e)}")
|
||||
|
||||
def _extract_from_scanned_pdf(self, user_id: int = None, job_id: int = None) -> List[str]:
|
||||
"""从扫描PDF提取文字片段(使用Dify OCR)"""
|
||||
try:
|
||||
doc = fitz.open(str(self.file_path))
|
||||
text_segments = []
|
||||
total_pages = doc.page_count
|
||||
|
||||
logger.info(f"Processing scanned PDF with {total_pages} pages using Dify OCR")
|
||||
|
||||
for page_num in range(total_pages):
|
||||
try:
|
||||
logger.info(f"[PDF-OCR] Processing page {page_num + 1}/{total_pages}")
|
||||
page = doc[page_num]
|
||||
|
||||
# 转换页面为高分辨率图片
|
||||
# 使用2倍缩放提高OCR准确度
|
||||
zoom = 2.0
|
||||
mat = fitz.Matrix(zoom, zoom)
|
||||
pix = page.get_pixmap(matrix=mat, alpha=False)
|
||||
|
||||
# 转换为PNG字节数据
|
||||
# 轉換為 PNG 並進行圖像預處理以提升 OCR 準確度
|
||||
img_data_raw = pix.tobytes("png")
|
||||
img_data = self.image_preprocessor.preprocess_smart(img_data_raw)
|
||||
logger.debug(f"[PDF-OCR] Page {page_num + 1}: Image preprocessed ({len(img_data_raw)} -> {len(img_data)} bytes)")
|
||||
filename = f"page_{page_num + 1}.png"
|
||||
|
||||
logger.info(f"[PDF-OCR] Page {page_num + 1}: Converted to image ({len(img_data)} bytes)")
|
||||
logger.debug(f"[PDF-OCR] Page {page_num + 1}: Image zoom={zoom}, format=PNG")
|
||||
|
||||
# 检查OCR快取
|
||||
cache_key_info = f"{self.file_path.name}_page_{page_num + 1}_zoom_{zoom}"
|
||||
cached_text = self.ocr_cache.get_cached_text(
|
||||
file_data=img_data,
|
||||
filename=filename,
|
||||
additional_info=cache_key_info
|
||||
)
|
||||
|
||||
if cached_text:
|
||||
logger.info(f"[PDF-OCR] Page {page_num + 1}: ✓ 使用快取的OCR結果 (節省AI流量)")
|
||||
ocr_text = cached_text
|
||||
else:
|
||||
# 使用Dify OCR识别文字
|
||||
logger.info(f"[PDF-OCR] Page {page_num + 1}: Starting OCR recognition...")
|
||||
ocr_text = self.dify_client.ocr_image_with_dify(
|
||||
image_data=img_data,
|
||||
filename=filename,
|
||||
user_id=user_id,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
# 保存OCR结果到快取
|
||||
if ocr_text.strip():
|
||||
self.ocr_cache.save_cached_text(
|
||||
file_data=img_data,
|
||||
extracted_text=ocr_text,
|
||||
filename=filename,
|
||||
additional_info=cache_key_info,
|
||||
metadata={
|
||||
'source_file': str(self.file_path),
|
||||
'page_number': page_num + 1,
|
||||
'total_pages': total_pages,
|
||||
'zoom_level': zoom,
|
||||
'image_size_bytes': len(img_data),
|
||||
'user_id': user_id,
|
||||
'job_id': job_id
|
||||
}
|
||||
)
|
||||
logger.info(f"[PDF-OCR] Page {page_num + 1}: ✓ OCR結果已保存到快取")
|
||||
|
||||
logger.info(f"[PDF-OCR] Page {page_num + 1}: OCR completed")
|
||||
logger.debug(f"[PDF-OCR] Page {page_num + 1}: Raw OCR result length: {len(ocr_text)}")
|
||||
|
||||
if ocr_text.strip():
|
||||
# 分割OCR结果为句子
|
||||
logger.debug(f"[PDF-OCR] Page {page_num + 1}: Splitting OCR text into sentences...")
|
||||
sentences = self._split_ocr_text(ocr_text)
|
||||
|
||||
# 过滤有效句子
|
||||
valid_sentences = [s for s in sentences if len(s.strip()) > 5]
|
||||
text_segments.extend(valid_sentences)
|
||||
|
||||
logger.info(f"[PDF-OCR] Page {page_num + 1}: ✓ Extracted {len(valid_sentences)} valid sentences")
|
||||
logger.debug(f"[PDF-OCR] Page {page_num + 1}: Total sentences before filter: {len(sentences)}")
|
||||
|
||||
# 记录前50个字符用于调试
|
||||
if valid_sentences:
|
||||
preview = valid_sentences[0][:50] + "..." if len(valid_sentences[0]) > 50 else valid_sentences[0]
|
||||
logger.debug(f"[PDF-OCR] Page {page_num + 1}: First sentence preview: {preview}")
|
||||
else:
|
||||
logger.warning(f"[PDF-OCR] Page {page_num + 1}: ⚠ OCR returned empty result")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[PDF-OCR] Page {page_num + 1}: ✗ Processing failed: {str(e)}")
|
||||
logger.error(f"[PDF-OCR] Page {page_num + 1}: Exception type: {type(e).__name__}")
|
||||
# 继续处理下一页,不中断整个流程
|
||||
continue
|
||||
|
||||
doc.close()
|
||||
|
||||
logger.info(f"[PDF-OCR] OCR processing completed for all {total_pages} pages")
|
||||
logger.info(f"[PDF-OCR] Total text segments extracted: {len(text_segments)}")
|
||||
|
||||
if not text_segments:
|
||||
logger.error(f"[PDF-OCR] ✗ No text content extracted from any page")
|
||||
raise FileProcessingError("OCR处理完成,但未提取到任何文字内容")
|
||||
|
||||
logger.info(f"[PDF-OCR] ✓ Scanned PDF processing completed successfully")
|
||||
logger.info(f"[PDF-OCR] Final result: {len(text_segments)} text segments extracted")
|
||||
|
||||
# 合併短段落以減少不必要的翻譯調用
|
||||
merged_segments = self._merge_short_segments(text_segments)
|
||||
logger.info(f"[PDF-OCR] After merging: {len(merged_segments)} segments ready for translation")
|
||||
return merged_segments
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Scanned PDF processing failed: {str(e)}")
|
||||
raise FileProcessingError(f"扫描PDF处理失败: {str(e)}")
|
||||
|
||||
def _split_text_into_sentences(self, text: str) -> List[str]:
|
||||
"""将文字分割成句子"""
|
||||
if not text.strip():
|
||||
return []
|
||||
|
||||
# 简单的分句逻辑
|
||||
sentences = []
|
||||
separators = ['. ', '。', '!', '?', '!', '?', '\n\n']
|
||||
|
||||
current_sentences = [text]
|
||||
|
||||
for sep in separators:
|
||||
new_sentences = []
|
||||
for sentence in current_sentences:
|
||||
parts = sentence.split(sep)
|
||||
if len(parts) > 1:
|
||||
# 保留分隔符
|
||||
for i, part in enumerate(parts[:-1]):
|
||||
if part.strip():
|
||||
new_sentences.append(part.strip() + sep.rstrip())
|
||||
# 最后一部分
|
||||
if parts[-1].strip():
|
||||
new_sentences.append(parts[-1].strip())
|
||||
else:
|
||||
new_sentences.append(sentence)
|
||||
current_sentences = new_sentences
|
||||
|
||||
# 过滤掉太短的句子
|
||||
valid_sentences = [s for s in current_sentences if len(s.strip()) > 3]
|
||||
return valid_sentences
|
||||
|
||||
def _split_ocr_text(self, ocr_text: str) -> List[str]:
|
||||
"""分割OCR识别的文字"""
|
||||
if not ocr_text.strip():
|
||||
return []
|
||||
|
||||
# OCR结果可能包含表格或特殊格式,需要特殊处理
|
||||
lines = ocr_text.split('\n')
|
||||
sentences = []
|
||||
|
||||
current_paragraph = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
# 空行表示段落结束
|
||||
if current_paragraph:
|
||||
paragraph_text = ' '.join(current_paragraph)
|
||||
if len(paragraph_text) > 10:
|
||||
sentences.append(paragraph_text)
|
||||
current_paragraph = []
|
||||
continue
|
||||
|
||||
# 检查是否是表格行(包含|或多个制表符)
|
||||
if '|' in line or '\t' in line:
|
||||
# 表格行单独处理
|
||||
if current_paragraph:
|
||||
paragraph_text = ' '.join(current_paragraph)
|
||||
if len(paragraph_text) > 10:
|
||||
sentences.append(paragraph_text)
|
||||
current_paragraph = []
|
||||
|
||||
if len(line) > 10:
|
||||
sentences.append(line)
|
||||
else:
|
||||
# 普通文字行
|
||||
current_paragraph.append(line)
|
||||
|
||||
# 处理最后的段落
|
||||
if current_paragraph:
|
||||
paragraph_text = ' '.join(current_paragraph)
|
||||
if len(paragraph_text) > 10:
|
||||
sentences.append(paragraph_text)
|
||||
|
||||
return sentences
|
||||
|
||||
def generate_translated_document(self, translations: dict, target_language: str,
|
||||
output_dir: Path) -> str:
|
||||
"""生成翻译的Word文档(保持与DOCX相同的格式)"""
|
||||
try:
|
||||
from app.utils.helpers import generate_filename
|
||||
|
||||
translated_texts = translations.get(target_language, [])
|
||||
|
||||
# 生成Word文档而非文字文件
|
||||
output_filename = f"{self.file_path.stem}_{target_language}_translated.docx"
|
||||
output_path = output_dir / output_filename
|
||||
|
||||
# 创建Word文档
|
||||
from docx import Document
|
||||
from docx.shared import Pt
|
||||
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
|
||||
|
||||
doc = Document()
|
||||
|
||||
# 添加标题页
|
||||
title = doc.add_heading(f"PDF翻译结果 - {target_language}", 0)
|
||||
title.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
|
||||
# 添加文档信息
|
||||
info_para = doc.add_paragraph()
|
||||
info_para.add_run("原始文件: ").bold = True
|
||||
info_para.add_run(self.file_path.name)
|
||||
info_para.add_run("\n处理方式: ").bold = True
|
||||
info_para.add_run("OCR识别" if self.is_scanned_pdf() else "直接文字提取")
|
||||
info_para.add_run(f"\n翻译语言: ").bold = True
|
||||
info_para.add_run(target_language)
|
||||
info_para.add_run(f"\n总段落数: ").bold = True
|
||||
info_para.add_run(str(len(translated_texts)))
|
||||
|
||||
doc.add_paragraph() # 空行
|
||||
|
||||
# 添加翻译内容
|
||||
for i, text in enumerate(translated_texts, 1):
|
||||
content_type = self._detect_content_type(text)
|
||||
|
||||
if content_type == 'table':
|
||||
# 尝试创建实际的表格
|
||||
self._add_table_content(doc, text, i)
|
||||
elif content_type == 'heading':
|
||||
# 添加标题
|
||||
self._add_heading_content(doc, text, i)
|
||||
elif content_type == 'list':
|
||||
# 添加列表
|
||||
self._add_list_content(doc, text, i)
|
||||
else:
|
||||
# 普通段落
|
||||
self._add_paragraph_content(doc, text, i)
|
||||
|
||||
# 保存Word文档
|
||||
doc.save(output_path)
|
||||
logger.info(f"Generated translated PDF Word document: {output_path}")
|
||||
return str(output_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate translated Word document: {str(e)}")
|
||||
raise FileProcessingError(f"生成翻译Word文档失败: {str(e)}")
|
||||
|
||||
def generate_combined_translated_document(self, all_translations: dict, target_languages: list,
|
||||
output_dir: Path) -> str:
|
||||
"""生成包含所有翻譯語言的組合Word文檔(譯文1/譯文2格式)"""
|
||||
try:
|
||||
from app.utils.helpers import generate_filename
|
||||
|
||||
# 生成組合文檔檔名
|
||||
languages_suffix = '_'.join(target_languages)
|
||||
output_filename = f"{self.file_path.stem}_{languages_suffix}_combined.docx"
|
||||
output_path = output_dir / output_filename
|
||||
|
||||
# 创建Word文档
|
||||
from docx import Document
|
||||
from docx.shared import Pt
|
||||
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
|
||||
|
||||
doc = Document()
|
||||
|
||||
# 添加标题页
|
||||
title = doc.add_heading(f"PDF翻译結果 - 多語言組合文檔", 0)
|
||||
title.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
|
||||
# 添加文档信息
|
||||
info_para = doc.add_paragraph()
|
||||
info_para.add_run("原始文件: ").bold = True
|
||||
info_para.add_run(self.file_path.name)
|
||||
info_para.add_run("\n处理方式: ").bold = True
|
||||
info_para.add_run("OCR识别" if self.is_scanned_pdf() else "直接文字提取")
|
||||
info_para.add_run(f"\n翻译语言: ").bold = True
|
||||
info_para.add_run(' / '.join(target_languages))
|
||||
|
||||
# 获取第一个語言的翻譯作為基準長度
|
||||
first_language = target_languages[0]
|
||||
segment_count = len(all_translations.get(first_language, []))
|
||||
info_para.add_run(f"\n总段落数: ").bold = True
|
||||
info_para.add_run(str(segment_count))
|
||||
|
||||
doc.add_paragraph() # 空行
|
||||
|
||||
# 添加翻译内容 - 譯文1/譯文2格式
|
||||
for i in range(segment_count):
|
||||
content_para = doc.add_paragraph()
|
||||
|
||||
# 添加段落编号
|
||||
num_run = content_para.add_run(f"{i+1:03d}. ")
|
||||
num_run.bold = True
|
||||
num_run.font.size = Pt(12)
|
||||
|
||||
# 为每种语言添加翻譯
|
||||
for j, target_language in enumerate(target_languages):
|
||||
if i < len(all_translations.get(target_language, [])):
|
||||
translation_text = all_translations[target_language][i]
|
||||
|
||||
# 添加語言標識
|
||||
if j > 0:
|
||||
content_para.add_run("\n\n") # 翻譯之間的間距
|
||||
|
||||
lang_run = content_para.add_run(f"[{target_language}] ")
|
||||
lang_run.bold = True
|
||||
lang_run.font.size = Pt(11)
|
||||
|
||||
# 添加翻譯内容
|
||||
trans_run = content_para.add_run(translation_text)
|
||||
trans_run.font.size = Pt(11)
|
||||
|
||||
# 段落間距
|
||||
content_para.paragraph_format.space_after = Pt(12)
|
||||
|
||||
# 保存Word文档
|
||||
doc.save(output_path)
|
||||
logger.info(f"Generated combined translated PDF Word document: {output_path}")
|
||||
return str(output_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate combined translated Word document: {str(e)}")
|
||||
raise FileProcessingError(f"生成組合翻译Word文档失败: {str(e)}")
|
||||
|
||||
def _is_table_component(self, segment: str) -> bool:
|
||||
"""檢查段落是否為表格組件(表格邊界、分隔線等)"""
|
||||
segment = segment.strip()
|
||||
|
||||
# Markdown表格分隔線:如 |---|---|---| 或 |===|===|===|
|
||||
if '|' in segment and ('-' in segment or '=' in segment):
|
||||
# 移除 | 和 - = 後,如果剩餘內容很少,則判斷為表格分隔線
|
||||
clean_segment = segment.replace('|', '').replace('-', '').replace('=', '').replace(' ', '').replace(':', '')
|
||||
if len(clean_segment) <= 2: # 允許少量其他字符
|
||||
return True
|
||||
|
||||
# 純分隔線
|
||||
if segment.replace('=', '').replace('-', '').replace(' ', '') == '':
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_table_row(self, segment: str) -> bool:
|
||||
"""檢查段落是否為表格行(包含實際數據的表格行)"""
|
||||
segment = segment.strip()
|
||||
|
||||
# Markdown表格行:至少包含兩個 | 符號,且有實際內容
|
||||
if segment.count('|') >= 2:
|
||||
# 移除首尾的 | 並分割為單元格
|
||||
cells = segment.strip('|').split('|')
|
||||
# 檢查是否有實際的文字內容(不只是分隔符號)
|
||||
has_content = any(
|
||||
cell.strip() and
|
||||
not cell.replace('-', '').replace('=', '').replace(' ', '').replace(':', '') == ''
|
||||
for cell in cells
|
||||
)
|
||||
if has_content:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _merge_table_segments(self, segments: List[str], start_idx: int) -> tuple[str, int]:
|
||||
"""
|
||||
合併表格相關的段落
|
||||
|
||||
Returns:
|
||||
(merged_table_content, next_index)
|
||||
"""
|
||||
table_parts = []
|
||||
current_idx = start_idx
|
||||
|
||||
# 收集連續的表格相關段落
|
||||
while current_idx < len(segments):
|
||||
segment = segments[current_idx].strip()
|
||||
|
||||
if self._is_table_component(segment) or self._is_table_row(segment):
|
||||
table_parts.append(segment)
|
||||
current_idx += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 將表格部分合併為一個段落
|
||||
merged_table = '\n'.join(table_parts)
|
||||
return merged_table, current_idx
|
||||
|
||||
def _merge_short_segments(self, text_segments: List[str], min_length: int = 10) -> List[str]:
|
||||
"""
|
||||
合併短段落以減少不必要的翻譯調用,特別處理表格結構
|
||||
|
||||
Args:
|
||||
text_segments: 原始文字段落列表
|
||||
min_length: 最小段落長度閾值,短於此長度的段落將被合併
|
||||
|
||||
Returns:
|
||||
合併後的段落列表
|
||||
"""
|
||||
if not text_segments:
|
||||
return text_segments
|
||||
|
||||
merged_segments = []
|
||||
current_merge = ""
|
||||
i = 0
|
||||
|
||||
while i < len(text_segments):
|
||||
segment = text_segments[i].strip()
|
||||
if not segment: # 跳過空段落
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 檢查是否為表格組件
|
||||
if self._is_table_component(segment) or self._is_table_row(segment):
|
||||
# 先處理之前積累的短段落
|
||||
if current_merge:
|
||||
merged_segments.append(current_merge.strip())
|
||||
logger.debug(f"Merged short segments before table: '{current_merge[:50]}...'")
|
||||
current_merge = ""
|
||||
|
||||
# 合併表格相關段落
|
||||
table_content, next_i = self._merge_table_segments(text_segments, i)
|
||||
merged_segments.append(table_content)
|
||||
logger.debug(f"Merged table content: {next_i - i} segments -> 1 table block")
|
||||
i = next_i
|
||||
continue
|
||||
|
||||
# 檢查是否為短段落
|
||||
if len(segment) < min_length:
|
||||
# 檢查是否為純標點符號或數字(排除表格符號)
|
||||
if segment.replace('*', '').replace('-', '').replace('_', '').replace('#', '').strip() == '':
|
||||
logger.debug(f"Skipping pure symbol segment: '{segment}'")
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 短段落需要合併
|
||||
if current_merge:
|
||||
current_merge += " " + segment
|
||||
else:
|
||||
current_merge = segment
|
||||
|
||||
logger.debug(f"Adding short segment to merge: '{segment}' (length: {len(segment)})")
|
||||
|
||||
else:
|
||||
# 長段落,先處理之前積累的短段落
|
||||
if current_merge:
|
||||
merged_segments.append(current_merge.strip())
|
||||
logger.debug(f"Merged short segments: '{current_merge[:50]}...' (total length: {len(current_merge)})")
|
||||
current_merge = ""
|
||||
|
||||
# 添加當前長段落
|
||||
merged_segments.append(segment)
|
||||
logger.debug(f"Added long segment: '{segment[:50]}...' (length: {len(segment)})")
|
||||
|
||||
i += 1
|
||||
|
||||
# 處理最後剩餘的短段落
|
||||
if current_merge:
|
||||
merged_segments.append(current_merge.strip())
|
||||
logger.debug(f"Final merged short segments: '{current_merge[:50]}...' (total length: {len(current_merge)})")
|
||||
|
||||
logger.info(f"Segment merging: {len(text_segments)} -> {len(merged_segments)} segments")
|
||||
return merged_segments
|
||||
|
||||
def _detect_content_type(self, text: str) -> str:
|
||||
"""检测内容类型"""
|
||||
text_lower = text.lower().strip()
|
||||
|
||||
# 检测表格(包含多个|或制表符)
|
||||
if ('|' in text and text.count('|') >= 2) or '\t' in text:
|
||||
return 'table'
|
||||
|
||||
# 检测标题
|
||||
if (text_lower.startswith(('第', '章', 'chapter', 'section', '#')) or
|
||||
any(keyword in text_lower for keyword in ['章', '节', '第']) and len(text) < 100):
|
||||
return 'heading'
|
||||
|
||||
# 检测列表
|
||||
if (text_lower.startswith(('•', '-', '*', '1.', '2.', '3.', '4.', '5.')) or
|
||||
any(text_lower.startswith(f"{i}.") for i in range(1, 20))):
|
||||
return 'list'
|
||||
|
||||
return 'paragraph'
|
||||
|
||||
def _add_table_content(self, doc, text: str, index: int):
|
||||
"""添加表格内容"""
|
||||
from docx.shared import Pt
|
||||
|
||||
# 添加表格标题
|
||||
title_para = doc.add_paragraph()
|
||||
title_run = title_para.add_run(f"表格 {index}: ")
|
||||
title_run.bold = True
|
||||
title_run.font.size = Pt(12)
|
||||
|
||||
# 解析表格
|
||||
if '|' in text:
|
||||
# Markdown风格表格
|
||||
lines = [line.strip() for line in text.split('\n') if line.strip()]
|
||||
rows = []
|
||||
for line in lines:
|
||||
if line.startswith('|') and line.endswith('|'):
|
||||
cells = [cell.strip() for cell in line.split('|')[1:-1]]
|
||||
if cells: # 过滤掉分隔行(如|---|---|)
|
||||
if not all(cell.replace('-', '').replace(' ', '') == '' for cell in cells):
|
||||
rows.append(cells)
|
||||
|
||||
if rows:
|
||||
# 创建表格
|
||||
table = doc.add_table(rows=len(rows), cols=len(rows[0]))
|
||||
table.style = 'Table Grid'
|
||||
|
||||
for i, row_data in enumerate(rows):
|
||||
for j, cell_data in enumerate(row_data):
|
||||
if j < len(table.rows[i].cells):
|
||||
cell = table.rows[i].cells[j]
|
||||
cell.text = cell_data
|
||||
# 设置字体
|
||||
for paragraph in cell.paragraphs:
|
||||
for run in paragraph.runs:
|
||||
run.font.size = Pt(10)
|
||||
else:
|
||||
# 制表符分隔的表格
|
||||
para = doc.add_paragraph()
|
||||
content_run = para.add_run(text)
|
||||
content_run.font.name = 'Courier New'
|
||||
content_run.font.size = Pt(10)
|
||||
|
||||
def _add_heading_content(self, doc, text: str, index: int):
|
||||
"""添加标题内容"""
|
||||
from docx.shared import Pt
|
||||
|
||||
# 移除段落编号,直接作为标题
|
||||
clean_text = text.strip()
|
||||
if len(clean_text) < 100:
|
||||
heading = doc.add_heading(clean_text, level=2)
|
||||
else:
|
||||
# 长文本作为普通段落但使用标题样式
|
||||
para = doc.add_paragraph()
|
||||
run = para.add_run(clean_text)
|
||||
run.bold = True
|
||||
run.font.size = Pt(14)
|
||||
|
||||
def _add_list_content(self, doc, text: str, index: int):
|
||||
"""添加列表内容"""
|
||||
from docx.shared import Pt
|
||||
|
||||
# 检查是否已经有编号
|
||||
if any(text.strip().startswith(f"{i}.") for i in range(1, 20)):
|
||||
# 已编号列表
|
||||
para = doc.add_paragraph(text.strip(), style='List Number')
|
||||
else:
|
||||
# 项目符号列表
|
||||
para = doc.add_paragraph(text.strip(), style='List Bullet')
|
||||
|
||||
# 设置字体大小
|
||||
for run in para.runs:
|
||||
run.font.size = Pt(11)
|
||||
|
||||
def _add_paragraph_content(self, doc, text: str, index: int):
|
||||
"""添加普通段落内容"""
|
||||
from docx.shared import Pt
|
||||
|
||||
para = doc.add_paragraph()
|
||||
|
||||
# 添加段落编号(可选)
|
||||
num_run = para.add_run(f"{index:03d}. ")
|
||||
num_run.bold = True
|
||||
num_run.font.size = Pt(12)
|
||||
|
||||
# 添加内容
|
||||
content_run = para.add_run(text)
|
||||
content_run.font.size = Pt(11)
|
||||
|
||||
# 设置段落间距
|
||||
para.paragraph_format.space_after = Pt(6)
|
@@ -56,41 +56,45 @@ class NotificationService:
|
||||
return None
|
||||
|
||||
def _send_email(self, to_email: str, subject: str, html_content: str, text_content: str = None) -> bool:
|
||||
"""發送郵件的基礎方法"""
|
||||
try:
|
||||
if not self.smtp_server or not self.sender_email:
|
||||
logger.error("SMTP configuration incomplete")
|
||||
return False
|
||||
|
||||
# 建立郵件
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['From'] = f"{self.app_name} <{self.sender_email}>"
|
||||
msg['To'] = to_email
|
||||
msg['Subject'] = subject
|
||||
|
||||
# 添加文本內容
|
||||
if text_content:
|
||||
text_part = MIMEText(text_content, 'plain', 'utf-8')
|
||||
msg.attach(text_part)
|
||||
|
||||
# 添加 HTML 內容
|
||||
html_part = MIMEText(html_content, 'html', 'utf-8')
|
||||
msg.attach(html_part)
|
||||
|
||||
# 發送郵件
|
||||
server = self._create_smtp_connection()
|
||||
if not server:
|
||||
return False
|
||||
|
||||
server.send_message(msg)
|
||||
server.quit()
|
||||
|
||||
logger.info(f"Email sent successfully to {to_email}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send email to {to_email}: {str(e)}")
|
||||
return False
|
||||
"""發送郵件的基礎方法 - 已停用 (資安限制,無法連接內網)"""
|
||||
logger.info(f"SMTP service disabled - Email notification skipped for {to_email}: {subject}")
|
||||
return True # 回傳 True 避免影響其他流程
|
||||
|
||||
# 以下 SMTP 功能已註解,因應資安限制無法連接內網
|
||||
# try:
|
||||
# if not self.smtp_server or not self.sender_email:
|
||||
# logger.error("SMTP configuration incomplete")
|
||||
# return False
|
||||
#
|
||||
# # 建立郵件
|
||||
# msg = MIMEMultipart('alternative')
|
||||
# msg['From'] = f"{self.app_name} <{self.sender_email}>"
|
||||
# msg['To'] = to_email
|
||||
# msg['Subject'] = subject
|
||||
#
|
||||
# # 添加文本內容
|
||||
# if text_content:
|
||||
# text_part = MIMEText(text_content, 'plain', 'utf-8')
|
||||
# msg.attach(text_part)
|
||||
#
|
||||
# # 添加 HTML 內容
|
||||
# html_part = MIMEText(html_content, 'html', 'utf-8')
|
||||
# msg.attach(html_part)
|
||||
#
|
||||
# # 發送郵件
|
||||
# server = self._create_smtp_connection()
|
||||
# if not server:
|
||||
# return False
|
||||
#
|
||||
# server.send_message(msg)
|
||||
# server.quit()
|
||||
#
|
||||
# logger.info(f"Email sent successfully to {to_email}")
|
||||
# return True
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to send email to {to_email}: {str(e)}")
|
||||
# return False
|
||||
|
||||
def send_job_completion_notification(self, job: TranslationJob) -> bool:
|
||||
"""發送任務完成通知"""
|
||||
|
282
app/services/ocr_cache.py
Normal file
282
app/services/ocr_cache.py
Normal file
@@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
OCR 快取管理模組
|
||||
|
||||
Author: PANJIT IT Team
|
||||
Created: 2024-01-28
|
||||
Modified: 2024-01-28
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OCRCache:
|
||||
"""OCR 結果快取管理器"""
|
||||
|
||||
def __init__(self, cache_db_path: str = "ocr_cache.db", cache_expire_days: int = 30):
|
||||
"""
|
||||
初始化 OCR 快取管理器
|
||||
|
||||
Args:
|
||||
cache_db_path: 快取資料庫路徑
|
||||
cache_expire_days: 快取過期天數
|
||||
"""
|
||||
self.cache_db_path = Path(cache_db_path)
|
||||
self.cache_expire_days = cache_expire_days
|
||||
self.init_database()
|
||||
|
||||
def init_database(self):
|
||||
"""初始化快取資料庫"""
|
||||
try:
|
||||
with sqlite3.connect(self.cache_db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS ocr_cache (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_hash TEXT UNIQUE NOT NULL,
|
||||
filename TEXT,
|
||||
file_size INTEGER,
|
||||
extracted_text TEXT NOT NULL,
|
||||
extraction_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
access_count INTEGER DEFAULT 1,
|
||||
last_access_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
metadata TEXT
|
||||
)
|
||||
''')
|
||||
|
||||
# 創建索引以提高查詢效能
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_file_hash
|
||||
ON ocr_cache(file_hash)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_extraction_time
|
||||
ON ocr_cache(extraction_time)
|
||||
''')
|
||||
|
||||
conn.commit()
|
||||
logger.info("OCR 快取資料庫初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 OCR 快取資料庫失敗: {e}")
|
||||
raise
|
||||
|
||||
def _calculate_file_hash(self, file_data: bytes, additional_info: str = "") -> str:
|
||||
"""
|
||||
計算檔案內容的 SHA256 雜湊值
|
||||
|
||||
Args:
|
||||
file_data: 檔案二進位資料
|
||||
additional_info: 額外資訊(如頁數、處理參數等)
|
||||
|
||||
Returns:
|
||||
檔案的 SHA256 雜湊值
|
||||
"""
|
||||
hash_input = file_data + additional_info.encode('utf-8')
|
||||
return hashlib.sha256(hash_input).hexdigest()
|
||||
|
||||
def get_cached_text(self, file_data: bytes, filename: str = "",
|
||||
additional_info: str = "") -> Optional[str]:
|
||||
"""
|
||||
獲取快取的 OCR 文字
|
||||
|
||||
Args:
|
||||
file_data: 檔案二進位資料
|
||||
filename: 檔案名稱
|
||||
additional_info: 額外資訊
|
||||
|
||||
Returns:
|
||||
快取的文字內容,如果不存在則返回 None
|
||||
"""
|
||||
try:
|
||||
file_hash = self._calculate_file_hash(file_data, additional_info)
|
||||
|
||||
with sqlite3.connect(self.cache_db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查詢快取
|
||||
cursor.execute('''
|
||||
SELECT extracted_text, access_count
|
||||
FROM ocr_cache
|
||||
WHERE file_hash = ? AND
|
||||
extraction_time > datetime('now', '-{} days')
|
||||
'''.format(self.cache_expire_days), (file_hash,))
|
||||
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
extracted_text, access_count = result
|
||||
|
||||
# 更新訪問計數和時間
|
||||
cursor.execute('''
|
||||
UPDATE ocr_cache
|
||||
SET access_count = ?, last_access_time = CURRENT_TIMESTAMP
|
||||
WHERE file_hash = ?
|
||||
''', (access_count + 1, file_hash))
|
||||
|
||||
conn.commit()
|
||||
|
||||
logger.info(f"[OCR-CACHE] 快取命中: {filename} (訪問次數: {access_count + 1})")
|
||||
return extracted_text
|
||||
|
||||
logger.debug(f"[OCR-CACHE] 快取未命中: {filename}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"獲取 OCR 快取失敗: {e}")
|
||||
return None
|
||||
|
||||
def save_cached_text(self, file_data: bytes, extracted_text: str,
|
||||
filename: str = "", additional_info: str = "",
|
||||
metadata: Dict[str, Any] = None) -> bool:
|
||||
"""
|
||||
儲存 OCR 文字到快取
|
||||
|
||||
Args:
|
||||
file_data: 檔案二進位資料
|
||||
extracted_text: 提取的文字
|
||||
filename: 檔案名稱
|
||||
additional_info: 額外資訊
|
||||
metadata: 中繼資料
|
||||
|
||||
Returns:
|
||||
是否儲存成功
|
||||
"""
|
||||
try:
|
||||
file_hash = self._calculate_file_hash(file_data, additional_info)
|
||||
file_size = len(file_data)
|
||||
metadata_json = json.dumps(metadata or {}, ensure_ascii=False)
|
||||
|
||||
with sqlite3.connect(self.cache_db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 使用 INSERT OR REPLACE 來處理重複的雜湊值
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO ocr_cache
|
||||
(file_hash, filename, file_size, extracted_text, metadata)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
''', (file_hash, filename, file_size, extracted_text, metadata_json))
|
||||
|
||||
conn.commit()
|
||||
|
||||
logger.info(f"[OCR-CACHE] 儲存快取成功: {filename} ({len(extracted_text)} 字元)")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"儲存 OCR 快取失敗: {e}")
|
||||
return False
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
獲取快取統計資訊
|
||||
|
||||
Returns:
|
||||
快取統計資料
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.cache_db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 總記錄數
|
||||
cursor.execute('SELECT COUNT(*) FROM ocr_cache')
|
||||
total_records = cursor.fetchone()[0]
|
||||
|
||||
# 總訪問次數
|
||||
cursor.execute('SELECT SUM(access_count) FROM ocr_cache')
|
||||
total_accesses = cursor.fetchone()[0] or 0
|
||||
|
||||
# 快取大小
|
||||
cursor.execute('SELECT SUM(LENGTH(extracted_text)) FROM ocr_cache')
|
||||
cache_size_chars = cursor.fetchone()[0] or 0
|
||||
|
||||
# 最近 7 天的記錄數
|
||||
cursor.execute('''
|
||||
SELECT COUNT(*) FROM ocr_cache
|
||||
WHERE extraction_time > datetime('now', '-7 days')
|
||||
''')
|
||||
recent_records = cursor.fetchone()[0]
|
||||
|
||||
# 最常訪問的記錄
|
||||
cursor.execute('''
|
||||
SELECT filename, access_count, last_access_time
|
||||
FROM ocr_cache
|
||||
ORDER BY access_count DESC
|
||||
LIMIT 5
|
||||
''')
|
||||
top_accessed = cursor.fetchall()
|
||||
|
||||
return {
|
||||
'total_records': total_records,
|
||||
'total_accesses': total_accesses,
|
||||
'cache_size_chars': cache_size_chars,
|
||||
'cache_size_mb': cache_size_chars / (1024 * 1024),
|
||||
'recent_records_7days': recent_records,
|
||||
'top_accessed_files': [
|
||||
{
|
||||
'filename': row[0],
|
||||
'access_count': row[1],
|
||||
'last_access': row[2]
|
||||
}
|
||||
for row in top_accessed
|
||||
],
|
||||
'cache_hit_potential': f"{(total_accesses - total_records) / max(total_accesses, 1) * 100:.1f}%"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"獲取快取統計失敗: {e}")
|
||||
return {}
|
||||
|
||||
def clean_expired_cache(self) -> int:
|
||||
"""
|
||||
清理過期的快取記錄
|
||||
|
||||
Returns:
|
||||
清理的記錄數量
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.cache_db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 刪除過期記錄
|
||||
cursor.execute('''
|
||||
DELETE FROM ocr_cache
|
||||
WHERE extraction_time < datetime('now', '-{} days')
|
||||
'''.format(self.cache_expire_days))
|
||||
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
logger.info(f"[OCR-CACHE] 清理過期快取: {deleted_count} 筆記錄")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理過期快取失敗: {e}")
|
||||
return 0
|
||||
|
||||
def clear_all_cache(self) -> bool:
|
||||
"""
|
||||
清空所有快取
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.cache_db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM ocr_cache')
|
||||
conn.commit()
|
||||
|
||||
logger.info("[OCR-CACHE] 已清空所有快取")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空快取失敗: {e}")
|
||||
return False
|
File diff suppressed because it is too large
Load Diff
277
app/utils/api_auth.py
Normal file
277
app/utils/api_auth.py
Normal file
@@ -0,0 +1,277 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
API 認證服務
|
||||
用於與 PANJIT Auth API 整合認證
|
||||
|
||||
Author: PANJIT IT Team
|
||||
Created: 2025-10-01
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from flask import current_app
|
||||
from .logger import get_logger
|
||||
from .exceptions import AuthenticationError
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class APIAuthService:
|
||||
"""API 認證服務"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = current_app.config
|
||||
self.api_base_url = "https://pj-auth-api.vercel.app"
|
||||
self.login_endpoint = "/api/auth/login"
|
||||
self.logout_endpoint = "/api/auth/logout"
|
||||
self.timeout = 30 # 30 秒超時
|
||||
|
||||
def authenticate_user(self, username: str, password: str) -> Dict[str, Any]:
|
||||
"""
|
||||
透過 API 驗證使用者憑證
|
||||
|
||||
Args:
|
||||
username: 使用者帳號
|
||||
password: 密碼
|
||||
|
||||
Returns:
|
||||
Dict: 包含使用者資訊和 Token 的字典
|
||||
|
||||
Raises:
|
||||
AuthenticationError: 認證失敗時拋出
|
||||
"""
|
||||
try:
|
||||
login_url = f"{self.api_base_url}{self.login_endpoint}"
|
||||
|
||||
payload = {
|
||||
"username": username,
|
||||
"password": password
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
logger.info(f"正在透過 API 驗證使用者: {username}")
|
||||
|
||||
# 發送認證請求
|
||||
response = requests.post(
|
||||
login_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
# 解析回應
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
|
||||
if data.get('success'):
|
||||
logger.info(f"API 認證成功: {username}")
|
||||
return self._parse_auth_response(data)
|
||||
else:
|
||||
error_msg = data.get('error', '認證失敗')
|
||||
logger.warning(f"API 認證失敗: {username} - {error_msg}")
|
||||
raise AuthenticationError(f"認證失敗: {error_msg}")
|
||||
|
||||
elif response.status_code == 401:
|
||||
data = response.json()
|
||||
error_msg = data.get('error', '帳號或密碼錯誤')
|
||||
logger.warning(f"API 認證失敗 (401): {username} - {error_msg}")
|
||||
raise AuthenticationError("帳號或密碼錯誤")
|
||||
|
||||
else:
|
||||
logger.error(f"API 認證請求失敗: HTTP {response.status_code}")
|
||||
raise AuthenticationError(f"認證服務錯誤 (HTTP {response.status_code})")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"API 認證請求超時: {username}")
|
||||
raise AuthenticationError("認證服務回應超時,請稍後再試")
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.error(f"API 認證連線錯誤: {username}")
|
||||
raise AuthenticationError("無法連接認證服務,請檢查網路連線")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"API 認證請求錯誤: {username} - {str(e)}")
|
||||
raise AuthenticationError(f"認證服務錯誤: {str(e)}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"API 認證回應格式錯誤: {username}")
|
||||
raise AuthenticationError("認證服務回應格式錯誤")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API 認證未知錯誤: {username} - {str(e)}")
|
||||
raise AuthenticationError(f"認證過程發生錯誤: {str(e)}")
|
||||
|
||||
def _parse_auth_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
解析 API 認證回應
|
||||
|
||||
Args:
|
||||
data: API 回應資料
|
||||
|
||||
Returns:
|
||||
Dict: 標準化的使用者資訊
|
||||
"""
|
||||
try:
|
||||
auth_data = data.get('data', {})
|
||||
user_info = auth_data.get('userInfo', {})
|
||||
|
||||
# 解析 Token 過期時間
|
||||
expires_at = None
|
||||
issued_at = None
|
||||
|
||||
if 'expiresAt' in auth_data:
|
||||
try:
|
||||
expires_at = datetime.fromisoformat(auth_data['expiresAt'].replace('Z', '+00:00'))
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning("無法解析 API Token 過期時間")
|
||||
|
||||
if 'issuedAt' in auth_data:
|
||||
try:
|
||||
issued_at = datetime.fromisoformat(auth_data['issuedAt'].replace('Z', '+00:00'))
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning("無法解析 API Token 發行時間")
|
||||
|
||||
# 標準化使用者資訊 (方案 A: API name 是姓名+email 格式)
|
||||
api_name = user_info.get('name', '') # 例: "劉怡明 ymirliu@panjit.com.tw"
|
||||
api_email = user_info.get('email', '') # 例: "ymirliu@panjit.com.tw"
|
||||
|
||||
result = {
|
||||
# 基本使用者資訊 (方案 A: username 和 display_name 都用 API name)
|
||||
'username': api_name, # 姓名+email 格式
|
||||
'display_name': api_name, # 姓名+email 格式
|
||||
'email': api_email, # 純 email
|
||||
'department': user_info.get('jobTitle'), # 使用 jobTitle 作為部門
|
||||
'user_principal_name': api_email,
|
||||
|
||||
# API 特有資訊
|
||||
'api_user_id': user_info.get('id', ''), # Azure Object ID
|
||||
'job_title': user_info.get('jobTitle'),
|
||||
'office_location': user_info.get('officeLocation'),
|
||||
'business_phones': user_info.get('businessPhones', []),
|
||||
|
||||
# Token 資訊
|
||||
'api_access_token': auth_data.get('access_token', ''),
|
||||
'api_id_token': auth_data.get('id_token', ''),
|
||||
'api_token_type': auth_data.get('token_type', 'Bearer'),
|
||||
'api_expires_in': auth_data.get('expires_in', 0),
|
||||
'api_issued_at': issued_at,
|
||||
'api_expires_at': expires_at,
|
||||
|
||||
# 完整的 API 回應 (用於記錄)
|
||||
'full_api_response': data,
|
||||
'api_user_info': user_info
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 API 回應時發生錯誤: {str(e)}")
|
||||
raise AuthenticationError(f"解析認證回應時發生錯誤: {str(e)}")
|
||||
|
||||
def logout_user(self, access_token: str) -> bool:
|
||||
"""
|
||||
透過 API 登出使用者
|
||||
|
||||
Args:
|
||||
access_token: 使用者的 access token
|
||||
|
||||
Returns:
|
||||
bool: 登出是否成功
|
||||
"""
|
||||
try:
|
||||
logout_url = f"{self.api_base_url}{self.logout_endpoint}"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
logout_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('success'):
|
||||
logger.info("API 登出成功")
|
||||
return True
|
||||
|
||||
logger.warning(f"API 登出失敗: HTTP {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API 登出時發生錯誤: {str(e)}")
|
||||
return False
|
||||
|
||||
def validate_token(self, access_token: str) -> bool:
|
||||
"""
|
||||
驗證 Token 是否有效
|
||||
|
||||
Args:
|
||||
access_token: 要驗證的 token
|
||||
|
||||
Returns:
|
||||
bool: Token 是否有效
|
||||
"""
|
||||
try:
|
||||
# 這裡可以實作 Token 驗證邏輯
|
||||
# 目前 API 沒有提供專門的驗證端點,可以考慮解析 JWT 或調用其他端點
|
||||
|
||||
# 簡單的檢查:Token 不能為空且格式看起來像 JWT
|
||||
if not access_token or len(access_token.split('.')) != 3:
|
||||
return False
|
||||
|
||||
# TODO: 實作更完整的 JWT 驗證邏輯
|
||||
# 可以解析 JWT payload 檢查過期時間等
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"驗證 Token 時發生錯誤: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
測試 API 連線
|
||||
|
||||
Returns:
|
||||
bool: 連線是否正常
|
||||
"""
|
||||
try:
|
||||
# 嘗試連接 API 基礎端點
|
||||
response = requests.get(
|
||||
self.api_base_url,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
return response.status_code in [200, 404] # 404 也算正常,表示能連接到伺服器
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API 連線測試失敗: {str(e)}")
|
||||
return False
|
||||
|
||||
def calculate_internal_expiry(self, api_expires_at: Optional[datetime], extend_days: int = 3) -> datetime:
|
||||
"""
|
||||
計算內部 Token 過期時間
|
||||
|
||||
Args:
|
||||
api_expires_at: API Token 過期時間
|
||||
extend_days: 延長天數
|
||||
|
||||
Returns:
|
||||
datetime: 內部 Token 過期時間
|
||||
"""
|
||||
if api_expires_at:
|
||||
# 基於 API Token 過期時間延長
|
||||
return api_expires_at + timedelta(days=extend_days)
|
||||
else:
|
||||
# 如果沒有 API 過期時間,從現在開始計算
|
||||
return datetime.utcnow() + timedelta(days=extend_days)
|
248
app/utils/image_preprocessor.py
Normal file
248
app/utils/image_preprocessor.py
Normal file
@@ -0,0 +1,248 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
圖像預處理工具 - 用於提升 OCR 識別準確度
|
||||
|
||||
Author: PANJIT IT Team
|
||||
Created: 2025-10-01
|
||||
Modified: 2025-10-01
|
||||
"""
|
||||
|
||||
import io
|
||||
import numpy as np
|
||||
from PIL import Image, ImageEnhance, ImageFilter
|
||||
from typing import Optional, Tuple
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 檢查 OpenCV 是否可用
|
||||
try:
|
||||
import cv2
|
||||
_HAS_OPENCV = True
|
||||
logger.info("OpenCV is available for advanced image preprocessing")
|
||||
except ImportError:
|
||||
_HAS_OPENCV = False
|
||||
logger.warning("OpenCV not available, using PIL-only preprocessing")
|
||||
|
||||
|
||||
class ImagePreprocessor:
|
||||
"""圖像預處理器 - 提升掃描文件 OCR 品質"""
|
||||
|
||||
def __init__(self, use_opencv: bool = True):
|
||||
"""
|
||||
初始化圖像預處理器
|
||||
|
||||
Args:
|
||||
use_opencv: 是否使用 OpenCV 進行進階處理(若可用)
|
||||
"""
|
||||
self.use_opencv = use_opencv and _HAS_OPENCV
|
||||
logger.info(f"ImagePreprocessor initialized (OpenCV: {self.use_opencv})")
|
||||
|
||||
def preprocess_for_ocr(self, image_bytes: bytes,
|
||||
enhance_level: str = 'medium') -> bytes:
|
||||
"""
|
||||
對圖像進行 OCR 前處理
|
||||
|
||||
Args:
|
||||
image_bytes: 原始圖像字節數據
|
||||
enhance_level: 增強級別 ('low', 'medium', 'high')
|
||||
|
||||
Returns:
|
||||
處理後的圖像字節數據 (PNG格式)
|
||||
"""
|
||||
try:
|
||||
# 1. 載入圖像
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
original_mode = image.mode
|
||||
logger.debug(f"Original image: {image.size}, mode={original_mode}")
|
||||
|
||||
# 2. 轉換為 RGB (如果需要)
|
||||
if image.mode not in ('RGB', 'L'):
|
||||
image = image.convert('RGB')
|
||||
logger.debug(f"Converted to RGB mode")
|
||||
|
||||
# 3. 根據增強級別選擇處理流程
|
||||
if self.use_opencv:
|
||||
processed_image = self._preprocess_with_opencv(image, enhance_level)
|
||||
else:
|
||||
processed_image = self._preprocess_with_pil(image, enhance_level)
|
||||
|
||||
# 4. 轉換為 PNG 字節
|
||||
output_buffer = io.BytesIO()
|
||||
processed_image.save(output_buffer, format='PNG', optimize=True)
|
||||
processed_bytes = output_buffer.getvalue()
|
||||
|
||||
logger.info(f"Image preprocessed: {len(image_bytes)} -> {len(processed_bytes)} bytes (level={enhance_level})")
|
||||
return processed_bytes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Image preprocessing failed: {e}, returning original image")
|
||||
return image_bytes # 失敗時返回原圖
|
||||
|
||||
def _preprocess_with_opencv(self, image: Image.Image, level: str) -> Image.Image:
|
||||
"""使用 OpenCV 進行進階圖像處理"""
|
||||
# PIL Image -> NumPy array
|
||||
img_array = np.array(image)
|
||||
|
||||
# 轉換為 BGR (OpenCV 格式)
|
||||
if len(img_array.shape) == 3 and img_array.shape[2] == 3:
|
||||
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
img_bgr = img_array
|
||||
|
||||
# 1. 灰階化
|
||||
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
logger.debug("Applied grayscale conversion (OpenCV)")
|
||||
|
||||
# 2. 去噪 - 根據級別調整
|
||||
if level == 'high':
|
||||
# 高級別:較強去噪
|
||||
denoised = cv2.fastNlMeansDenoising(gray, None, h=10, templateWindowSize=7, searchWindowSize=21)
|
||||
logger.debug("Applied strong denoising (h=10)")
|
||||
elif level == 'medium':
|
||||
# 中級別:中等去噪
|
||||
denoised = cv2.fastNlMeansDenoising(gray, None, h=7, templateWindowSize=7, searchWindowSize=21)
|
||||
logger.debug("Applied medium denoising (h=7)")
|
||||
else:
|
||||
# 低級別:輕度去噪
|
||||
denoised = cv2.bilateralFilter(gray, 5, 50, 50)
|
||||
logger.debug("Applied light denoising (bilateral)")
|
||||
|
||||
# 3. 對比度增強 - CLAHE
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(denoised)
|
||||
logger.debug("Applied CLAHE contrast enhancement")
|
||||
|
||||
# 4. 銳化 (高級別才使用)
|
||||
if level == 'high':
|
||||
kernel = np.array([[-1,-1,-1],
|
||||
[-1, 9,-1],
|
||||
[-1,-1,-1]])
|
||||
sharpened = cv2.filter2D(enhanced, -1, kernel)
|
||||
logger.debug("Applied sharpening filter")
|
||||
else:
|
||||
sharpened = enhanced
|
||||
|
||||
# 5. 自適應二值化 (根據級別決定是否使用)
|
||||
if level in ('medium', 'high'):
|
||||
# 使用自適應閾值
|
||||
binary = cv2.adaptiveThreshold(
|
||||
sharpened, 255,
|
||||
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY,
|
||||
blockSize=11,
|
||||
C=2
|
||||
)
|
||||
logger.debug("Applied adaptive thresholding")
|
||||
final_image = binary
|
||||
else:
|
||||
final_image = sharpened
|
||||
|
||||
# NumPy array -> PIL Image
|
||||
return Image.fromarray(final_image)
|
||||
|
||||
def _preprocess_with_pil(self, image: Image.Image, level: str) -> Image.Image:
|
||||
"""使用 PIL 進行基礎圖像處理(當 OpenCV 不可用時)"""
|
||||
|
||||
# 1. 灰階化
|
||||
gray = image.convert('L')
|
||||
logger.debug("Applied grayscale conversion (PIL)")
|
||||
|
||||
# 2. 對比度增強
|
||||
enhancer = ImageEnhance.Contrast(gray)
|
||||
if level == 'high':
|
||||
contrast_factor = 2.0
|
||||
elif level == 'medium':
|
||||
contrast_factor = 1.5
|
||||
else:
|
||||
contrast_factor = 1.2
|
||||
|
||||
enhanced = enhancer.enhance(contrast_factor)
|
||||
logger.debug(f"Applied contrast enhancement (factor={contrast_factor})")
|
||||
|
||||
# 3. 銳化
|
||||
if level in ('medium', 'high'):
|
||||
sharpness = ImageEnhance.Sharpness(enhanced)
|
||||
sharp_factor = 2.0 if level == 'high' else 1.5
|
||||
sharpened = sharpness.enhance(sharp_factor)
|
||||
logger.debug(f"Applied sharpening (factor={sharp_factor})")
|
||||
else:
|
||||
sharpened = enhanced
|
||||
|
||||
# 4. 去噪 (使用中值濾波)
|
||||
if level == 'high':
|
||||
denoised = sharpened.filter(ImageFilter.MedianFilter(size=3))
|
||||
logger.debug("Applied median filter (size=3)")
|
||||
else:
|
||||
denoised = sharpened
|
||||
|
||||
return denoised
|
||||
|
||||
def auto_detect_enhance_level(self, image_bytes: bytes) -> str:
|
||||
"""
|
||||
自動偵測最佳增強級別
|
||||
|
||||
Args:
|
||||
image_bytes: 圖像字節數據
|
||||
|
||||
Returns:
|
||||
建議的增強級別 ('low', 'medium', 'high')
|
||||
"""
|
||||
try:
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
if self.use_opencv:
|
||||
# 使用 OpenCV 計算圖像品質指標
|
||||
img_array = np.array(image.convert('L'))
|
||||
|
||||
# 計算拉普拉斯方差 (評估清晰度)
|
||||
laplacian_var = cv2.Laplacian(img_array, cv2.CV_64F).var()
|
||||
|
||||
# 計算對比度 (標準差)
|
||||
contrast = np.std(img_array)
|
||||
|
||||
logger.debug(f"Image quality metrics: laplacian_var={laplacian_var:.2f}, contrast={contrast:.2f}")
|
||||
|
||||
# 根據指標決定增強級別
|
||||
if laplacian_var < 50 or contrast < 40:
|
||||
# 模糊或低對比度 -> 高級別增強
|
||||
return 'high'
|
||||
elif laplacian_var < 100 or contrast < 60:
|
||||
# 中等品質 -> 中級別增強
|
||||
return 'medium'
|
||||
else:
|
||||
# 高品質 -> 低級別增強
|
||||
return 'low'
|
||||
else:
|
||||
# PIL 簡易判斷
|
||||
gray = image.convert('L')
|
||||
img_array = np.array(gray)
|
||||
|
||||
# 簡單對比度評估
|
||||
contrast = np.std(img_array)
|
||||
|
||||
if contrast < 40:
|
||||
return 'high'
|
||||
elif contrast < 60:
|
||||
return 'medium'
|
||||
else:
|
||||
return 'low'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auto enhance level detection failed: {e}")
|
||||
return 'medium' # 預設使用中級別
|
||||
|
||||
def preprocess_smart(self, image_bytes: bytes) -> bytes:
|
||||
"""
|
||||
智能預處理 - 自動偵測並應用最佳處理級別
|
||||
|
||||
Args:
|
||||
image_bytes: 原始圖像字節數據
|
||||
|
||||
Returns:
|
||||
處理後的圖像字節數據
|
||||
"""
|
||||
enhance_level = self.auto_detect_enhance_level(image_bytes)
|
||||
logger.info(f"Auto-detected enhancement level: {enhance_level}")
|
||||
return self.preprocess_for_ocr(image_bytes, enhance_level)
|
Reference in New Issue
Block a user