refactor: complete V1 to V2 migration and remove legacy architecture
Remove all V1 architecture components and promote V2 to primary: - Delete all paddle_ocr_* table models (export, ocr, translation, user) - Delete legacy routers (auth, export, ocr, translation) - Delete legacy schemas and services - Promote user_v2.py to user.py as primary user model - Update all imports and dependencies to use V2 models only - Update main.py version to 2.0.0 Database changes: - Fix SQLAlchemy reserved word: rename audit_log.metadata to extra_data - Add migration to drop all paddle_ocr_* tables - Update alembic env to only import V2 models Frontend fixes: - Fix Select component exports in TaskHistoryPage.tsx - Update to use simplified Select API with options prop - Fix AxiosInstance TypeScript import syntax 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -68,7 +68,10 @@
|
|||||||
"Bash(ss:*)",
|
"Bash(ss:*)",
|
||||||
"Bash(pip index:*)",
|
"Bash(pip index:*)",
|
||||||
"Bash(timeout 10 python:*)",
|
"Bash(timeout 10 python:*)",
|
||||||
"Bash(alembic current:*)"
|
"Bash(alembic current:*)",
|
||||||
|
"Bash(git clean:*)",
|
||||||
|
"Bash(npx tsc:*)",
|
||||||
|
"Bash(./node_modules/.bin/tsc:*)"
|
||||||
],
|
],
|
||||||
"deny": [],
|
"deny": [],
|
||||||
"ask": []
|
"ask": []
|
||||||
|
|||||||
@@ -15,14 +15,8 @@ from app.core.config import settings
|
|||||||
from app.core.database import Base
|
from app.core.database import Base
|
||||||
|
|
||||||
# Import all models to ensure they're registered with Base.metadata
|
# Import all models to ensure they're registered with Base.metadata
|
||||||
# Import old User model for legacy tables
|
# Import V2 models
|
||||||
from app.models.user import User as OldUser
|
from app.models import User, Task, TaskFile, TaskStatus, Session, AuditLog
|
||||||
# Import new models
|
|
||||||
from app.models.user_v2 import User as NewUser
|
|
||||||
from app.models.task import Task, TaskFile, TaskStatus
|
|
||||||
from app.models.session import Session
|
|
||||||
# Import legacy models
|
|
||||||
from app.models import OCRBatch, OCRFile, OCRResult, ExportRule, TranslationConfig
|
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
|
|||||||
@@ -0,0 +1,52 @@
|
|||||||
|
"""drop_old_tables_add_audit_logs
|
||||||
|
|
||||||
|
Revision ID: 4d37f412d37a
|
||||||
|
Revises: 5e75a59fb763
|
||||||
|
Create Date: 2025-11-14 21:13:08.003723
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import mysql
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '4d37f412d37a'
|
||||||
|
down_revision: Union[str, None] = '5e75a59fb763'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema: drop old paddle_ocr_* tables if they exist."""
|
||||||
|
|
||||||
|
# Use raw SQL to drop tables with IF EXISTS
|
||||||
|
connection = op.get_bind()
|
||||||
|
|
||||||
|
# Drop old paddle_ocr_* tables (with foreign key dependencies in correct order)
|
||||||
|
old_tables = [
|
||||||
|
'paddle_ocr_results',
|
||||||
|
'paddle_ocr_files',
|
||||||
|
'paddle_ocr_batches',
|
||||||
|
'paddle_ocr_export_rules',
|
||||||
|
'paddle_ocr_translation_configs',
|
||||||
|
'paddle_ocr_users',
|
||||||
|
]
|
||||||
|
|
||||||
|
for table in old_tables:
|
||||||
|
try:
|
||||||
|
connection.execute(sa.text(f"DROP TABLE IF EXISTS {table}"))
|
||||||
|
print(f"✓ Dropped table: {table}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not drop {table}: {e}")
|
||||||
|
|
||||||
|
print("\n✓ All old paddle_ocr_* tables have been removed")
|
||||||
|
print("✓ Migration complete - V2 schema is now active")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# Note: Downgrade not supported as it would require recreating old tables
|
||||||
|
pass
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Tool_OCR - FastAPI Dependencies
|
Tool_OCR - FastAPI Dependencies (V2)
|
||||||
Authentication and database session dependencies
|
Authentication and database session dependencies with external authentication
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Generator, Optional
|
from typing import Generator, Optional
|
||||||
@@ -13,7 +13,6 @@ from sqlalchemy.orm import Session
|
|||||||
from app.core.database import SessionLocal
|
from app.core.database import SessionLocal
|
||||||
from app.core.security import decode_access_token
|
from app.core.security import decode_access_token
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_v2 import User as UserV2
|
|
||||||
from app.models.session import Session as UserSession
|
from app.models.session import Session as UserSession
|
||||||
from app.services.admin_service import admin_service
|
from app.services.admin_service import admin_service
|
||||||
|
|
||||||
@@ -44,7 +43,7 @@ def get_current_user(
|
|||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
) -> User:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Get current authenticated user from JWT token
|
Get current authenticated user from JWT token (External Authentication)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
credentials: HTTP Bearer credentials
|
credentials: HTTP Bearer credentials
|
||||||
@@ -65,110 +64,6 @@ def get_current_user(
|
|||||||
# Extract token
|
# Extract token
|
||||||
token = credentials.credentials
|
token = credentials.credentials
|
||||||
|
|
||||||
# Decode token
|
|
||||||
payload = decode_access_token(token)
|
|
||||||
if payload is None:
|
|
||||||
raise credentials_exception
|
|
||||||
|
|
||||||
# Extract user ID from token (convert from string to int)
|
|
||||||
user_id_str: Optional[str] = payload.get("sub")
|
|
||||||
if user_id_str is None:
|
|
||||||
raise credentials_exception
|
|
||||||
|
|
||||||
try:
|
|
||||||
user_id: int = int(user_id_str)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
raise credentials_exception
|
|
||||||
|
|
||||||
# Query user from database
|
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
|
||||||
if user is None:
|
|
||||||
raise credentials_exception
|
|
||||||
|
|
||||||
# Check if user is active
|
|
||||||
if not user.is_active:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Inactive user"
|
|
||||||
)
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_active_user(
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
) -> User:
|
|
||||||
"""
|
|
||||||
Get current active user
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_user: Current user from get_current_user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User: Current active user
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If user is inactive
|
|
||||||
"""
|
|
||||||
if not current_user.is_active:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Inactive user"
|
|
||||||
)
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_admin_user(
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
) -> User:
|
|
||||||
"""
|
|
||||||
Get current admin user
|
|
||||||
|
|
||||||
Args:
|
|
||||||
current_user: Current user from get_current_user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User: Current admin user
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If user is not admin
|
|
||||||
"""
|
|
||||||
if not current_user.is_admin:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Not enough privileges"
|
|
||||||
)
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
# ===== V2 Dependencies for External Authentication =====
|
|
||||||
|
|
||||||
def get_current_user_v2(
|
|
||||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
||||||
db: Session = Depends(get_db)
|
|
||||||
) -> UserV2:
|
|
||||||
"""
|
|
||||||
Get current authenticated user from JWT token (V2 with external auth)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: HTTP Bearer credentials
|
|
||||||
db: Database session
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserV2: Current user object
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If token is invalid or user not found
|
|
||||||
"""
|
|
||||||
credentials_exception = HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Could not validate credentials",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract token
|
|
||||||
token = credentials.credentials
|
|
||||||
|
|
||||||
# Decode token
|
# Decode token
|
||||||
payload = decode_access_token(token)
|
payload = decode_access_token(token)
|
||||||
if payload is None:
|
if payload is None:
|
||||||
@@ -187,10 +82,10 @@ def get_current_user_v2(
|
|||||||
# Extract session ID from token (optional)
|
# Extract session ID from token (optional)
|
||||||
session_id: Optional[int] = payload.get("session_id")
|
session_id: Optional[int] = payload.get("session_id")
|
||||||
|
|
||||||
# Query user from database (using V2 model)
|
# Query user from database
|
||||||
user = db.query(UserV2).filter(UserV2.id == user_id).first()
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
if user is None:
|
if user is None:
|
||||||
logger.warning(f"User {user_id} not found in V2 table")
|
logger.warning(f"User {user_id} not found")
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
# Check if user is active
|
# Check if user is active
|
||||||
@@ -234,17 +129,17 @@ def get_current_user_v2(
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def get_current_active_user_v2(
|
def get_current_active_user(
|
||||||
current_user: UserV2 = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
) -> UserV2:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Get current active user (V2)
|
Get current active user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
current_user: Current user from get_current_user_v2
|
current_user: Current user from get_current_user
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
UserV2: Current active user
|
User: Current active user
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If user is inactive
|
HTTPException: If user is inactive
|
||||||
@@ -257,17 +152,17 @@ def get_current_active_user_v2(
|
|||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
def get_current_admin_user_v2(
|
def get_current_admin_user(
|
||||||
current_user: UserV2 = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
) -> UserV2:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Get current admin user (V2)
|
Get current admin user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
current_user: Current user from get_current_user_v2
|
current_user: Current user from get_current_user
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
UserV2: Current admin user
|
User: Current admin user
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If user is not admin
|
HTTPException: If user is not admin
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Tool_OCR - FastAPI Application Entry Point
|
Tool_OCR - FastAPI Application Entry Point (V2)
|
||||||
Main application setup with CORS, routes, and startup/shutdown events
|
Main application setup with CORS, routes, and startup/shutdown events
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -7,11 +7,9 @@ from fastapi import FastAPI
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.services.background_tasks import task_manager
|
|
||||||
|
|
||||||
# Ensure log directory exists before configuring logging
|
# Ensure log directory exists before configuring logging
|
||||||
Path(settings.log_file).parent.mkdir(parents=True, exist_ok=True)
|
Path(settings.log_file).parent.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -32,19 +30,12 @@ logger = logging.getLogger(__name__)
|
|||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Application lifespan events"""
|
"""Application lifespan events"""
|
||||||
# Startup
|
# Startup
|
||||||
logger.info("Starting Tool_OCR application...")
|
logger.info("Starting Tool_OCR V2 application...")
|
||||||
|
|
||||||
# Ensure all directories exist
|
# Ensure all directories exist
|
||||||
settings.ensure_directories()
|
settings.ensure_directories()
|
||||||
logger.info("All directories created/verified")
|
logger.info("All directories created/verified")
|
||||||
|
|
||||||
# Start cleanup scheduler as background task
|
|
||||||
cleanup_task = asyncio.create_task(task_manager.start_cleanup_scheduler())
|
|
||||||
logger.info("Started cleanup scheduler for expired files")
|
|
||||||
|
|
||||||
# TODO: Initialize database connection pool
|
|
||||||
# TODO: Load PaddleOCR models
|
|
||||||
|
|
||||||
logger.info("Application startup complete")
|
logger.info("Application startup complete")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
@@ -52,21 +43,12 @@ async def lifespan(app: FastAPI):
|
|||||||
# Shutdown
|
# Shutdown
|
||||||
logger.info("Shutting down Tool_OCR application...")
|
logger.info("Shutting down Tool_OCR application...")
|
||||||
|
|
||||||
# Cancel cleanup task
|
|
||||||
cleanup_task.cancel()
|
|
||||||
try:
|
|
||||||
await cleanup_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("Cleanup scheduler stopped")
|
|
||||||
|
|
||||||
# TODO: Close database connections
|
|
||||||
|
|
||||||
|
|
||||||
# Create FastAPI application
|
# Create FastAPI application
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Tool_OCR",
|
title="Tool_OCR V2",
|
||||||
description="OCR Batch Processing System with Structure Extraction",
|
description="OCR Processing System with External Authentication & Task Isolation",
|
||||||
version="0.1.0",
|
version="2.0.0",
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -88,8 +70,8 @@ async def health_check():
|
|||||||
|
|
||||||
response = {
|
response = {
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"service": "Tool_OCR",
|
"service": "Tool_OCR V2",
|
||||||
"version": "0.1.0",
|
"version": "2.0.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add GPU status information
|
# Add GPU status information
|
||||||
@@ -134,26 +116,17 @@ async def health_check():
|
|||||||
async def root():
|
async def root():
|
||||||
"""Root endpoint with API information"""
|
"""Root endpoint with API information"""
|
||||||
return {
|
return {
|
||||||
"message": "Tool_OCR API",
|
"message": "Tool_OCR API V2 - External Authentication",
|
||||||
"version": "0.1.0",
|
"version": "2.0.0",
|
||||||
"docs_url": "/docs",
|
"docs_url": "/docs",
|
||||||
"health_check": "/health",
|
"health_check": "/health",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Include API routers
|
# Include V2 API routers
|
||||||
from app.routers import auth, ocr, export, translation
|
from app.routers import auth, tasks, admin
|
||||||
# V2 routers with external authentication
|
|
||||||
from app.routers import auth_v2, tasks, admin
|
|
||||||
|
|
||||||
# Legacy V1 routers
|
|
||||||
app.include_router(auth.router)
|
app.include_router(auth.router)
|
||||||
app.include_router(ocr.router)
|
|
||||||
app.include_router(export.router)
|
|
||||||
app.include_router(translation.router) # RESERVED for Phase 5
|
|
||||||
|
|
||||||
# New V2 routers with external authentication
|
|
||||||
app.include_router(auth_v2.router)
|
|
||||||
app.include_router(tasks.router)
|
app.include_router(tasks.router)
|
||||||
app.include_router(admin.router)
|
app.include_router(admin.router)
|
||||||
|
|
||||||
|
|||||||
@@ -1,31 +1,20 @@
|
|||||||
"""
|
"""
|
||||||
Tool_OCR - Database Models
|
Tool_OCR - Database Models (V2)
|
||||||
|
|
||||||
New schema with external API authentication and user task isolation.
|
External API authentication with user task isolation.
|
||||||
All tables use 'tool_ocr_' prefix for namespace separation.
|
All tables use 'tool_ocr_' prefix for namespace separation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# New models for external authentication system
|
from app.models.user import User
|
||||||
from app.models.user_v2 import User
|
|
||||||
from app.models.task import Task, TaskFile, TaskStatus
|
from app.models.task import Task, TaskFile, TaskStatus
|
||||||
from app.models.session import Session
|
from app.models.session import Session
|
||||||
|
from app.models.audit_log import AuditLog
|
||||||
# Legacy models (will be deprecated after migration)
|
|
||||||
from app.models.ocr import OCRBatch, OCRFile, OCRResult
|
|
||||||
from app.models.export import ExportRule
|
|
||||||
from app.models.translation import TranslationConfig
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# New authentication and task models
|
|
||||||
"User",
|
"User",
|
||||||
"Task",
|
"Task",
|
||||||
"TaskFile",
|
"TaskFile",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
"Session",
|
"Session",
|
||||||
# Legacy models (deprecated)
|
"AuditLog",
|
||||||
"OCRBatch",
|
|
||||||
"OCRFile",
|
|
||||||
"OCRResult",
|
|
||||||
"ExportRule",
|
|
||||||
"TranslationConfig",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ class AuditLog(Base):
|
|||||||
comment="1 for success, 0 for failure"
|
comment="1 for success, 0 for failure"
|
||||||
)
|
)
|
||||||
error_message = Column(Text, nullable=True, comment="Error details if failed")
|
error_message = Column(Text, nullable=True, comment="Error details if failed")
|
||||||
metadata = Column(Text, nullable=True, comment="Additional JSON metadata")
|
extra_data = Column(Text, nullable=True, comment="Additional JSON metadata")
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
@@ -90,6 +90,6 @@ class AuditLog(Base):
|
|||||||
"resource_id": self.resource_id,
|
"resource_id": self.resource_id,
|
||||||
"success": bool(self.success),
|
"success": bool(self.success),
|
||||||
"error_message": self.error_message,
|
"error_message": self.error_message,
|
||||||
"metadata": self.metadata,
|
"extra_data": self.extra_data,
|
||||||
"created_at": self.created_at.isoformat() if self.created_at else None
|
"created_at": self.created_at.isoformat() if self.created_at else None
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - Export Rule Model
|
|
||||||
User-defined export rules and formatting configurations
|
|
||||||
"""
|
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, Text, ForeignKey, JSON
|
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.core.database import Base
|
|
||||||
|
|
||||||
|
|
||||||
class ExportRule(Base):
|
|
||||||
"""Export rule configuration for customized output formatting"""
|
|
||||||
|
|
||||||
__tablename__ = "paddle_ocr_export_rules"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
|
||||||
user_id = Column(Integer, ForeignKey("paddle_ocr_users.id", ondelete="CASCADE"), nullable=False, index=True)
|
|
||||||
rule_name = Column(String(100), nullable=False)
|
|
||||||
description = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
# Rule configuration stored as JSON
|
|
||||||
# {
|
|
||||||
# "filters": {
|
|
||||||
# "confidence_threshold": 0.8,
|
|
||||||
# "filename_pattern": "invoice_*",
|
|
||||||
# "language": "ch"
|
|
||||||
# },
|
|
||||||
# "formatting": {
|
|
||||||
# "add_line_numbers": true,
|
|
||||||
# "sort_by_position": true,
|
|
||||||
# "group_by_filename": false
|
|
||||||
# },
|
|
||||||
# "export_options": {
|
|
||||||
# "include_metadata": true,
|
|
||||||
# "include_confidence": true,
|
|
||||||
# "include_bounding_boxes": false
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
config_json = Column(JSON, nullable=False)
|
|
||||||
|
|
||||||
# CSS template for PDF export (optional)
|
|
||||||
# Can reference predefined templates: "default", "academic", "business", "report"
|
|
||||||
# Or store custom CSS
|
|
||||||
css_template = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
|
||||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
user = relationship("User", back_populates="export_rules")
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<ExportRule(id={self.id}, name='{self.rule_name}', user_id={self.user_id})>"
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - OCR Models
|
|
||||||
Database models for OCR batches, files, and results
|
|
||||||
"""
|
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, Float, Text, ForeignKey, Enum, JSON
|
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
from datetime import datetime
|
|
||||||
import enum
|
|
||||||
|
|
||||||
from app.core.database import Base
|
|
||||||
|
|
||||||
|
|
||||||
class BatchStatus(str, enum.Enum):
|
|
||||||
"""Batch processing status"""
|
|
||||||
PENDING = "pending"
|
|
||||||
PROCESSING = "processing"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
PARTIAL = "partial" # Some files failed
|
|
||||||
FAILED = "failed"
|
|
||||||
|
|
||||||
|
|
||||||
class FileStatus(str, enum.Enum):
|
|
||||||
"""Individual file processing status"""
|
|
||||||
PENDING = "pending"
|
|
||||||
PROCESSING = "processing"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
|
|
||||||
|
|
||||||
class OCRBatch(Base):
|
|
||||||
"""OCR batch processing tracking"""
|
|
||||||
|
|
||||||
__tablename__ = "paddle_ocr_batches"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
|
||||||
user_id = Column(Integer, ForeignKey("paddle_ocr_users.id", ondelete="CASCADE"), nullable=False, index=True)
|
|
||||||
batch_name = Column(String(255), nullable=True)
|
|
||||||
status = Column(Enum(BatchStatus), default=BatchStatus.PENDING, nullable=False, index=True)
|
|
||||||
total_files = Column(Integer, default=0, nullable=False)
|
|
||||||
completed_files = Column(Integer, default=0, nullable=False)
|
|
||||||
failed_files = Column(Integer, default=0, nullable=False)
|
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
|
||||||
started_at = Column(DateTime, nullable=True)
|
|
||||||
completed_at = Column(DateTime, nullable=True)
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
user = relationship("User", back_populates="ocr_batches")
|
|
||||||
files = relationship("OCRFile", back_populates="batch", cascade="all, delete-orphan")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def progress_percentage(self) -> float:
|
|
||||||
"""Calculate progress percentage"""
|
|
||||||
if self.total_files == 0:
|
|
||||||
return 0.0
|
|
||||||
return (self.completed_files / self.total_files) * 100
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<OCRBatch(id={self.id}, status='{self.status}', progress={self.progress_percentage:.1f}%)>"
|
|
||||||
|
|
||||||
|
|
||||||
class OCRFile(Base):
|
|
||||||
"""Individual file in an OCR batch"""
|
|
||||||
|
|
||||||
__tablename__ = "paddle_ocr_files"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
|
||||||
batch_id = Column(Integer, ForeignKey("paddle_ocr_batches.id", ondelete="CASCADE"), nullable=False, index=True)
|
|
||||||
filename = Column(String(255), nullable=False)
|
|
||||||
original_filename = Column(String(255), nullable=False)
|
|
||||||
file_path = Column(String(512), nullable=False)
|
|
||||||
file_size = Column(Integer, nullable=False) # Size in bytes
|
|
||||||
file_format = Column(String(20), nullable=False) # png, jpg, pdf, etc.
|
|
||||||
status = Column(Enum(FileStatus), default=FileStatus.PENDING, nullable=False, index=True)
|
|
||||||
error_message = Column(Text, nullable=True)
|
|
||||||
retry_count = Column(Integer, default=0, nullable=False) # Number of retry attempts
|
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
|
||||||
started_at = Column(DateTime, nullable=True)
|
|
||||||
completed_at = Column(DateTime, nullable=True)
|
|
||||||
processing_time = Column(Float, nullable=True) # Processing time in seconds
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
batch = relationship("OCRBatch", back_populates="files")
|
|
||||||
result = relationship("OCRResult", back_populates="file", uselist=False, cascade="all, delete-orphan")
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<OCRFile(id={self.id}, filename='{self.filename}', status='{self.status}')>"
|
|
||||||
|
|
||||||
|
|
||||||
class OCRResult(Base):
|
|
||||||
"""OCR processing result with structure and images"""
|
|
||||||
|
|
||||||
__tablename__ = "paddle_ocr_results"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
|
||||||
file_id = Column(Integer, ForeignKey("paddle_ocr_files.id", ondelete="CASCADE"), unique=True, nullable=False, index=True)
|
|
||||||
|
|
||||||
# Output file paths
|
|
||||||
markdown_path = Column(String(512), nullable=True) # Path to Markdown file
|
|
||||||
json_path = Column(String(512), nullable=True) # Path to JSON file
|
|
||||||
images_dir = Column(String(512), nullable=True) # Directory containing extracted images
|
|
||||||
|
|
||||||
# OCR metadata
|
|
||||||
detected_language = Column(String(20), nullable=True) # ch, en, japan, korean
|
|
||||||
total_text_regions = Column(Integer, default=0, nullable=False)
|
|
||||||
average_confidence = Column(Float, nullable=True)
|
|
||||||
|
|
||||||
# Layout structure data (stored as JSON)
|
|
||||||
# Contains: layout elements (title, paragraph, table, image, formula), reading order, bounding boxes
|
|
||||||
layout_data = Column(JSON, nullable=True)
|
|
||||||
|
|
||||||
# Extracted images metadata (stored as JSON)
|
|
||||||
# Contains: list of {image_path, bbox, element_type}
|
|
||||||
images_metadata = Column(JSON, nullable=True)
|
|
||||||
|
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
file = relationship("OCRFile", back_populates="result")
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<OCRResult(id={self.id}, file_id={self.file_id}, language='{self.detected_language}')>"
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - Translation Config Model (RESERVED)
|
|
||||||
Reserved for future translation feature implementation
|
|
||||||
"""
|
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON
|
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.core.database import Base
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationConfig(Base):
|
|
||||||
"""
|
|
||||||
Translation configuration (RESERVED for future implementation)
|
|
||||||
|
|
||||||
This table is created but not actively used until translation feature is implemented.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "paddle_ocr_translation_configs"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
|
||||||
user_id = Column(Integer, ForeignKey("paddle_ocr_users.id", ondelete="CASCADE"), nullable=False, index=True)
|
|
||||||
|
|
||||||
source_lang = Column(String(20), nullable=False) # ch, en, japan, korean, etc.
|
|
||||||
target_lang = Column(String(20), nullable=False) # en, ch, japan, korean, etc.
|
|
||||||
|
|
||||||
# Translation engine type: "offline" (argostranslate), "ernie", "google", "deepl"
|
|
||||||
engine_type = Column(String(50), nullable=False, default="offline")
|
|
||||||
|
|
||||||
# Engine-specific configuration stored as JSON
|
|
||||||
# For offline (argostranslate): {"model_path": "/path/to/model"}
|
|
||||||
# For API-based: {"api_key": "xxx", "endpoint": "https://..."}
|
|
||||||
engine_config = Column(JSON, nullable=True)
|
|
||||||
|
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
|
||||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
user = relationship("User", back_populates="translation_configs")
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<TranslationConfig(id={self.id}, {self.source_lang}->{self.target_lang}, engine='{self.engine_type}')>"
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Tool_OCR - User Model
|
Tool_OCR - User Model v2.0
|
||||||
User authentication and management
|
External API authentication with simplified schema
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean
|
from sqlalchemy import Column, Integer, String, DateTime, Boolean
|
||||||
@@ -11,24 +11,39 @@ from app.core.database import Base
|
|||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
"""User model for JWT authentication"""
|
"""
|
||||||
|
User model for external API authentication
|
||||||
|
|
||||||
__tablename__ = "paddle_ocr_users"
|
Uses email as primary identifier from Azure AD.
|
||||||
|
No password storage - authentication via external API only.
|
||||||
|
"""
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
__tablename__ = "tool_ocr_users"
|
||||||
username = Column(String(50), unique=True, nullable=False, index=True)
|
|
||||||
email = Column(String(100), unique=True, nullable=False, index=True)
|
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||||
password_hash = Column(String(255), nullable=False)
|
email = Column(String(255), unique=True, nullable=False, index=True,
|
||||||
full_name = Column(String(100), nullable=True)
|
comment="Primary identifier from Azure AD")
|
||||||
is_active = Column(Boolean, default=True, nullable=False)
|
display_name = Column(String(255), nullable=True,
|
||||||
is_admin = Column(Boolean, default=False, nullable=False)
|
comment="Display name from API response")
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
last_login = Column(DateTime, nullable=True)
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
ocr_batches = relationship("OCRBatch", back_populates="user", cascade="all, delete-orphan")
|
tasks = relationship("Task", back_populates="user", cascade="all, delete-orphan")
|
||||||
export_rules = relationship("ExportRule", back_populates="user", cascade="all, delete-orphan")
|
sessions = relationship("Session", back_populates="user", cascade="all, delete-orphan")
|
||||||
translation_configs = relationship("TranslationConfig", back_populates="user", cascade="all, delete-orphan")
|
audit_logs = relationship("AuditLog", back_populates="user")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<User(id={self.id}, username='{self.username}', email='{self.email}')>"
|
return f"<User(id={self.id}, email='{self.email}', display_name='{self.display_name}')>"
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""Convert user to dictionary"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"email": self.email,
|
||||||
|
"display_name": self.display_name,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"last_login": self.last_login.isoformat() if self.last_login else None,
|
||||||
|
"is_active": self.is_active
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,49 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - User Model v2.0
|
|
||||||
External API authentication with simplified schema
|
|
||||||
"""
|
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean
|
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.core.database import Base
|
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
|
||||||
"""
|
|
||||||
User model for external API authentication
|
|
||||||
|
|
||||||
Uses email as primary identifier from Azure AD.
|
|
||||||
No password storage - authentication via external API only.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "tool_ocr_users"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
|
||||||
email = Column(String(255), unique=True, nullable=False, index=True,
|
|
||||||
comment="Primary identifier from Azure AD")
|
|
||||||
display_name = Column(String(255), nullable=True,
|
|
||||||
comment="Display name from API response")
|
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
|
||||||
last_login = Column(DateTime, nullable=True)
|
|
||||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
tasks = relationship("Task", back_populates="user", cascade="all, delete-orphan")
|
|
||||||
sessions = relationship("Session", back_populates="user", cascade="all, delete-orphan")
|
|
||||||
audit_logs = relationship("AuditLog", back_populates="user")
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<User(id={self.id}, email='{self.email}', display_name='{self.display_name}')>"
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
"""Convert user to dictionary"""
|
|
||||||
return {
|
|
||||||
"id": self.id,
|
|
||||||
"email": self.email,
|
|
||||||
"display_name": self.display_name,
|
|
||||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
||||||
"last_login": self.last_login.isoformat() if self.last_login else None,
|
|
||||||
"is_active": self.is_active
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Tool_OCR - API Routers
|
Tool_OCR - API Routers (V2)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.routers import auth, ocr, export, translation
|
from app.routers import auth, tasks, admin
|
||||||
|
|
||||||
__all__ = ["auth", "ocr", "export", "translation"]
|
__all__ = ["auth", "tasks", "admin"]
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from datetime import datetime
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.deps import get_db, get_current_admin_user_v2
|
from app.core.deps import get_db, get_current_admin_user
|
||||||
from app.models.user_v2 import User
|
from app.models.user import User
|
||||||
from app.services.admin_service import admin_service
|
from app.services.admin_service import admin_service
|
||||||
from app.services.audit_service import audit_service
|
from app.services.audit_service import audit_service
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ router = APIRouter(prefix="/api/v2/admin", tags=["Admin"])
|
|||||||
@router.get("/stats", summary="Get system statistics")
|
@router.get("/stats", summary="Get system statistics")
|
||||||
async def get_system_stats(
|
async def get_system_stats(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
admin_user: User = Depends(get_current_admin_user_v2)
|
admin_user: User = Depends(get_current_admin_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get overall system statistics
|
Get overall system statistics
|
||||||
@@ -47,7 +47,7 @@ async def list_users(
|
|||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
page_size: int = Query(50, ge=1, le=100),
|
page_size: int = Query(50, ge=1, le=100),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
admin_user: User = Depends(get_current_admin_user_v2)
|
admin_user: User = Depends(get_current_admin_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get list of all users with statistics
|
Get list of all users with statistics
|
||||||
@@ -79,7 +79,7 @@ async def get_top_users(
|
|||||||
metric: str = Query("tasks", regex="^(tasks|completed_tasks)$"),
|
metric: str = Query("tasks", regex="^(tasks|completed_tasks)$"),
|
||||||
limit: int = Query(10, ge=1, le=50),
|
limit: int = Query(10, ge=1, le=50),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
admin_user: User = Depends(get_current_admin_user_v2)
|
admin_user: User = Depends(get_current_admin_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get top users by metric
|
Get top users by metric
|
||||||
@@ -115,7 +115,7 @@ async def get_audit_logs(
|
|||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
page_size: int = Query(100, ge=1, le=500),
|
page_size: int = Query(100, ge=1, le=500),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
admin_user: User = Depends(get_current_admin_user_v2)
|
admin_user: User = Depends(get_current_admin_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get audit logs with filtering
|
Get audit logs with filtering
|
||||||
@@ -169,7 +169,7 @@ async def get_user_activity_summary(
|
|||||||
user_id: int,
|
user_id: int,
|
||||||
days: int = Query(30, ge=1, le=365),
|
days: int = Query(30, ge=1, le=365),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
admin_user: User = Depends(get_current_admin_user_v2)
|
admin_user: User = Depends(get_current_admin_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get user activity summary for the last N days
|
Get user activity summary for the last N days
|
||||||
|
|||||||
@@ -1,70 +1,347 @@
|
|||||||
"""
|
"""
|
||||||
Tool_OCR - Authentication Router
|
Tool_OCR - External Authentication Router (V2)
|
||||||
JWT login endpoint
|
Handles authentication via external Microsoft Azure AD API
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import datetime, timedelta
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.deps import get_db
|
from app.core.deps import get_db, get_current_user
|
||||||
from app.core.security import verify_password, create_access_token
|
from app.core.security import create_access_token
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.auth import LoginRequest, Token
|
from app.models.session import Session as UserSession
|
||||||
|
from app.schemas.auth import LoginRequest, Token, UserResponse
|
||||||
|
from app.services.external_auth_service import external_auth_service
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/auth", tags=["Authentication"])
|
router = APIRouter(prefix="/api/v2/auth", tags=["Authentication V2"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=Token, summary="User login")
|
def get_client_ip(request: Request) -> str:
|
||||||
|
"""Extract client IP address from request"""
|
||||||
|
# Check X-Forwarded-For header (for proxies)
|
||||||
|
forwarded = request.headers.get("X-Forwarded-For")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
# Check X-Real-IP header
|
||||||
|
real_ip = request.headers.get("X-Real-IP")
|
||||||
|
if real_ip:
|
||||||
|
return real_ip
|
||||||
|
# Fallback to direct client
|
||||||
|
return request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_agent(request: Request) -> str:
|
||||||
|
"""Extract user agent from request"""
|
||||||
|
return request.headers.get("User-Agent", "unknown")[:500]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=Token, summary="External API login")
|
||||||
async def login(
|
async def login(
|
||||||
login_data: LoginRequest,
|
login_data: LoginRequest,
|
||||||
|
request: Request,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
User login with username and password
|
User login via external Microsoft Azure AD API
|
||||||
|
|
||||||
Returns JWT access token for authentication
|
Returns JWT access token and stores session information
|
||||||
|
|
||||||
- **username**: User's username
|
- **username**: User's email address
|
||||||
- **password**: User's password
|
- **password**: User's password
|
||||||
"""
|
"""
|
||||||
# Query user by username
|
# Call external authentication API
|
||||||
user = db.query(User).filter(User.username == login_data.username).first()
|
success, auth_response, error_msg = await external_auth_service.authenticate_user(
|
||||||
|
username=login_data.username,
|
||||||
|
password=login_data.password
|
||||||
|
)
|
||||||
|
|
||||||
# Verify user exists and password is correct
|
if not success or not auth_response:
|
||||||
if not user or not verify_password(login_data.password, user.password_hash):
|
logger.warning(
|
||||||
logger.warning(f"Failed login attempt for username: {login_data.username}")
|
f"External auth failed for user {login_data.username}: {error_msg}"
|
||||||
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Incorrect username or password",
|
detail=error_msg or "Authentication failed",
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract user info from external API response
|
||||||
|
user_info = auth_response.user_info
|
||||||
|
email = user_info.email
|
||||||
|
display_name = user_info.name
|
||||||
|
|
||||||
|
# Find or create user in database
|
||||||
|
user = db.query(User).filter(User.email == email).first()
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
# Create new user
|
||||||
|
user = User(
|
||||||
|
email=email,
|
||||||
|
display_name=display_name,
|
||||||
|
is_active=True,
|
||||||
|
last_login=datetime.utcnow()
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(user)
|
||||||
|
logger.info(f"Created new user: {email} (ID: {user.id})")
|
||||||
|
else:
|
||||||
|
# Update existing user
|
||||||
|
user.display_name = display_name
|
||||||
|
user.last_login = datetime.utcnow()
|
||||||
|
|
||||||
# Check if user is active
|
# Check if user is active
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
logger.warning(f"Inactive user login attempt: {login_data.username}")
|
logger.warning(f"Inactive user login attempt: {email}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="User account is inactive"
|
detail="User account is inactive"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create access token
|
db.commit()
|
||||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
db.refresh(user)
|
||||||
access_token = create_access_token(
|
logger.info(f"Updated existing user: {email} (ID: {user.id})")
|
||||||
data={"sub": str(user.id), "username": user.username},
|
|
||||||
expires_delta=access_token_expires
|
# Parse token expiration
|
||||||
|
try:
|
||||||
|
expires_at = datetime.fromisoformat(auth_response.expires_at.replace('Z', '+00:00'))
|
||||||
|
issued_at = datetime.fromisoformat(auth_response.issued_at.replace('Z', '+00:00'))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse token timestamps: {e}")
|
||||||
|
expires_at = datetime.utcnow() + timedelta(seconds=auth_response.expires_in)
|
||||||
|
issued_at = datetime.utcnow()
|
||||||
|
|
||||||
|
# Create session in database
|
||||||
|
# TODO: Implement token encryption before storing
|
||||||
|
session = UserSession(
|
||||||
|
user_id=user.id,
|
||||||
|
access_token=auth_response.access_token, # Should be encrypted
|
||||||
|
id_token=auth_response.id_token, # Should be encrypted
|
||||||
|
token_type=auth_response.token_type,
|
||||||
|
expires_at=expires_at,
|
||||||
|
issued_at=issued_at,
|
||||||
|
ip_address=get_client_ip(request),
|
||||||
|
user_agent=get_user_agent(request)
|
||||||
|
)
|
||||||
|
db.add(session)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(session)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created session {session.id} for user {user.email} "
|
||||||
|
f"(expires: {expires_at})"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Successful login: {user.username} (ID: {user.id})")
|
# Create internal JWT token for API access
|
||||||
|
# This token contains user ID and session ID
|
||||||
|
internal_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||||
|
internal_access_token = create_access_token(
|
||||||
|
data={
|
||||||
|
"sub": str(user.id),
|
||||||
|
"email": user.email,
|
||||||
|
"session_id": session.id
|
||||||
|
},
|
||||||
|
expires_delta=internal_token_expires
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"access_token": access_token,
|
"access_token": internal_access_token,
|
||||||
"token_type": "bearer",
|
"token_type": "bearer",
|
||||||
"expires_in": settings.access_token_expire_minutes * 60 # Convert to seconds
|
"expires_in": int(internal_token_expires.total_seconds()),
|
||||||
|
"user": {
|
||||||
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"display_name": user.display_name
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout", summary="User logout")
|
||||||
|
async def logout(
|
||||||
|
session_id: Optional[int] = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
User logout - invalidates session
|
||||||
|
|
||||||
|
- **session_id**: Session ID to logout (optional, logs out all if not provided)
|
||||||
|
"""
|
||||||
|
# TODO: Implement proper current_user dependency from JWT token
|
||||||
|
# For now, this is a placeholder
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
# Logout specific session
|
||||||
|
session = db.query(UserSession).filter(
|
||||||
|
UserSession.id == session_id,
|
||||||
|
UserSession.user_id == current_user.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if session:
|
||||||
|
db.delete(session)
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"Logged out session {session_id} for user {current_user.email}")
|
||||||
|
return {"message": "Logged out successfully"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Session not found"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Logout all sessions
|
||||||
|
sessions = db.query(UserSession).filter(
|
||||||
|
UserSession.user_id == current_user.id
|
||||||
|
).all()
|
||||||
|
|
||||||
|
count = len(sessions)
|
||||||
|
for session in sessions:
|
||||||
|
db.delete(session)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"Logged out all {count} sessions for user {current_user.email}")
|
||||||
|
return {"message": f"Logged out {count} sessions"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse, summary="Get current user")
|
||||||
|
async def get_me(
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get current authenticated user information
|
||||||
|
"""
|
||||||
|
# TODO: Implement proper current_user dependency from JWT token
|
||||||
|
return {
|
||||||
|
"id": current_user.id,
|
||||||
|
"email": current_user.email,
|
||||||
|
"display_name": current_user.display_name,
|
||||||
|
"created_at": current_user.created_at,
|
||||||
|
"last_login": current_user.last_login,
|
||||||
|
"is_active": current_user.is_active
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions", summary="List user sessions")
|
||||||
|
async def list_sessions(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List all active sessions for current user
|
||||||
|
"""
|
||||||
|
sessions = db.query(UserSession).filter(
|
||||||
|
UserSession.user_id == current_user.id
|
||||||
|
).order_by(UserSession.created_at.desc()).all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sessions": [
|
||||||
|
{
|
||||||
|
"id": s.id,
|
||||||
|
"token_type": s.token_type,
|
||||||
|
"expires_at": s.expires_at,
|
||||||
|
"issued_at": s.issued_at,
|
||||||
|
"ip_address": s.ip_address,
|
||||||
|
"user_agent": s.user_agent,
|
||||||
|
"created_at": s.created_at,
|
||||||
|
"last_accessed_at": s.last_accessed_at,
|
||||||
|
"is_expired": s.is_expired,
|
||||||
|
"time_until_expiry": s.time_until_expiry
|
||||||
|
}
|
||||||
|
for s in sessions
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=Token, summary="Refresh access token")
|
||||||
|
async def refresh_token(
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Refresh access token before expiration
|
||||||
|
|
||||||
|
Re-authenticates with external API using stored session.
|
||||||
|
Note: Since external API doesn't provide refresh tokens,
|
||||||
|
we re-issue internal JWT tokens with extended expiry.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Find user's most recent session
|
||||||
|
session = db.query(UserSession).filter(
|
||||||
|
UserSession.user_id == current_user.id
|
||||||
|
).order_by(UserSession.created_at.desc()).first()
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="No active session found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if token is expiring soon (within TOKEN_REFRESH_BUFFER)
|
||||||
|
if not external_auth_service.is_token_expiring_soon(session.expires_at):
|
||||||
|
# Token still valid for a while, just issue new internal JWT
|
||||||
|
internal_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||||
|
internal_access_token = create_access_token(
|
||||||
|
data={
|
||||||
|
"sub": str(current_user.id),
|
||||||
|
"email": current_user.email,
|
||||||
|
"session_id": session.id
|
||||||
|
},
|
||||||
|
expires_delta=internal_token_expires
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Refreshed internal token for user {current_user.email}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": internal_access_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": int(internal_token_expires.total_seconds()),
|
||||||
|
"user": {
|
||||||
|
"id": current_user.id,
|
||||||
|
"email": current_user.email,
|
||||||
|
"display_name": current_user.display_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# External token expiring soon - would need re-authentication
|
||||||
|
# For now, we extend internal token and log a warning
|
||||||
|
logger.warning(
|
||||||
|
f"External token expiring soon for user {current_user.email}. "
|
||||||
|
"User should re-authenticate."
|
||||||
|
)
|
||||||
|
|
||||||
|
internal_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||||
|
internal_access_token = create_access_token(
|
||||||
|
data={
|
||||||
|
"sub": str(current_user.id),
|
||||||
|
"email": current_user.email,
|
||||||
|
"session_id": session.id
|
||||||
|
},
|
||||||
|
expires_delta=internal_token_expires
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": internal_access_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": int(internal_token_expires.total_seconds()),
|
||||||
|
"user": {
|
||||||
|
"id": current_user.id,
|
||||||
|
"email": current_user.email,
|
||||||
|
"display_name": current_user.display_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Token refresh failed for user {current_user.id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Token refresh failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,347 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - External Authentication Router (V2)
|
|
||||||
Handles authentication via external Microsoft Azure AD API
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.deps import get_db, get_current_user_v2
|
|
||||||
from app.core.security import create_access_token
|
|
||||||
from app.models.user_v2 import User
|
|
||||||
from app.models.session import Session as UserSession
|
|
||||||
from app.schemas.auth import LoginRequest, Token, UserResponse
|
|
||||||
from app.services.external_auth_service import external_auth_service
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v2/auth", tags=["Authentication V2"])
|
|
||||||
|
|
||||||
|
|
||||||
def get_client_ip(request: Request) -> str:
|
|
||||||
"""Extract client IP address from request"""
|
|
||||||
# Check X-Forwarded-For header (for proxies)
|
|
||||||
forwarded = request.headers.get("X-Forwarded-For")
|
|
||||||
if forwarded:
|
|
||||||
return forwarded.split(",")[0].strip()
|
|
||||||
# Check X-Real-IP header
|
|
||||||
real_ip = request.headers.get("X-Real-IP")
|
|
||||||
if real_ip:
|
|
||||||
return real_ip
|
|
||||||
# Fallback to direct client
|
|
||||||
return request.client.host if request.client else "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_agent(request: Request) -> str:
|
|
||||||
"""Extract user agent from request"""
|
|
||||||
return request.headers.get("User-Agent", "unknown")[:500]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=Token, summary="External API login")
|
|
||||||
async def login(
|
|
||||||
login_data: LoginRequest,
|
|
||||||
request: Request,
|
|
||||||
db: Session = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
User login via external Microsoft Azure AD API
|
|
||||||
|
|
||||||
Returns JWT access token and stores session information
|
|
||||||
|
|
||||||
- **username**: User's email address
|
|
||||||
- **password**: User's password
|
|
||||||
"""
|
|
||||||
# Call external authentication API
|
|
||||||
success, auth_response, error_msg = await external_auth_service.authenticate_user(
|
|
||||||
username=login_data.username,
|
|
||||||
password=login_data.password
|
|
||||||
)
|
|
||||||
|
|
||||||
if not success or not auth_response:
|
|
||||||
logger.warning(
|
|
||||||
f"External auth failed for user {login_data.username}: {error_msg}"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=error_msg or "Authentication failed",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract user info from external API response
|
|
||||||
user_info = auth_response.user_info
|
|
||||||
email = user_info.email
|
|
||||||
display_name = user_info.name
|
|
||||||
|
|
||||||
# Find or create user in database
|
|
||||||
user = db.query(User).filter(User.email == email).first()
|
|
||||||
|
|
||||||
if not user:
|
|
||||||
# Create new user
|
|
||||||
user = User(
|
|
||||||
email=email,
|
|
||||||
display_name=display_name,
|
|
||||||
is_active=True,
|
|
||||||
last_login=datetime.utcnow()
|
|
||||||
)
|
|
||||||
db.add(user)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(user)
|
|
||||||
logger.info(f"Created new user: {email} (ID: {user.id})")
|
|
||||||
else:
|
|
||||||
# Update existing user
|
|
||||||
user.display_name = display_name
|
|
||||||
user.last_login = datetime.utcnow()
|
|
||||||
|
|
||||||
# Check if user is active
|
|
||||||
if not user.is_active:
|
|
||||||
logger.warning(f"Inactive user login attempt: {email}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="User account is inactive"
|
|
||||||
)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db.refresh(user)
|
|
||||||
logger.info(f"Updated existing user: {email} (ID: {user.id})")
|
|
||||||
|
|
||||||
# Parse token expiration
|
|
||||||
try:
|
|
||||||
expires_at = datetime.fromisoformat(auth_response.expires_at.replace('Z', '+00:00'))
|
|
||||||
issued_at = datetime.fromisoformat(auth_response.issued_at.replace('Z', '+00:00'))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to parse token timestamps: {e}")
|
|
||||||
expires_at = datetime.utcnow() + timedelta(seconds=auth_response.expires_in)
|
|
||||||
issued_at = datetime.utcnow()
|
|
||||||
|
|
||||||
# Create session in database
|
|
||||||
# TODO: Implement token encryption before storing
|
|
||||||
session = UserSession(
|
|
||||||
user_id=user.id,
|
|
||||||
access_token=auth_response.access_token, # Should be encrypted
|
|
||||||
id_token=auth_response.id_token, # Should be encrypted
|
|
||||||
token_type=auth_response.token_type,
|
|
||||||
expires_at=expires_at,
|
|
||||||
issued_at=issued_at,
|
|
||||||
ip_address=get_client_ip(request),
|
|
||||||
user_agent=get_user_agent(request)
|
|
||||||
)
|
|
||||||
db.add(session)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(session)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Created session {session.id} for user {user.email} "
|
|
||||||
f"(expires: {expires_at})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create internal JWT token for API access
|
|
||||||
# This token contains user ID and session ID
|
|
||||||
internal_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
|
||||||
internal_access_token = create_access_token(
|
|
||||||
data={
|
|
||||||
"sub": str(user.id),
|
|
||||||
"email": user.email,
|
|
||||||
"session_id": session.id
|
|
||||||
},
|
|
||||||
expires_delta=internal_token_expires
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"access_token": internal_access_token,
|
|
||||||
"token_type": "bearer",
|
|
||||||
"expires_in": int(internal_token_expires.total_seconds()),
|
|
||||||
"user": {
|
|
||||||
"id": user.id,
|
|
||||||
"email": user.email,
|
|
||||||
"display_name": user.display_name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/logout", summary="User logout")
|
|
||||||
async def logout(
|
|
||||||
session_id: Optional[int] = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user_v2)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
User logout - invalidates session
|
|
||||||
|
|
||||||
- **session_id**: Session ID to logout (optional, logs out all if not provided)
|
|
||||||
"""
|
|
||||||
# TODO: Implement proper current_user dependency from JWT token
|
|
||||||
# For now, this is a placeholder
|
|
||||||
|
|
||||||
if session_id:
|
|
||||||
# Logout specific session
|
|
||||||
session = db.query(UserSession).filter(
|
|
||||||
UserSession.id == session_id,
|
|
||||||
UserSession.user_id == current_user.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if session:
|
|
||||||
db.delete(session)
|
|
||||||
db.commit()
|
|
||||||
logger.info(f"Logged out session {session_id} for user {current_user.email}")
|
|
||||||
return {"message": "Logged out successfully"}
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Session not found"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Logout all sessions
|
|
||||||
sessions = db.query(UserSession).filter(
|
|
||||||
UserSession.user_id == current_user.id
|
|
||||||
).all()
|
|
||||||
|
|
||||||
count = len(sessions)
|
|
||||||
for session in sessions:
|
|
||||||
db.delete(session)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
logger.info(f"Logged out all {count} sessions for user {current_user.email}")
|
|
||||||
return {"message": f"Logged out {count} sessions"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse, summary="Get current user")
|
|
||||||
async def get_me(
|
|
||||||
current_user: User = Depends(get_current_user_v2)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get current authenticated user information
|
|
||||||
"""
|
|
||||||
# TODO: Implement proper current_user dependency from JWT token
|
|
||||||
return {
|
|
||||||
"id": current_user.id,
|
|
||||||
"email": current_user.email,
|
|
||||||
"display_name": current_user.display_name,
|
|
||||||
"created_at": current_user.created_at,
|
|
||||||
"last_login": current_user.last_login,
|
|
||||||
"is_active": current_user.is_active
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions", summary="List user sessions")
|
|
||||||
async def list_sessions(
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user_v2)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
List all active sessions for current user
|
|
||||||
"""
|
|
||||||
sessions = db.query(UserSession).filter(
|
|
||||||
UserSession.user_id == current_user.id
|
|
||||||
).order_by(UserSession.created_at.desc()).all()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"sessions": [
|
|
||||||
{
|
|
||||||
"id": s.id,
|
|
||||||
"token_type": s.token_type,
|
|
||||||
"expires_at": s.expires_at,
|
|
||||||
"issued_at": s.issued_at,
|
|
||||||
"ip_address": s.ip_address,
|
|
||||||
"user_agent": s.user_agent,
|
|
||||||
"created_at": s.created_at,
|
|
||||||
"last_accessed_at": s.last_accessed_at,
|
|
||||||
"is_expired": s.is_expired,
|
|
||||||
"time_until_expiry": s.time_until_expiry
|
|
||||||
}
|
|
||||||
for s in sessions
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=Token, summary="Refresh access token")
|
|
||||||
async def refresh_token(
|
|
||||||
request: Request,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user_v2)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Refresh access token before expiration
|
|
||||||
|
|
||||||
Re-authenticates with external API using stored session.
|
|
||||||
Note: Since external API doesn't provide refresh tokens,
|
|
||||||
we re-issue internal JWT tokens with extended expiry.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Find user's most recent session
|
|
||||||
session = db.query(UserSession).filter(
|
|
||||||
UserSession.user_id == current_user.id
|
|
||||||
).order_by(UserSession.created_at.desc()).first()
|
|
||||||
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="No active session found"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if token is expiring soon (within TOKEN_REFRESH_BUFFER)
|
|
||||||
if not external_auth_service.is_token_expiring_soon(session.expires_at):
|
|
||||||
# Token still valid for a while, just issue new internal JWT
|
|
||||||
internal_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
|
||||||
internal_access_token = create_access_token(
|
|
||||||
data={
|
|
||||||
"sub": str(current_user.id),
|
|
||||||
"email": current_user.email,
|
|
||||||
"session_id": session.id
|
|
||||||
},
|
|
||||||
expires_delta=internal_token_expires
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Refreshed internal token for user {current_user.email}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"access_token": internal_access_token,
|
|
||||||
"token_type": "bearer",
|
|
||||||
"expires_in": int(internal_token_expires.total_seconds()),
|
|
||||||
"user": {
|
|
||||||
"id": current_user.id,
|
|
||||||
"email": current_user.email,
|
|
||||||
"display_name": current_user.display_name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# External token expiring soon - would need re-authentication
|
|
||||||
# For now, we extend internal token and log a warning
|
|
||||||
logger.warning(
|
|
||||||
f"External token expiring soon for user {current_user.email}. "
|
|
||||||
"User should re-authenticate."
|
|
||||||
)
|
|
||||||
|
|
||||||
internal_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
|
||||||
internal_access_token = create_access_token(
|
|
||||||
data={
|
|
||||||
"sub": str(current_user.id),
|
|
||||||
"email": current_user.email,
|
|
||||||
"session_id": session.id
|
|
||||||
},
|
|
||||||
expires_delta=internal_token_expires
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"access_token": internal_access_token,
|
|
||||||
"token_type": "bearer",
|
|
||||||
"expires_in": int(internal_token_expires.total_seconds()),
|
|
||||||
"user": {
|
|
||||||
"id": current_user.id,
|
|
||||||
"email": current_user.email,
|
|
||||||
"display_name": current_user.display_name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Token refresh failed for user {current_user.id}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Token refresh failed: {str(e)}"
|
|
||||||
)
|
|
||||||
@@ -1,338 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - Export Router
|
|
||||||
Export results in multiple formats
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from fastapi.responses import FileResponse
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.deps import get_db, get_current_active_user
|
|
||||||
from app.models.user import User
|
|
||||||
from app.models.ocr import OCRBatch, OCRFile, OCRResult, FileStatus
|
|
||||||
from app.models.export import ExportRule
|
|
||||||
from app.schemas.export import (
|
|
||||||
ExportRequest,
|
|
||||||
ExportRuleCreate,
|
|
||||||
ExportRuleUpdate,
|
|
||||||
ExportRuleResponse,
|
|
||||||
CSSTemplateResponse,
|
|
||||||
)
|
|
||||||
from app.services.export_service import ExportService, ExportError
|
|
||||||
from app.services.pdf_generator import PDFGenerator
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/export", tags=["Export"])
|
|
||||||
|
|
||||||
# Initialize services
|
|
||||||
export_service = ExportService()
|
|
||||||
pdf_generator = PDFGenerator()
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("", summary="Export OCR results")
|
|
||||||
async def export_results(
|
|
||||||
request: ExportRequest,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Export OCR results in specified format
|
|
||||||
|
|
||||||
Supports multiple export formats: txt, json, excel, markdown, pdf, zip
|
|
||||||
|
|
||||||
- **batch_id**: Batch ID to export
|
|
||||||
- **format**: Export format (txt, json, excel, markdown, pdf, zip)
|
|
||||||
- **rule_id**: Optional export rule ID to apply filters
|
|
||||||
- **css_template**: CSS template for PDF export (default, academic, business)
|
|
||||||
- **include_formats**: Formats to include in ZIP export
|
|
||||||
"""
|
|
||||||
# Verify batch ownership
|
|
||||||
batch = db.query(OCRBatch).filter(
|
|
||||||
OCRBatch.id == request.batch_id,
|
|
||||||
OCRBatch.user_id == current_user.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not batch:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Batch not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get completed results
|
|
||||||
results = db.query(OCRResult).join(OCRFile).filter(
|
|
||||||
OCRFile.batch_id == request.batch_id,
|
|
||||||
OCRFile.status == FileStatus.COMPLETED
|
|
||||||
).all()
|
|
||||||
|
|
||||||
if not results:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="No completed results found for this batch"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply export rule if specified
|
|
||||||
if request.rule_id:
|
|
||||||
try:
|
|
||||||
results = export_service.apply_export_rule(db, results, request.rule_id)
|
|
||||||
except ExportError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Generate export based on format
|
|
||||||
export_dir = Path(f"uploads/batches/{batch.id}/exports")
|
|
||||||
export_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
if request.format == "txt":
|
|
||||||
output_path = export_dir / f"batch_{batch.id}_export.txt"
|
|
||||||
export_service.export_to_txt(results, output_path)
|
|
||||||
|
|
||||||
elif request.format == "json":
|
|
||||||
output_path = export_dir / f"batch_{batch.id}_export.json"
|
|
||||||
export_service.export_to_json(results, output_path)
|
|
||||||
|
|
||||||
elif request.format == "excel":
|
|
||||||
output_path = export_dir / f"batch_{batch.id}_export.xlsx"
|
|
||||||
export_service.export_to_excel(results, output_path)
|
|
||||||
|
|
||||||
elif request.format == "markdown":
|
|
||||||
output_path = export_dir / f"batch_{batch.id}_export.md"
|
|
||||||
export_service.export_to_markdown(results, output_path, combine=True)
|
|
||||||
|
|
||||||
elif request.format == "zip":
|
|
||||||
output_path = export_dir / f"batch_{batch.id}_export.zip"
|
|
||||||
include_formats = request.include_formats or ["markdown", "json"]
|
|
||||||
export_service.export_batch_to_zip(db, batch.id, output_path, include_formats)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Unsupported export format: {request.format}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Exported batch {batch.id} to {request.format} format: {output_path}")
|
|
||||||
|
|
||||||
# Return file for download
|
|
||||||
return FileResponse(
|
|
||||||
path=str(output_path),
|
|
||||||
filename=output_path.name,
|
|
||||||
media_type="application/octet-stream"
|
|
||||||
)
|
|
||||||
|
|
||||||
except ExportError as e:
|
|
||||||
logger.error(f"Export error: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=str(e)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected export error: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Export failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/pdf/{file_id}", summary="Generate PDF for single file")
|
|
||||||
async def generate_pdf(
|
|
||||||
file_id: int,
|
|
||||||
css_template: str = "default",
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Generate layout-preserved PDF for a single file
|
|
||||||
|
|
||||||
- **file_id**: File ID
|
|
||||||
- **css_template**: CSS template (default, academic, business)
|
|
||||||
"""
|
|
||||||
# Get file and verify ownership
|
|
||||||
ocr_file = db.query(OCRFile).join(OCRBatch).filter(
|
|
||||||
OCRFile.id == file_id,
|
|
||||||
OCRBatch.user_id == current_user.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not ocr_file:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="File not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get result
|
|
||||||
result = db.query(OCRResult).filter(OCRResult.file_id == file_id).first()
|
|
||||||
if not result:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="OCR result not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Generate PDF
|
|
||||||
export_dir = Path(f"uploads/batches/{ocr_file.batch_id}/exports")
|
|
||||||
export_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
output_path = export_dir / f"file_{file_id}_export.pdf"
|
|
||||||
|
|
||||||
export_service.export_to_pdf(
|
|
||||||
result=result,
|
|
||||||
output_path=output_path,
|
|
||||||
css_template=css_template,
|
|
||||||
metadata={"title": ocr_file.original_filename}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Generated PDF for file {file_id}: {output_path}")
|
|
||||||
|
|
||||||
return FileResponse(
|
|
||||||
path=str(output_path),
|
|
||||||
filename=f"{Path(ocr_file.original_filename).stem}.pdf",
|
|
||||||
media_type="application/pdf"
|
|
||||||
)
|
|
||||||
|
|
||||||
except ExportError as e:
|
|
||||||
logger.error(f"PDF generation error: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/rules", response_model=List[ExportRuleResponse], summary="List export rules")
|
|
||||||
async def list_export_rules(
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
List all export rules for current user
|
|
||||||
|
|
||||||
Returns list of saved export rules
|
|
||||||
"""
|
|
||||||
rules = db.query(ExportRule).filter(ExportRule.user_id == current_user.id).all()
|
|
||||||
return rules
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/rules", response_model=ExportRuleResponse, summary="Create export rule")
|
|
||||||
async def create_export_rule(
|
|
||||||
rule: ExportRuleCreate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create new export rule
|
|
||||||
|
|
||||||
Saves custom export configuration for reuse
|
|
||||||
|
|
||||||
- **rule_name**: Rule name
|
|
||||||
- **description**: Optional description
|
|
||||||
- **config_json**: Rule configuration (filters, formatting, export_options)
|
|
||||||
- **css_template**: Optional custom CSS for PDF export
|
|
||||||
"""
|
|
||||||
# Create rule
|
|
||||||
new_rule = ExportRule(
|
|
||||||
user_id=current_user.id,
|
|
||||||
rule_name=rule.rule_name,
|
|
||||||
description=rule.description,
|
|
||||||
config_json=rule.config_json,
|
|
||||||
css_template=rule.css_template
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(new_rule)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(new_rule)
|
|
||||||
|
|
||||||
logger.info(f"Created export rule {new_rule.id} for user {current_user.id}")
|
|
||||||
|
|
||||||
return new_rule
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/rules/{rule_id}", response_model=ExportRuleResponse, summary="Update export rule")
|
|
||||||
async def update_export_rule(
|
|
||||||
rule_id: int,
|
|
||||||
rule: ExportRuleUpdate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update existing export rule
|
|
||||||
|
|
||||||
- **rule_id**: Rule ID to update
|
|
||||||
- **rule_name**: Optional new rule name
|
|
||||||
- **description**: Optional new description
|
|
||||||
- **config_json**: Optional new configuration
|
|
||||||
- **css_template**: Optional new CSS template
|
|
||||||
"""
|
|
||||||
# Get rule and verify ownership
|
|
||||||
db_rule = db.query(ExportRule).filter(
|
|
||||||
ExportRule.id == rule_id,
|
|
||||||
ExportRule.user_id == current_user.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not db_rule:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Export rule not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update fields
|
|
||||||
update_data = rule.dict(exclude_unset=True)
|
|
||||||
for field, value in update_data.items():
|
|
||||||
setattr(db_rule, field, value)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db.refresh(db_rule)
|
|
||||||
|
|
||||||
logger.info(f"Updated export rule {rule_id}")
|
|
||||||
|
|
||||||
return db_rule
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/rules/{rule_id}", summary="Delete export rule")
|
|
||||||
async def delete_export_rule(
|
|
||||||
rule_id: int,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Delete export rule
|
|
||||||
|
|
||||||
- **rule_id**: Rule ID to delete
|
|
||||||
"""
|
|
||||||
# Get rule and verify ownership
|
|
||||||
db_rule = db.query(ExportRule).filter(
|
|
||||||
ExportRule.id == rule_id,
|
|
||||||
ExportRule.user_id == current_user.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not db_rule:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Export rule not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
db.delete(db_rule)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
logger.info(f"Deleted export rule {rule_id}")
|
|
||||||
|
|
||||||
return {"message": "Export rule deleted successfully"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/css-templates", response_model=List[CSSTemplateResponse], summary="List CSS templates")
|
|
||||||
async def list_css_templates():
|
|
||||||
"""
|
|
||||||
List available CSS templates for PDF generation
|
|
||||||
|
|
||||||
Returns list of predefined CSS templates with descriptions
|
|
||||||
"""
|
|
||||||
templates = pdf_generator.get_available_templates()
|
|
||||||
|
|
||||||
return [
|
|
||||||
{"name": name, "description": desc}
|
|
||||||
for name, desc in templates.items()
|
|
||||||
]
|
|
||||||
@@ -1,244 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - OCR Router
|
|
||||||
File upload, OCR processing, and status endpoints
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, BackgroundTasks
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.deps import get_db, get_current_active_user
|
|
||||||
from app.models.user import User
|
|
||||||
from app.models.ocr import OCRBatch, OCRFile, OCRResult, BatchStatus, FileStatus
|
|
||||||
from app.schemas.ocr import (
|
|
||||||
OCRBatchResponse,
|
|
||||||
BatchStatusResponse,
|
|
||||||
FileStatusResponse,
|
|
||||||
OCRResultDetailResponse,
|
|
||||||
UploadBatchResponse,
|
|
||||||
ProcessRequest,
|
|
||||||
ProcessResponse,
|
|
||||||
)
|
|
||||||
from app.services.file_manager import FileManager, FileManagementError
|
|
||||||
from app.services.ocr_service import OCRService
|
|
||||||
from app.services.background_tasks import process_batch_files_with_retry
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1", tags=["OCR"])
|
|
||||||
|
|
||||||
# Initialize services
|
|
||||||
file_manager = FileManager()
|
|
||||||
ocr_service = OCRService()
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/upload", response_model=UploadBatchResponse, summary="Upload files for OCR")
|
|
||||||
async def upload_files(
|
|
||||||
files: List[UploadFile] = File(..., description="Files to upload (PNG, JPG, PDF)"),
|
|
||||||
batch_name: str = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Upload files for OCR processing
|
|
||||||
|
|
||||||
Creates a new batch and uploads files to it
|
|
||||||
|
|
||||||
- **files**: List of files to upload (PNG, JPG, JPEG, PDF)
|
|
||||||
- **batch_name**: Optional name for the batch
|
|
||||||
"""
|
|
||||||
if not files:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="No files provided"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create batch
|
|
||||||
batch = file_manager.create_batch(db, current_user.id, batch_name)
|
|
||||||
|
|
||||||
# Upload files
|
|
||||||
uploaded_files = file_manager.add_files_to_batch(db, batch.id, files)
|
|
||||||
|
|
||||||
logger.info(f"Uploaded {len(uploaded_files)} files to batch {batch.id} for user {current_user.id}")
|
|
||||||
|
|
||||||
# Refresh batch to get updated counts
|
|
||||||
db.refresh(batch)
|
|
||||||
|
|
||||||
# Return response matching frontend expectations
|
|
||||||
return {
|
|
||||||
"batch_id": batch.id,
|
|
||||||
"files": uploaded_files
|
|
||||||
}
|
|
||||||
|
|
||||||
except FileManagementError as e:
|
|
||||||
logger.error(f"File upload error: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=str(e)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error during upload: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to upload files"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: process_batch_files function moved to app.services.background_tasks
|
|
||||||
# Now using process_batch_files_with_retry with retry logic
|
|
||||||
|
|
||||||
@router.post("/ocr/process", response_model=ProcessResponse, summary="Trigger OCR processing")
|
|
||||||
async def process_ocr(
|
|
||||||
request: ProcessRequest,
|
|
||||||
background_tasks: BackgroundTasks,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Trigger OCR processing for a batch
|
|
||||||
|
|
||||||
Starts background processing of all files in the batch
|
|
||||||
|
|
||||||
- **batch_id**: Batch ID to process
|
|
||||||
- **lang**: Language code (ch, en, japan, korean)
|
|
||||||
- **detect_layout**: Enable layout detection
|
|
||||||
"""
|
|
||||||
# Verify batch ownership
|
|
||||||
batch = db.query(OCRBatch).filter(
|
|
||||||
OCRBatch.id == request.batch_id,
|
|
||||||
OCRBatch.user_id == current_user.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not batch:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Batch not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch.status != BatchStatus.PENDING:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Batch is already {batch.status.value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start background processing with retry logic
|
|
||||||
background_tasks.add_task(
|
|
||||||
process_batch_files_with_retry,
|
|
||||||
batch_id=batch.id,
|
|
||||||
lang=request.lang,
|
|
||||||
detect_layout=request.detect_layout,
|
|
||||||
db=SessionLocal() # Create new session for background task
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Started OCR processing for batch {batch.id}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"message": "OCR processing started",
|
|
||||||
"batch_id": batch.id,
|
|
||||||
"total_files": batch.total_files,
|
|
||||||
"status": "processing"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/batch/{batch_id}/status", response_model=BatchStatusResponse, summary="Get batch status")
|
|
||||||
async def get_batch_status(
|
|
||||||
batch_id: int,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get batch processing status
|
|
||||||
|
|
||||||
Returns batch information and all files in the batch
|
|
||||||
|
|
||||||
- **batch_id**: Batch ID
|
|
||||||
"""
|
|
||||||
# Verify batch ownership
|
|
||||||
batch = db.query(OCRBatch).filter(
|
|
||||||
OCRBatch.id == batch_id,
|
|
||||||
OCRBatch.user_id == current_user.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not batch:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Batch not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get all files in batch
|
|
||||||
files = db.query(OCRFile).filter(OCRFile.batch_id == batch_id).all()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"batch": batch,
|
|
||||||
"files": files
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/ocr/result/{file_id}", response_model=OCRResultDetailResponse, summary="Get OCR result")
|
|
||||||
async def get_ocr_result(
|
|
||||||
file_id: int,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get OCR result for a file
|
|
||||||
|
|
||||||
Returns flattened file and OCR result information for frontend preview
|
|
||||||
|
|
||||||
- **file_id**: File ID
|
|
||||||
"""
|
|
||||||
# Get file
|
|
||||||
ocr_file = db.query(OCRFile).join(OCRBatch).filter(
|
|
||||||
OCRFile.id == file_id,
|
|
||||||
OCRBatch.user_id == current_user.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not ocr_file:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="File not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get result if exists
|
|
||||||
result = db.query(OCRResult).filter(OCRResult.file_id == file_id).first()
|
|
||||||
|
|
||||||
# Read markdown content if result exists
|
|
||||||
markdown_content = None
|
|
||||||
if result and result.markdown_path:
|
|
||||||
markdown_file = Path(result.markdown_path)
|
|
||||||
if markdown_file.exists():
|
|
||||||
try:
|
|
||||||
markdown_content = markdown_file.read_text(encoding='utf-8')
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to read markdown file {result.markdown_path}: {e}")
|
|
||||||
|
|
||||||
# Build JSON data from result if available
|
|
||||||
json_data = None
|
|
||||||
if result:
|
|
||||||
json_data = {
|
|
||||||
"total_text_regions": result.total_text_regions,
|
|
||||||
"average_confidence": result.average_confidence,
|
|
||||||
"detected_language": result.detected_language,
|
|
||||||
"layout_data": result.layout_data,
|
|
||||||
"images_metadata": result.images_metadata,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Return flattened structure matching frontend expectations
|
|
||||||
return {
|
|
||||||
"file_id": ocr_file.id,
|
|
||||||
"filename": ocr_file.filename,
|
|
||||||
"status": ocr_file.status.value,
|
|
||||||
"markdown_content": markdown_content,
|
|
||||||
"json_data": json_data,
|
|
||||||
"confidence": result.average_confidence if result else None,
|
|
||||||
"processing_time": ocr_file.processing_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Import SessionLocal for background tasks
|
|
||||||
from app.core.database import SessionLocal
|
|
||||||
@@ -10,8 +10,8 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.deps import get_db, get_current_user_v2
|
from app.core.deps import get_db, get_current_user
|
||||||
from app.models.user_v2 import User
|
from app.models.user import User
|
||||||
from app.models.task import TaskStatus
|
from app.models.task import TaskStatus
|
||||||
from app.schemas.task import (
|
from app.schemas.task import (
|
||||||
TaskCreate,
|
TaskCreate,
|
||||||
@@ -34,7 +34,7 @@ router = APIRouter(prefix="/api/v2/tasks", tags=["Tasks"])
|
|||||||
async def create_task(
|
async def create_task(
|
||||||
task_data: TaskCreate,
|
task_data: TaskCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new OCR task
|
Create a new OCR task
|
||||||
@@ -72,7 +72,7 @@ async def list_tasks(
|
|||||||
order_by: str = Query("created_at"),
|
order_by: str = Query("created_at"),
|
||||||
order_desc: bool = Query(True),
|
order_desc: bool = Query(True),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
List user's tasks with pagination and filtering
|
List user's tasks with pagination and filtering
|
||||||
@@ -134,7 +134,7 @@ async def list_tasks(
|
|||||||
@router.get("/stats", response_model=TaskStatsResponse)
|
@router.get("/stats", response_model=TaskStatsResponse)
|
||||||
async def get_task_stats(
|
async def get_task_stats(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get task statistics for current user
|
Get task statistics for current user
|
||||||
@@ -157,7 +157,7 @@ async def get_task_stats(
|
|||||||
async def get_task(
|
async def get_task(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get task details by ID
|
Get task details by ID
|
||||||
@@ -184,7 +184,7 @@ async def update_task(
|
|||||||
task_id: str,
|
task_id: str,
|
||||||
task_update: TaskUpdate,
|
task_update: TaskUpdate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Update task status and results
|
Update task status and results
|
||||||
@@ -253,7 +253,7 @@ async def update_task(
|
|||||||
async def delete_task(
|
async def delete_task(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Delete a task
|
Delete a task
|
||||||
@@ -280,7 +280,7 @@ async def delete_task(
|
|||||||
async def download_json(
|
async def download_json(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Download task result as JSON file
|
Download task result as JSON file
|
||||||
@@ -327,7 +327,7 @@ async def download_json(
|
|||||||
async def download_markdown(
|
async def download_markdown(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Download task result as Markdown file
|
Download task result as Markdown file
|
||||||
@@ -374,7 +374,7 @@ async def download_markdown(
|
|||||||
async def download_pdf(
|
async def download_pdf(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Download task result as searchable PDF file
|
Download task result as searchable PDF file
|
||||||
@@ -421,7 +421,7 @@ async def download_pdf(
|
|||||||
async def start_task(
|
async def start_task(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Start processing a pending task
|
Start processing a pending task
|
||||||
@@ -459,7 +459,7 @@ async def start_task(
|
|||||||
async def cancel_task(
|
async def cancel_task(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Cancel a pending or processing task
|
Cancel a pending or processing task
|
||||||
@@ -513,7 +513,7 @@ async def cancel_task(
|
|||||||
async def retry_task(
|
async def retry_task(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user_v2)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Retry a failed task
|
Retry a failed task
|
||||||
|
|||||||
@@ -1,189 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - Translation Router (RESERVED)
|
|
||||||
Stub endpoints for future translation feature
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.deps import get_db, get_current_active_user
|
|
||||||
from app.models.user import User
|
|
||||||
from app.schemas.translation import (
|
|
||||||
TranslationRequest,
|
|
||||||
TranslationResponse,
|
|
||||||
TranslationFeatureStatus,
|
|
||||||
LanguageInfo,
|
|
||||||
)
|
|
||||||
from app.services.translation_service import StubTranslationService
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/translate", tags=["Translation (RESERVED)"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/status", response_model=TranslationFeatureStatus, summary="Get translation feature status")
|
|
||||||
async def get_translation_status():
|
|
||||||
"""
|
|
||||||
Get translation feature status
|
|
||||||
|
|
||||||
Returns current implementation status and roadmap for translation feature.
|
|
||||||
This is a RESERVED feature that will be implemented in Phase 5.
|
|
||||||
|
|
||||||
**Status**: RESERVED - Not yet implemented
|
|
||||||
**Phase**: Phase 5 (Post-production)
|
|
||||||
**Priority**: Implemented after production deployment and user feedback
|
|
||||||
"""
|
|
||||||
return StubTranslationService.get_feature_status()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/languages", response_model=List[LanguageInfo], summary="Get supported languages")
|
|
||||||
async def get_supported_languages():
|
|
||||||
"""
|
|
||||||
Get list of languages planned for translation support
|
|
||||||
|
|
||||||
Returns list of languages that will be supported when translation
|
|
||||||
feature is implemented.
|
|
||||||
|
|
||||||
**Status**: RESERVED - Planning phase
|
|
||||||
"""
|
|
||||||
return StubTranslationService.get_supported_languages()
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/document", response_model=TranslationResponse, summary="Translate document (RESERVED)")
|
|
||||||
async def translate_document(
|
|
||||||
request: TranslationRequest,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Translate OCR document (RESERVED - NOT IMPLEMENTED)
|
|
||||||
|
|
||||||
This endpoint is reserved for future translation functionality.
|
|
||||||
Returns 501 Not Implemented status.
|
|
||||||
|
|
||||||
**Expected Functionality** (when implemented):
|
|
||||||
- Translate markdown documents while preserving structure
|
|
||||||
- Support multiple translation engines (offline, ERNIE, Google, DeepL)
|
|
||||||
- Maintain layout and formatting
|
|
||||||
- Handle technical terminology
|
|
||||||
|
|
||||||
**Planned Features**:
|
|
||||||
- Offline translation (Argos Translate)
|
|
||||||
- Cloud API integration (ERNIE, Google, DeepL)
|
|
||||||
- Batch translation support
|
|
||||||
- Translation memory
|
|
||||||
- Glossary support
|
|
||||||
|
|
||||||
**Current Status**: RESERVED for Phase 5 implementation
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Request Parameters** (planned):
|
|
||||||
- **file_id**: ID of OCR result file to translate
|
|
||||||
- **source_lang**: Source language code (zh, en, ja, ko)
|
|
||||||
- **target_lang**: Target language code (zh, en, ja, ko)
|
|
||||||
- **engine_type**: Translation engine (offline, ernie, google, deepl)
|
|
||||||
- **preserve_structure**: Whether to preserve markdown structure
|
|
||||||
- **engine_config**: Engine-specific configuration
|
|
||||||
|
|
||||||
**Response** (planned):
|
|
||||||
- **task_id**: Translation task ID for tracking progress
|
|
||||||
- **status**: Translation status
|
|
||||||
- **translated_file_path**: Path to translated file (when completed)
|
|
||||||
"""
|
|
||||||
logger.info(f"Translation request received from user {current_user.id} (stub endpoint)")
|
|
||||||
|
|
||||||
# Return 501 Not Implemented with informative message
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
||||||
detail={
|
|
||||||
"error": "Translation feature not implemented",
|
|
||||||
"message": "This feature is reserved for future development (Phase 5)",
|
|
||||||
"status": "RESERVED",
|
|
||||||
"roadmap": {
|
|
||||||
"phase": "Phase 5",
|
|
||||||
"priority": "Implemented after production deployment",
|
|
||||||
"planned_features": [
|
|
||||||
"Offline translation (Argos Translate)",
|
|
||||||
"Cloud API integration (ERNIE, Google, DeepL)",
|
|
||||||
"Structure-preserving markdown translation",
|
|
||||||
"Batch translation support"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"request_received": {
|
|
||||||
"file_id": request.file_id,
|
|
||||||
"source_lang": request.source_lang,
|
|
||||||
"target_lang": request.target_lang,
|
|
||||||
"engine_type": request.engine_type
|
|
||||||
},
|
|
||||||
"action": "Please check back in a future release or contact support for updates"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/task/{task_id}", summary="Get translation task status (RESERVED)")
|
|
||||||
async def get_translation_task_status(
|
|
||||||
task_id: int,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get translation task status (RESERVED - NOT IMPLEMENTED)
|
|
||||||
|
|
||||||
This endpoint would track translation task progress.
|
|
||||||
Returns 501 Not Implemented status.
|
|
||||||
|
|
||||||
**Planned Functionality**:
|
|
||||||
- Real-time translation progress
|
|
||||||
- Status updates (pending, processing, completed, failed)
|
|
||||||
- Estimated completion time
|
|
||||||
- Error reporting
|
|
||||||
|
|
||||||
**Current Status**: RESERVED for Phase 5 implementation
|
|
||||||
"""
|
|
||||||
logger.info(f"Translation status check for task {task_id} from user {current_user.id} (stub endpoint)")
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
||||||
detail={
|
|
||||||
"error": "Translation feature not implemented",
|
|
||||||
"message": "Translation task tracking is reserved for Phase 5",
|
|
||||||
"task_id": task_id,
|
|
||||||
"status": "RESERVED"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/task/{task_id}", summary="Cancel translation task (RESERVED)")
|
|
||||||
async def cancel_translation_task(
|
|
||||||
task_id: int,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Cancel ongoing translation task (RESERVED - NOT IMPLEMENTED)
|
|
||||||
|
|
||||||
This endpoint would allow cancellation of translation tasks.
|
|
||||||
Returns 501 Not Implemented status.
|
|
||||||
|
|
||||||
**Planned Functionality**:
|
|
||||||
- Cancel in-progress translations
|
|
||||||
- Clean up temporary files
|
|
||||||
- Refund credits (if applicable)
|
|
||||||
|
|
||||||
**Current Status**: RESERVED for Phase 5 implementation
|
|
||||||
"""
|
|
||||||
logger.info(f"Translation cancellation request for task {task_id} from user {current_user.id} (stub endpoint)")
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
||||||
detail={
|
|
||||||
"error": "Translation feature not implemented",
|
|
||||||
"message": "This feature is reserved for Phase 5",
|
|
||||||
"status": "RESERVED"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@@ -1,59 +1,30 @@
|
|||||||
"""
|
"""
|
||||||
Tool_OCR - API Schemas
|
Tool_OCR - API Schemas (V2)
|
||||||
Pydantic models for request/response validation
|
Pydantic models for request/response validation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.schemas.auth import Token, TokenData, LoginRequest
|
from app.schemas.auth import LoginRequest, Token, UserResponse
|
||||||
from app.schemas.user import UserBase, UserCreate, UserResponse
|
from app.schemas.task import (
|
||||||
from app.schemas.ocr import (
|
TaskCreate,
|
||||||
OCRBatchResponse,
|
TaskUpdate,
|
||||||
OCRFileResponse,
|
TaskResponse,
|
||||||
OCRResultResponse,
|
TaskDetailResponse,
|
||||||
BatchStatusResponse,
|
TaskListResponse,
|
||||||
FileStatusResponse,
|
TaskStatsResponse,
|
||||||
ProcessRequest,
|
TaskStatusEnum,
|
||||||
ProcessResponse,
|
|
||||||
)
|
|
||||||
from app.schemas.export import (
|
|
||||||
ExportRequest,
|
|
||||||
ExportRuleCreate,
|
|
||||||
ExportRuleUpdate,
|
|
||||||
ExportRuleResponse,
|
|
||||||
CSSTemplateResponse,
|
|
||||||
)
|
|
||||||
from app.schemas.translation import (
|
|
||||||
TranslationRequest,
|
|
||||||
TranslationResponse,
|
|
||||||
TranslationFeatureStatus,
|
|
||||||
LanguageInfo,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Auth
|
# Auth
|
||||||
"Token",
|
|
||||||
"TokenData",
|
|
||||||
"LoginRequest",
|
"LoginRequest",
|
||||||
# User
|
"Token",
|
||||||
"UserBase",
|
|
||||||
"UserCreate",
|
|
||||||
"UserResponse",
|
"UserResponse",
|
||||||
# OCR
|
# Task
|
||||||
"OCRBatchResponse",
|
"TaskCreate",
|
||||||
"OCRFileResponse",
|
"TaskUpdate",
|
||||||
"OCRResultResponse",
|
"TaskResponse",
|
||||||
"BatchStatusResponse",
|
"TaskDetailResponse",
|
||||||
"FileStatusResponse",
|
"TaskListResponse",
|
||||||
"ProcessRequest",
|
"TaskStatsResponse",
|
||||||
"ProcessResponse",
|
"TaskStatusEnum",
|
||||||
# Export
|
|
||||||
"ExportRequest",
|
|
||||||
"ExportRuleCreate",
|
|
||||||
"ExportRuleUpdate",
|
|
||||||
"ExportRuleResponse",
|
|
||||||
"CSSTemplateResponse",
|
|
||||||
# Translation (RESERVED)
|
|
||||||
"TranslationRequest",
|
|
||||||
"TranslationResponse",
|
|
||||||
"TranslationFeatureStatus",
|
|
||||||
"LanguageInfo",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,104 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - Export Schemas
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional, Dict, Any, List
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class ExportOptions(BaseModel):
|
|
||||||
"""Export options schema"""
|
|
||||||
confidence_threshold: Optional[float] = Field(None, description="Minimum confidence threshold")
|
|
||||||
include_metadata: Optional[bool] = Field(True, description="Include metadata in export")
|
|
||||||
filename_pattern: Optional[str] = Field(None, description="Filename pattern for export")
|
|
||||||
css_template: Optional[str] = Field(None, description="CSS template for PDF export")
|
|
||||||
|
|
||||||
|
|
||||||
class ExportRequest(BaseModel):
|
|
||||||
"""Export request schema"""
|
|
||||||
batch_id: int = Field(..., description="Batch ID to export")
|
|
||||||
format: str = Field(..., description="Export format (txt, json, excel, markdown, pdf, zip)")
|
|
||||||
rule_id: Optional[int] = Field(None, description="Optional export rule ID to apply")
|
|
||||||
css_template: Optional[str] = Field("default", description="CSS template for PDF export")
|
|
||||||
include_formats: Optional[List[str]] = Field(None, description="Formats to include in ZIP export")
|
|
||||||
options: Optional[ExportOptions] = Field(None, description="Additional export options")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
json_schema_extra = {
|
|
||||||
"example": {
|
|
||||||
"batch_id": 1,
|
|
||||||
"format": "pdf",
|
|
||||||
"rule_id": None,
|
|
||||||
"css_template": "default",
|
|
||||||
"include_formats": ["markdown", "json"],
|
|
||||||
"options": {
|
|
||||||
"confidence_threshold": 0.8,
|
|
||||||
"include_metadata": True
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ExportRuleCreate(BaseModel):
|
|
||||||
"""Export rule creation schema"""
|
|
||||||
rule_name: str = Field(..., max_length=100, description="Rule name")
|
|
||||||
description: Optional[str] = Field(None, description="Rule description")
|
|
||||||
config_json: Dict[str, Any] = Field(..., description="Rule configuration as JSON")
|
|
||||||
css_template: Optional[str] = Field(None, description="Custom CSS template")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
json_schema_extra = {
|
|
||||||
"example": {
|
|
||||||
"rule_name": "High Confidence Only",
|
|
||||||
"description": "Export only results with confidence > 0.8",
|
|
||||||
"config_json": {
|
|
||||||
"filters": {
|
|
||||||
"confidence_threshold": 0.8
|
|
||||||
},
|
|
||||||
"formatting": {
|
|
||||||
"add_line_numbers": True
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"css_template": None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ExportRuleUpdate(BaseModel):
|
|
||||||
"""Export rule update schema"""
|
|
||||||
rule_name: Optional[str] = Field(None, max_length=100)
|
|
||||||
description: Optional[str] = None
|
|
||||||
config_json: Optional[Dict[str, Any]] = None
|
|
||||||
css_template: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExportRuleResponse(BaseModel):
|
|
||||||
"""Export rule response schema"""
|
|
||||||
id: int
|
|
||||||
user_id: int
|
|
||||||
rule_name: str
|
|
||||||
description: Optional[str] = None
|
|
||||||
config_json: Dict[str, Any]
|
|
||||||
css_template: Optional[str] = None
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
class CSSTemplateResponse(BaseModel):
|
|
||||||
"""CSS template response schema"""
|
|
||||||
name: str = Field(..., description="Template name")
|
|
||||||
description: str = Field(..., description="Template description")
|
|
||||||
filename: str = Field(..., description="Template filename")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
json_schema_extra = {
|
|
||||||
"example": {
|
|
||||||
"name": "default",
|
|
||||||
"description": "通用排版模板,適合大多數文檔",
|
|
||||||
"filename": "default.css"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,151 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - OCR Schemas
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional, Dict, List, Any
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from app.models.ocr import BatchStatus, FileStatus
|
|
||||||
|
|
||||||
|
|
||||||
class OCRFileResponse(BaseModel):
|
|
||||||
"""OCR file response schema"""
|
|
||||||
id: int
|
|
||||||
batch_id: int
|
|
||||||
filename: str
|
|
||||||
original_filename: str
|
|
||||||
file_size: int
|
|
||||||
file_format: str
|
|
||||||
status: FileStatus
|
|
||||||
error: Optional[str] = Field(None, validation_alias='error_message') # Map from error_message to error
|
|
||||||
created_at: datetime
|
|
||||||
processing_time: Optional[float] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
from_attributes = True
|
|
||||||
populate_by_name = True
|
|
||||||
|
|
||||||
|
|
||||||
class OCRResultResponse(BaseModel):
|
|
||||||
"""OCR result response schema"""
|
|
||||||
id: int
|
|
||||||
file_id: int
|
|
||||||
markdown_path: Optional[str] = None
|
|
||||||
markdown_content: Optional[str] = None # Added for frontend preview
|
|
||||||
json_path: Optional[str] = None
|
|
||||||
images_dir: Optional[str] = None
|
|
||||||
detected_language: Optional[str] = None
|
|
||||||
total_text_regions: int
|
|
||||||
average_confidence: Optional[float] = None
|
|
||||||
layout_data: Optional[Dict[str, Any]] = None
|
|
||||||
images_metadata: Optional[List[Dict[str, Any]]] = None
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
class OCRBatchResponse(BaseModel):
|
|
||||||
"""OCR batch response schema"""
|
|
||||||
id: int
|
|
||||||
user_id: int
|
|
||||||
batch_name: Optional[str] = None
|
|
||||||
status: BatchStatus
|
|
||||||
total_files: int
|
|
||||||
completed_files: int
|
|
||||||
failed_files: int
|
|
||||||
progress_percentage: float
|
|
||||||
created_at: datetime
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
class BatchStatusResponse(BaseModel):
|
|
||||||
"""Batch status with file details"""
|
|
||||||
batch: OCRBatchResponse
|
|
||||||
files: List[OCRFileResponse]
|
|
||||||
|
|
||||||
|
|
||||||
class FileStatusResponse(BaseModel):
|
|
||||||
"""File status with result details"""
|
|
||||||
file: OCRFileResponse
|
|
||||||
result: Optional[OCRResultResponse] = None
|
|
||||||
|
|
||||||
|
|
||||||
class OCRResultDetailResponse(BaseModel):
|
|
||||||
"""OCR result detail response for frontend preview - flattened structure"""
|
|
||||||
file_id: int
|
|
||||||
filename: str
|
|
||||||
status: str
|
|
||||||
markdown_content: Optional[str] = None
|
|
||||||
json_data: Optional[Dict[str, Any]] = None
|
|
||||||
confidence: Optional[float] = None
|
|
||||||
processing_time: Optional[float] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
class UploadBatchResponse(BaseModel):
|
|
||||||
"""Upload response schema matching frontend expectations"""
|
|
||||||
batch_id: int = Field(..., description="Batch ID")
|
|
||||||
files: List[OCRFileResponse] = Field(..., description="Uploaded files")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
json_schema_extra = {
|
|
||||||
"example": {
|
|
||||||
"batch_id": 1,
|
|
||||||
"files": [
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"batch_id": 1,
|
|
||||||
"filename": "doc_1.png",
|
|
||||||
"original_filename": "document.png",
|
|
||||||
"file_size": 1024000,
|
|
||||||
"file_format": "png",
|
|
||||||
"status": "pending",
|
|
||||||
"error_message": None,
|
|
||||||
"created_at": "2025-01-01T00:00:00",
|
|
||||||
"processing_time": None
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessRequest(BaseModel):
|
|
||||||
"""OCR process request schema"""
|
|
||||||
batch_id: int = Field(..., description="Batch ID to process")
|
|
||||||
lang: str = Field(default="ch", description="Language code (ch, en, japan, korean)")
|
|
||||||
detect_layout: bool = Field(default=True, description="Enable layout detection")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
json_schema_extra = {
|
|
||||||
"example": {
|
|
||||||
"batch_id": 1,
|
|
||||||
"lang": "ch",
|
|
||||||
"detect_layout": True
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessResponse(BaseModel):
|
|
||||||
"""OCR process response schema"""
|
|
||||||
message: str
|
|
||||||
batch_id: int
|
|
||||||
total_files: int
|
|
||||||
status: str
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
json_schema_extra = {
|
|
||||||
"example": {
|
|
||||||
"message": "OCR processing started",
|
|
||||||
"batch_id": 1,
|
|
||||||
"total_files": 5,
|
|
||||||
"status": "processing"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - Translation Schemas (RESERVED)
|
|
||||||
Request/response models for translation endpoints
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional, Dict, List, Any
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationRequest(BaseModel):
|
|
||||||
"""
|
|
||||||
Translation request schema (RESERVED)
|
|
||||||
|
|
||||||
Expected format for document translation requests
|
|
||||||
"""
|
|
||||||
file_id: int = Field(..., description="File ID to translate")
|
|
||||||
source_lang: str = Field(..., description="Source language code (zh, en, ja, ko)")
|
|
||||||
target_lang: str = Field(..., description="Target language code (zh, en, ja, ko)")
|
|
||||||
engine_type: Optional[str] = Field("offline", description="Translation engine (offline, ernie, google, deepl)")
|
|
||||||
preserve_structure: bool = Field(True, description="Preserve markdown structure")
|
|
||||||
engine_config: Optional[Dict[str, Any]] = Field(None, description="Engine-specific configuration")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
json_schema_extra = {
|
|
||||||
"example": {
|
|
||||||
"file_id": 1,
|
|
||||||
"source_lang": "zh",
|
|
||||||
"target_lang": "en",
|
|
||||||
"engine_type": "offline",
|
|
||||||
"preserve_structure": True,
|
|
||||||
"engine_config": {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationResponse(BaseModel):
|
|
||||||
"""
|
|
||||||
Translation response schema (RESERVED)
|
|
||||||
|
|
||||||
Expected format for translation results
|
|
||||||
"""
|
|
||||||
task_id: int = Field(..., description="Translation task ID")
|
|
||||||
file_id: int
|
|
||||||
source_lang: str
|
|
||||||
target_lang: str
|
|
||||||
engine_type: str
|
|
||||||
status: str = Field(..., description="Translation status (pending, processing, completed, failed)")
|
|
||||||
translated_file_path: Optional[str] = Field(None, description="Path to translated markdown file")
|
|
||||||
progress: float = Field(0.0, description="Translation progress (0.0-1.0)")
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
json_schema_extra = {
|
|
||||||
"example": {
|
|
||||||
"task_id": 1,
|
|
||||||
"file_id": 1,
|
|
||||||
"source_lang": "zh",
|
|
||||||
"target_lang": "en",
|
|
||||||
"engine_type": "offline",
|
|
||||||
"status": "processing",
|
|
||||||
"translated_file_path": None,
|
|
||||||
"progress": 0.5,
|
|
||||||
"error_message": None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationStatusResponse(BaseModel):
|
|
||||||
"""Translation task status response (RESERVED)"""
|
|
||||||
task_id: int
|
|
||||||
status: str
|
|
||||||
progress: float
|
|
||||||
created_at: str
|
|
||||||
completed_at: Optional[str] = None
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationConfigRequest(BaseModel):
|
|
||||||
"""Translation configuration request (RESERVED)"""
|
|
||||||
source_lang: str = Field(..., max_length=20)
|
|
||||||
target_lang: str = Field(..., max_length=20)
|
|
||||||
engine_type: str = Field(..., max_length=50)
|
|
||||||
engine_config: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
json_schema_extra = {
|
|
||||||
"example": {
|
|
||||||
"source_lang": "zh",
|
|
||||||
"target_lang": "en",
|
|
||||||
"engine_type": "offline",
|
|
||||||
"engine_config": {
|
|
||||||
"model_path": "/path/to/model"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationConfigResponse(BaseModel):
|
|
||||||
"""Translation configuration response (RESERVED)"""
|
|
||||||
id: int
|
|
||||||
user_id: int
|
|
||||||
source_lang: str
|
|
||||||
target_lang: str
|
|
||||||
engine_type: str
|
|
||||||
engine_config: Optional[Dict[str, Any]] = None
|
|
||||||
created_at: str
|
|
||||||
updated_at: str
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationFeatureStatus(BaseModel):
|
|
||||||
"""Translation feature status response"""
|
|
||||||
available: bool = Field(..., description="Whether translation is available")
|
|
||||||
status: str = Field(..., description="Feature status (reserved, planned, implemented)")
|
|
||||||
message: str = Field(..., description="Status message")
|
|
||||||
supported_engines: List[str] = Field(default_factory=list, description="Currently supported engines")
|
|
||||||
planned_engines: List[Dict[str, str]] = Field(default_factory=list, description="Planned engines")
|
|
||||||
roadmap: Dict[str, Any] = Field(default_factory=dict, description="Implementation roadmap")
|
|
||||||
|
|
||||||
|
|
||||||
class LanguageInfo(BaseModel):
|
|
||||||
"""Language information"""
|
|
||||||
code: str = Field(..., description="Language code (ISO 639-1)")
|
|
||||||
name: str = Field(..., description="Language name")
|
|
||||||
status: str = Field(..., description="Support status (planned, supported)")
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - User Schemas
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel, EmailStr, Field
|
|
||||||
|
|
||||||
|
|
||||||
class UserBase(BaseModel):
|
|
||||||
"""Base user schema"""
|
|
||||||
username: str = Field(..., min_length=3, max_length=50)
|
|
||||||
email: EmailStr
|
|
||||||
full_name: Optional[str] = Field(None, max_length=100)
|
|
||||||
|
|
||||||
|
|
||||||
class UserCreate(UserBase):
|
|
||||||
"""User creation schema"""
|
|
||||||
password: str = Field(..., min_length=6, description="Password (min 6 characters)")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
json_schema_extra = {
|
|
||||||
"example": {
|
|
||||||
"username": "johndoe",
|
|
||||||
"email": "john@example.com",
|
|
||||||
"full_name": "John Doe",
|
|
||||||
"password": "secret123"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(UserBase):
|
|
||||||
"""User response schema"""
|
|
||||||
id: int
|
|
||||||
is_active: bool
|
|
||||||
is_admin: bool
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
from_attributes = True
|
|
||||||
json_schema_extra = {
|
|
||||||
"example": {
|
|
||||||
"id": 1,
|
|
||||||
"username": "johndoe",
|
|
||||||
"email": "john@example.com",
|
|
||||||
"full_name": "John Doe",
|
|
||||||
"is_active": True,
|
|
||||||
"is_admin": False,
|
|
||||||
"created_at": "2025-01-01T00:00:00",
|
|
||||||
"updated_at": "2025-01-01T00:00:00"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
|||||||
from sqlalchemy import func, and_
|
from sqlalchemy import func, and_
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from app.models.user_v2 import User
|
from app.models.user import User
|
||||||
from app.models.task import Task, TaskStatus
|
from app.models.task import Task, TaskStatus
|
||||||
from app.models.session import Session as UserSession
|
from app.models.session import Session as UserSession
|
||||||
from app.models.audit_log import AuditLog
|
from app.models.audit_log import AuditLog
|
||||||
|
|||||||
@@ -1,421 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - Background Tasks Service
|
|
||||||
Handles async processing, cleanup, and scheduled tasks
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Callable, Any
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.database import SessionLocal
|
|
||||||
from app.models.ocr import OCRBatch, OCRFile, OCRResult, BatchStatus, FileStatus
|
|
||||||
from app.services.ocr_service import OCRService
|
|
||||||
from app.services.file_manager import FileManager
|
|
||||||
from app.services.pdf_generator import PDFGenerator
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BackgroundTaskManager:
|
|
||||||
"""
|
|
||||||
Manages background tasks including retry logic, cleanup, and scheduled jobs
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_retries: int = 3,
|
|
||||||
retry_delay: int = 5,
|
|
||||||
cleanup_interval: int = 3600, # 1 hour
|
|
||||||
file_retention_hours: int = 24
|
|
||||||
):
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.retry_delay = retry_delay
|
|
||||||
self.cleanup_interval = cleanup_interval
|
|
||||||
self.file_retention_hours = file_retention_hours
|
|
||||||
self.ocr_service = OCRService()
|
|
||||||
self.file_manager = FileManager()
|
|
||||||
self.pdf_generator = PDFGenerator()
|
|
||||||
|
|
||||||
async def execute_with_retry(
|
|
||||||
self,
|
|
||||||
func: Callable,
|
|
||||||
*args,
|
|
||||||
max_retries: Optional[int] = None,
|
|
||||||
retry_delay: Optional[int] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Execute a function with retry logic
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: Function to execute
|
|
||||||
args: Positional arguments for func
|
|
||||||
max_retries: Maximum retry attempts (overrides default)
|
|
||||||
retry_delay: Delay between retries in seconds (overrides default)
|
|
||||||
kwargs: Keyword arguments for func
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Function result
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If all retries are exhausted
|
|
||||||
"""
|
|
||||||
max_retries = max_retries or self.max_retries
|
|
||||||
retry_delay = retry_delay or self.retry_delay
|
|
||||||
|
|
||||||
last_exception = None
|
|
||||||
for attempt in range(max_retries + 1):
|
|
||||||
try:
|
|
||||||
if asyncio.iscoroutinefunction(func):
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
last_exception = e
|
|
||||||
if attempt < max_retries:
|
|
||||||
logger.warning(
|
|
||||||
f"Attempt {attempt + 1}/{max_retries + 1} failed for {func.__name__}: {e}. "
|
|
||||||
f"Retrying in {retry_delay}s..."
|
|
||||||
)
|
|
||||||
await asyncio.sleep(retry_delay)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"All {max_retries + 1} attempts failed for {func.__name__}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
raise last_exception
|
|
||||||
|
|
||||||
def process_single_file_with_retry(
|
|
||||||
self,
|
|
||||||
ocr_file: OCRFile,
|
|
||||||
batch_id: int,
|
|
||||||
lang: str,
|
|
||||||
detect_layout: bool,
|
|
||||||
db: Session
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Process a single file with retry logic
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ocr_file: OCRFile instance
|
|
||||||
batch_id: Batch ID
|
|
||||||
lang: Language code
|
|
||||||
detect_layout: Whether to detect layout
|
|
||||||
db: Database session
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
try:
|
|
||||||
# Update file status
|
|
||||||
ocr_file.status = FileStatus.PROCESSING
|
|
||||||
ocr_file.started_at = datetime.utcnow()
|
|
||||||
ocr_file.retry_count = attempt
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
# Get file paths
|
|
||||||
file_path = Path(ocr_file.file_path)
|
|
||||||
paths = self.file_manager.get_file_paths(batch_id, ocr_file.id)
|
|
||||||
|
|
||||||
# Process OCR
|
|
||||||
result = self.ocr_service.process_image(
|
|
||||||
file_path,
|
|
||||||
lang=lang,
|
|
||||||
detect_layout=detect_layout
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if processing was successful
|
|
||||||
if result['status'] != 'success':
|
|
||||||
raise Exception(result.get('error_message', 'Unknown error during OCR processing'))
|
|
||||||
|
|
||||||
# Save results
|
|
||||||
json_path, markdown_path = self.ocr_service.save_results(
|
|
||||||
result=result,
|
|
||||||
output_dir=paths["output_dir"],
|
|
||||||
file_id=str(ocr_file.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract data from result
|
|
||||||
text_regions = result.get('text_regions', [])
|
|
||||||
layout_data = result.get('layout_data')
|
|
||||||
images_metadata = result.get('images_metadata', [])
|
|
||||||
|
|
||||||
# Calculate average confidence (or use from result)
|
|
||||||
avg_confidence = result.get('average_confidence')
|
|
||||||
|
|
||||||
# Create OCR result record
|
|
||||||
ocr_result = OCRResult(
|
|
||||||
file_id=ocr_file.id,
|
|
||||||
markdown_path=str(markdown_path) if markdown_path else None,
|
|
||||||
json_path=str(json_path) if json_path else None,
|
|
||||||
images_dir=None, # Images dir not used in current implementation
|
|
||||||
detected_language=lang,
|
|
||||||
total_text_regions=len(text_regions),
|
|
||||||
average_confidence=avg_confidence,
|
|
||||||
layout_data=layout_data,
|
|
||||||
images_metadata=images_metadata
|
|
||||||
)
|
|
||||||
db.add(ocr_result)
|
|
||||||
|
|
||||||
# Update file status
|
|
||||||
ocr_file.status = FileStatus.COMPLETED
|
|
||||||
ocr_file.completed_at = datetime.utcnow()
|
|
||||||
ocr_file.processing_time = (ocr_file.completed_at - ocr_file.started_at).total_seconds()
|
|
||||||
|
|
||||||
# Commit with retry on connection errors
|
|
||||||
try:
|
|
||||||
db.commit()
|
|
||||||
except Exception as commit_error:
|
|
||||||
logger.warning(f"Commit failed, rolling back and retrying: {commit_error}")
|
|
||||||
db.rollback()
|
|
||||||
db.refresh(ocr_file)
|
|
||||||
ocr_file.status = FileStatus.COMPLETED
|
|
||||||
ocr_file.completed_at = datetime.utcnow()
|
|
||||||
ocr_file.processing_time = (ocr_file.completed_at - ocr_file.started_at).total_seconds()
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
logger.info(f"Successfully processed file {ocr_file.id} ({ocr_file.original_filename})")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Attempt {attempt + 1}/{self.max_retries + 1} failed for file {ocr_file.id}: {e}")
|
|
||||||
db.rollback() # Rollback failed transaction
|
|
||||||
|
|
||||||
if attempt < self.max_retries:
|
|
||||||
# Wait before retry
|
|
||||||
time.sleep(self.retry_delay)
|
|
||||||
else:
|
|
||||||
# Final failure
|
|
||||||
try:
|
|
||||||
ocr_file.status = FileStatus.FAILED
|
|
||||||
ocr_file.error_message = f"Failed after {self.max_retries + 1} attempts: {str(e)}"
|
|
||||||
ocr_file.completed_at = datetime.utcnow()
|
|
||||||
ocr_file.retry_count = attempt
|
|
||||||
db.commit()
|
|
||||||
except Exception as final_error:
|
|
||||||
logger.error(f"Failed to update error status: {final_error}")
|
|
||||||
db.rollback()
|
|
||||||
return False
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def cleanup_expired_files(self, db: Session):
|
|
||||||
"""
|
|
||||||
Clean up files and batches older than retention period
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
cutoff_time = datetime.utcnow() - timedelta(hours=self.file_retention_hours)
|
|
||||||
|
|
||||||
# Find expired batches
|
|
||||||
expired_batches = db.query(OCRBatch).filter(
|
|
||||||
OCRBatch.created_at < cutoff_time,
|
|
||||||
OCRBatch.status.in_([BatchStatus.COMPLETED, BatchStatus.FAILED, BatchStatus.PARTIAL])
|
|
||||||
).all()
|
|
||||||
|
|
||||||
logger.info(f"Found {len(expired_batches)} expired batches to clean up")
|
|
||||||
|
|
||||||
for batch in expired_batches:
|
|
||||||
try:
|
|
||||||
# Get batch directory
|
|
||||||
batch_dir = self.file_manager.base_upload_dir / "batches" / str(batch.id)
|
|
||||||
|
|
||||||
# Delete physical files
|
|
||||||
if batch_dir.exists():
|
|
||||||
import shutil
|
|
||||||
shutil.rmtree(batch_dir)
|
|
||||||
logger.info(f"Deleted batch directory: {batch_dir}")
|
|
||||||
|
|
||||||
# Delete database records
|
|
||||||
# Delete results first (foreign key constraint)
|
|
||||||
db.query(OCRResult).filter(
|
|
||||||
OCRResult.file_id.in_(
|
|
||||||
db.query(OCRFile.id).filter(OCRFile.batch_id == batch.id)
|
|
||||||
)
|
|
||||||
).delete(synchronize_session=False)
|
|
||||||
|
|
||||||
# Delete files
|
|
||||||
db.query(OCRFile).filter(OCRFile.batch_id == batch.id).delete()
|
|
||||||
|
|
||||||
# Delete batch
|
|
||||||
db.delete(batch)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
logger.info(f"Cleaned up expired batch {batch.id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error cleaning up batch {batch.id}: {e}")
|
|
||||||
db.rollback()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in cleanup_expired_files: {e}")
|
|
||||||
|
|
||||||
async def generate_pdf_background(
|
|
||||||
self,
|
|
||||||
result_id: int,
|
|
||||||
output_path: str,
|
|
||||||
css_template: str = "default",
|
|
||||||
db: Session = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Generate PDF in background with retry logic
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result_id: OCR result ID
|
|
||||||
output_path: Output PDF path
|
|
||||||
css_template: CSS template name
|
|
||||||
db: Database session
|
|
||||||
"""
|
|
||||||
should_close_db = False
|
|
||||||
if db is None:
|
|
||||||
db = SessionLocal()
|
|
||||||
should_close_db = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get result
|
|
||||||
result = db.query(OCRResult).filter(OCRResult.id == result_id).first()
|
|
||||||
if not result:
|
|
||||||
logger.error(f"Result {result_id} not found")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Generate PDF with retry
|
|
||||||
await self.execute_with_retry(
|
|
||||||
self.pdf_generator.generate_pdf,
|
|
||||||
markdown_path=result.markdown_path,
|
|
||||||
output_path=output_path,
|
|
||||||
css_template=css_template,
|
|
||||||
max_retries=2,
|
|
||||||
retry_delay=3
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Successfully generated PDF for result {result_id}: {output_path}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to generate PDF for result {result_id}: {e}")
|
|
||||||
finally:
|
|
||||||
if should_close_db:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
async def start_cleanup_scheduler(self):
|
|
||||||
"""
|
|
||||||
Start periodic cleanup scheduler
|
|
||||||
|
|
||||||
Runs cleanup task at specified intervals
|
|
||||||
"""
|
|
||||||
logger.info(f"Starting cleanup scheduler (interval: {self.cleanup_interval}s, retention: {self.file_retention_hours}h)")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
db = SessionLocal()
|
|
||||||
await self.cleanup_expired_files(db)
|
|
||||||
db.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in cleanup scheduler: {e}")
|
|
||||||
|
|
||||||
# Wait for next interval
|
|
||||||
await asyncio.sleep(self.cleanup_interval)
|
|
||||||
|
|
||||||
|
|
||||||
# Global task manager instance
|
|
||||||
task_manager = BackgroundTaskManager()
|
|
||||||
|
|
||||||
|
|
||||||
def process_batch_files_with_retry(
|
|
||||||
batch_id: int,
|
|
||||||
lang: str,
|
|
||||||
detect_layout: bool,
|
|
||||||
db: Session
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Process all files in a batch with retry logic
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch_id: Batch ID
|
|
||||||
lang: Language code
|
|
||||||
detect_layout: Whether to detect layout
|
|
||||||
db: Database session
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get batch
|
|
||||||
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
|
|
||||||
if not batch:
|
|
||||||
logger.error(f"Batch {batch_id} not found")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Update batch status
|
|
||||||
batch.status = BatchStatus.PROCESSING
|
|
||||||
batch.started_at = datetime.utcnow()
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
# Get pending files
|
|
||||||
files = db.query(OCRFile).filter(
|
|
||||||
OCRFile.batch_id == batch_id,
|
|
||||||
OCRFile.status == FileStatus.PENDING
|
|
||||||
).all()
|
|
||||||
|
|
||||||
logger.info(f"Processing {len(files)} files in batch {batch_id} with retry logic")
|
|
||||||
|
|
||||||
# Process each file with retry
|
|
||||||
for ocr_file in files:
|
|
||||||
success = task_manager.process_single_file_with_retry(
|
|
||||||
ocr_file=ocr_file,
|
|
||||||
batch_id=batch_id,
|
|
||||||
lang=lang,
|
|
||||||
detect_layout=detect_layout,
|
|
||||||
db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update batch progress
|
|
||||||
if success:
|
|
||||||
batch.completed_files += 1
|
|
||||||
else:
|
|
||||||
batch.failed_files += 1
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
# Update batch final status
|
|
||||||
if batch.failed_files == 0:
|
|
||||||
batch.status = BatchStatus.COMPLETED
|
|
||||||
elif batch.completed_files > 0:
|
|
||||||
batch.status = BatchStatus.PARTIAL
|
|
||||||
else:
|
|
||||||
batch.status = BatchStatus.FAILED
|
|
||||||
|
|
||||||
batch.completed_at = datetime.utcnow()
|
|
||||||
|
|
||||||
# Commit with retry on connection errors
|
|
||||||
try:
|
|
||||||
db.commit()
|
|
||||||
except Exception as commit_error:
|
|
||||||
logger.warning(f"Batch commit failed, rolling back and retrying: {commit_error}")
|
|
||||||
db.rollback()
|
|
||||||
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
|
|
||||||
if batch:
|
|
||||||
batch.completed_at = datetime.utcnow()
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Batch {batch_id} processing complete: "
|
|
||||||
f"{batch.completed_files} succeeded, {batch.failed_files} failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Fatal error processing batch {batch_id}: {e}")
|
|
||||||
db.rollback() # Rollback any failed transaction
|
|
||||||
try:
|
|
||||||
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
|
|
||||||
if batch:
|
|
||||||
batch.status = BatchStatus.FAILED
|
|
||||||
batch.completed_at = datetime.utcnow()
|
|
||||||
db.commit()
|
|
||||||
except Exception as commit_error:
|
|
||||||
logger.error(f"Error updating batch status: {commit_error}")
|
|
||||||
db.rollback()
|
|
||||||
@@ -1,512 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - Export Service
|
|
||||||
Handles OCR result export in multiple formats with filtering and formatting rules
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import zipfile
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Dict, Optional, Any
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.models.ocr import OCRBatch, OCRFile, OCRResult, FileStatus
|
|
||||||
from app.models.export import ExportRule
|
|
||||||
from app.services.pdf_generator import PDFGenerator, PDFGenerationError
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ExportError(Exception):
|
|
||||||
"""Exception raised for export errors"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ExportService:
|
|
||||||
"""
|
|
||||||
Export service for OCR results
|
|
||||||
|
|
||||||
Supported formats:
|
|
||||||
- TXT: Plain text export
|
|
||||||
- JSON: Full metadata export
|
|
||||||
- Excel: Tabular data export
|
|
||||||
- Markdown: Direct Markdown export
|
|
||||||
- PDF: Layout-preserved PDF export
|
|
||||||
- ZIP: Batch export archive
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize export service"""
|
|
||||||
self.pdf_generator = PDFGenerator()
|
|
||||||
|
|
||||||
def apply_filters(
|
|
||||||
self,
|
|
||||||
results: List[OCRResult],
|
|
||||||
filters: Dict[str, Any]
|
|
||||||
) -> List[OCRResult]:
|
|
||||||
"""
|
|
||||||
Apply filters to OCR results
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results: List of OCR results
|
|
||||||
filters: Filter configuration
|
|
||||||
- confidence_threshold: Minimum confidence (0.0-1.0)
|
|
||||||
- filename_pattern: Glob pattern for filename matching
|
|
||||||
- language: Filter by detected language
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[OCRResult]: Filtered results
|
|
||||||
"""
|
|
||||||
filtered = results
|
|
||||||
|
|
||||||
# Confidence threshold filter
|
|
||||||
if "confidence_threshold" in filters:
|
|
||||||
threshold = filters["confidence_threshold"]
|
|
||||||
filtered = [r for r in filtered if r.average_confidence and r.average_confidence >= threshold]
|
|
||||||
|
|
||||||
# Filename pattern filter (using simple substring match)
|
|
||||||
if "filename_pattern" in filters:
|
|
||||||
pattern = filters["filename_pattern"].lower()
|
|
||||||
filtered = [
|
|
||||||
r for r in filtered
|
|
||||||
if pattern in r.file.original_filename.lower()
|
|
||||||
]
|
|
||||||
|
|
||||||
# Language filter
|
|
||||||
if "language" in filters:
|
|
||||||
lang = filters["language"]
|
|
||||||
filtered = [r for r in filtered if r.detected_language == lang]
|
|
||||||
|
|
||||||
return filtered
|
|
||||||
|
|
||||||
def export_to_txt(
|
|
||||||
self,
|
|
||||||
results: List[OCRResult],
|
|
||||||
output_path: Path,
|
|
||||||
formatting: Optional[Dict] = None
|
|
||||||
) -> Path:
|
|
||||||
"""
|
|
||||||
Export results to plain text file
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results: List of OCR results
|
|
||||||
output_path: Output file path
|
|
||||||
formatting: Formatting options
|
|
||||||
- add_line_numbers: Add line numbers
|
|
||||||
- group_by_filename: Group text by source file
|
|
||||||
- include_metadata: Add file metadata headers
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: Output file path
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ExportError: If export fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
formatting = formatting or {}
|
|
||||||
output_lines = []
|
|
||||||
|
|
||||||
for idx, result in enumerate(results, 1):
|
|
||||||
# Read Markdown file
|
|
||||||
if not result.markdown_path or not Path(result.markdown_path).exists():
|
|
||||||
logger.warning(f"Markdown file not found for result {result.id}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
markdown_content = Path(result.markdown_path).read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
# Add metadata header if requested
|
|
||||||
if formatting.get("include_metadata", False):
|
|
||||||
output_lines.append(f"=" * 80)
|
|
||||||
output_lines.append(f"文件: {result.file.original_filename}")
|
|
||||||
output_lines.append(f"語言: {result.detected_language or '未知'}")
|
|
||||||
output_lines.append(f"信心度: {result.average_confidence:.2%}" if result.average_confidence else "信心度: N/A")
|
|
||||||
output_lines.append(f"=" * 80)
|
|
||||||
output_lines.append("")
|
|
||||||
|
|
||||||
# Add content with optional line numbers
|
|
||||||
if formatting.get("add_line_numbers", False):
|
|
||||||
for line_num, line in enumerate(markdown_content.split('\n'), 1):
|
|
||||||
output_lines.append(f"{line_num:4d} | {line}")
|
|
||||||
else:
|
|
||||||
output_lines.append(markdown_content)
|
|
||||||
|
|
||||||
# Add separator between files if grouping
|
|
||||||
if formatting.get("group_by_filename", False) and idx < len(results):
|
|
||||||
output_lines.append("\n" + "-" * 80 + "\n")
|
|
||||||
|
|
||||||
# Write to file
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
output_path.write_text("\n".join(output_lines), encoding="utf-8")
|
|
||||||
|
|
||||||
logger.info(f"Exported {len(results)} results to TXT: {output_path}")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ExportError(f"TXT export failed: {str(e)}")
|
|
||||||
|
|
||||||
def export_to_json(
|
|
||||||
self,
|
|
||||||
results: List[OCRResult],
|
|
||||||
output_path: Path,
|
|
||||||
include_layout: bool = True,
|
|
||||||
include_images: bool = True
|
|
||||||
) -> Path:
|
|
||||||
"""
|
|
||||||
Export results to JSON file with full metadata
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results: List of OCR results
|
|
||||||
output_path: Output file path
|
|
||||||
include_layout: Include layout data
|
|
||||||
include_images: Include images metadata
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: Output file path
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ExportError: If export fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
export_data = {
|
|
||||||
"export_time": datetime.utcnow().isoformat(),
|
|
||||||
"total_files": len(results),
|
|
||||||
"results": []
|
|
||||||
}
|
|
||||||
|
|
||||||
for result in results:
|
|
||||||
result_data = {
|
|
||||||
"file_id": result.file.id,
|
|
||||||
"filename": result.file.original_filename,
|
|
||||||
"file_format": result.file.file_format,
|
|
||||||
"file_size": result.file.file_size,
|
|
||||||
"processing_time": result.file.processing_time,
|
|
||||||
"detected_language": result.detected_language,
|
|
||||||
"total_text_regions": result.total_text_regions,
|
|
||||||
"average_confidence": result.average_confidence,
|
|
||||||
"markdown_path": result.markdown_path,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Include layout data if requested
|
|
||||||
if include_layout and result.layout_data:
|
|
||||||
result_data["layout_data"] = result.layout_data
|
|
||||||
|
|
||||||
# Include images metadata if requested
|
|
||||||
if include_images and result.images_metadata:
|
|
||||||
result_data["images_metadata"] = result.images_metadata
|
|
||||||
|
|
||||||
export_data["results"].append(result_data)
|
|
||||||
|
|
||||||
# Write to file
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
output_path.write_text(
|
|
||||||
json.dumps(export_data, ensure_ascii=False, indent=2),
|
|
||||||
encoding="utf-8"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Exported {len(results)} results to JSON: {output_path}")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ExportError(f"JSON export failed: {str(e)}")
|
|
||||||
|
|
||||||
def export_to_excel(
|
|
||||||
self,
|
|
||||||
results: List[OCRResult],
|
|
||||||
output_path: Path,
|
|
||||||
include_confidence: bool = True,
|
|
||||||
include_processing_time: bool = True
|
|
||||||
) -> Path:
|
|
||||||
"""
|
|
||||||
Export results to Excel file
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results: List of OCR results
|
|
||||||
output_path: Output file path
|
|
||||||
include_confidence: Include confidence scores
|
|
||||||
include_processing_time: Include processing time
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: Output file path
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ExportError: If export fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
rows = []
|
|
||||||
|
|
||||||
for result in results:
|
|
||||||
# Read Markdown content
|
|
||||||
text_content = ""
|
|
||||||
if result.markdown_path and Path(result.markdown_path).exists():
|
|
||||||
text_content = Path(result.markdown_path).read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
row = {
|
|
||||||
"文件名": result.file.original_filename,
|
|
||||||
"格式": result.file.file_format,
|
|
||||||
"大小(字節)": result.file.file_size,
|
|
||||||
"語言": result.detected_language or "未知",
|
|
||||||
"文本區域數": result.total_text_regions,
|
|
||||||
"提取內容": text_content[:1000] + "..." if len(text_content) > 1000 else text_content,
|
|
||||||
}
|
|
||||||
|
|
||||||
if include_confidence:
|
|
||||||
row["平均信心度"] = f"{result.average_confidence:.2%}" if result.average_confidence else "N/A"
|
|
||||||
|
|
||||||
if include_processing_time:
|
|
||||||
row["處理時間(秒)"] = f"{result.file.processing_time:.2f}" if result.file.processing_time else "N/A"
|
|
||||||
|
|
||||||
rows.append(row)
|
|
||||||
|
|
||||||
# Create DataFrame and export
|
|
||||||
df = pd.DataFrame(rows)
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
df.to_excel(output_path, index=False, engine='openpyxl')
|
|
||||||
|
|
||||||
logger.info(f"Exported {len(results)} results to Excel: {output_path}")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ExportError(f"Excel export failed: {str(e)}")
|
|
||||||
|
|
||||||
def export_to_markdown(
|
|
||||||
self,
|
|
||||||
results: List[OCRResult],
|
|
||||||
output_path: Path,
|
|
||||||
combine: bool = True
|
|
||||||
) -> Path:
|
|
||||||
"""
|
|
||||||
Export results to Markdown file(s)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results: List of OCR results
|
|
||||||
output_path: Output file path (or directory if not combining)
|
|
||||||
combine: Combine all results into one file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: Output file/directory path
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ExportError: If export fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if combine:
|
|
||||||
# Combine all Markdown files into one
|
|
||||||
combined_content = []
|
|
||||||
|
|
||||||
for result in results:
|
|
||||||
if not result.markdown_path or not Path(result.markdown_path).exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
markdown_content = Path(result.markdown_path).read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
# Add header
|
|
||||||
combined_content.append(f"# {result.file.original_filename}\n")
|
|
||||||
combined_content.append(markdown_content)
|
|
||||||
combined_content.append("\n---\n") # Separator
|
|
||||||
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
output_path.write_text("\n".join(combined_content), encoding="utf-8")
|
|
||||||
|
|
||||||
logger.info(f"Exported {len(results)} results to combined Markdown: {output_path}")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Export each result to separate file
|
|
||||||
output_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
for result in results:
|
|
||||||
if not result.markdown_path or not Path(result.markdown_path).exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Copy Markdown file to output directory
|
|
||||||
src_path = Path(result.markdown_path)
|
|
||||||
dst_path = output_path / f"{result.file.original_filename}.md"
|
|
||||||
dst_path.write_text(src_path.read_text(encoding="utf-8"), encoding="utf-8")
|
|
||||||
|
|
||||||
logger.info(f"Exported {len(results)} results to separate Markdown files: {output_path}")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ExportError(f"Markdown export failed: {str(e)}")
|
|
||||||
|
|
||||||
def export_to_pdf(
|
|
||||||
self,
|
|
||||||
result: OCRResult,
|
|
||||||
output_path: Path,
|
|
||||||
css_template: str = "default",
|
|
||||||
metadata: Optional[Dict] = None
|
|
||||||
) -> Path:
|
|
||||||
"""
|
|
||||||
Export single result to PDF with layout preservation
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: OCR result
|
|
||||||
output_path: Output PDF path
|
|
||||||
css_template: CSS template name or custom CSS
|
|
||||||
metadata: Optional PDF metadata
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: Output PDF path
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ExportError: If export fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not result.markdown_path or not Path(result.markdown_path).exists():
|
|
||||||
raise ExportError(f"Markdown file not found for result {result.id}")
|
|
||||||
|
|
||||||
markdown_path = Path(result.markdown_path)
|
|
||||||
|
|
||||||
# Prepare metadata
|
|
||||||
pdf_metadata = metadata or {}
|
|
||||||
if "title" not in pdf_metadata:
|
|
||||||
pdf_metadata["title"] = result.file.original_filename
|
|
||||||
|
|
||||||
# Generate PDF
|
|
||||||
self.pdf_generator.generate_pdf(
|
|
||||||
markdown_path=markdown_path,
|
|
||||||
output_path=output_path,
|
|
||||||
css_template=css_template,
|
|
||||||
metadata=pdf_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Exported result {result.id} to PDF: {output_path}")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
except PDFGenerationError as e:
|
|
||||||
raise ExportError(f"PDF generation failed: {str(e)}")
|
|
||||||
except Exception as e:
|
|
||||||
raise ExportError(f"PDF export failed: {str(e)}")
|
|
||||||
|
|
||||||
def export_batch_to_zip(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
batch_id: int,
|
|
||||||
output_path: Path,
|
|
||||||
include_formats: Optional[List[str]] = None
|
|
||||||
) -> Path:
|
|
||||||
"""
|
|
||||||
Export entire batch to ZIP archive
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
batch_id: Batch ID
|
|
||||||
output_path: Output ZIP path
|
|
||||||
include_formats: List of formats to include (markdown, json, txt, excel, pdf)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: Output ZIP path
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ExportError: If export fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
include_formats = include_formats or ["markdown", "json"]
|
|
||||||
|
|
||||||
# Get batch and results
|
|
||||||
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
|
|
||||||
if not batch:
|
|
||||||
raise ExportError(f"Batch {batch_id} not found")
|
|
||||||
|
|
||||||
results = db.query(OCRResult).join(OCRFile).filter(
|
|
||||||
OCRFile.batch_id == batch_id,
|
|
||||||
OCRFile.status == FileStatus.COMPLETED
|
|
||||||
).all()
|
|
||||||
|
|
||||||
if not results:
|
|
||||||
raise ExportError(f"No completed results found for batch {batch_id}")
|
|
||||||
|
|
||||||
# Create temporary export directory
|
|
||||||
temp_dir = output_path.parent / f"temp_export_{batch_id}"
|
|
||||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Export in requested formats
|
|
||||||
if "markdown" in include_formats:
|
|
||||||
md_dir = temp_dir / "markdown"
|
|
||||||
self.export_to_markdown(results, md_dir, combine=False)
|
|
||||||
|
|
||||||
if "json" in include_formats:
|
|
||||||
json_path = temp_dir / "batch_results.json"
|
|
||||||
self.export_to_json(results, json_path)
|
|
||||||
|
|
||||||
if "txt" in include_formats:
|
|
||||||
txt_path = temp_dir / "batch_results.txt"
|
|
||||||
self.export_to_txt(results, txt_path)
|
|
||||||
|
|
||||||
if "excel" in include_formats:
|
|
||||||
excel_path = temp_dir / "batch_results.xlsx"
|
|
||||||
self.export_to_excel(results, excel_path)
|
|
||||||
|
|
||||||
# Create ZIP archive
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
|
||||||
for file_path in temp_dir.rglob('*'):
|
|
||||||
if file_path.is_file():
|
|
||||||
arcname = file_path.relative_to(temp_dir)
|
|
||||||
zipf.write(file_path, arcname)
|
|
||||||
|
|
||||||
logger.info(f"Exported batch {batch_id} to ZIP: {output_path}")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up temporary directory
|
|
||||||
import shutil
|
|
||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ExportError(f"Batch ZIP export failed: {str(e)}")
|
|
||||||
|
|
||||||
def apply_export_rule(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
results: List[OCRResult],
|
|
||||||
rule_id: int
|
|
||||||
) -> List[OCRResult]:
|
|
||||||
"""
|
|
||||||
Apply export rule to filter and format results
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
results: List of OCR results
|
|
||||||
rule_id: Export rule ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[OCRResult]: Filtered results
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ExportError: If rule not found
|
|
||||||
"""
|
|
||||||
rule = db.query(ExportRule).filter(ExportRule.id == rule_id).first()
|
|
||||||
if not rule:
|
|
||||||
raise ExportError(f"Export rule {rule_id} not found")
|
|
||||||
|
|
||||||
config = rule.config_json
|
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if "filters" in config:
|
|
||||||
results = self.apply_filters(results, config["filters"])
|
|
||||||
|
|
||||||
# Note: Formatting options are applied in individual export methods
|
|
||||||
return results
|
|
||||||
|
|
||||||
def get_export_formats(self) -> Dict[str, str]:
|
|
||||||
"""
|
|
||||||
Get available export formats
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping format codes to descriptions
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"txt": "純文本格式 (.txt)",
|
|
||||||
"json": "JSON 格式 - 包含完整元數據 (.json)",
|
|
||||||
"excel": "Excel 表格格式 (.xlsx)",
|
|
||||||
"markdown": "Markdown 格式 (.md)",
|
|
||||||
"pdf": "版面保留 PDF 格式 (.pdf)",
|
|
||||||
"zip": "批次打包格式 (.zip)",
|
|
||||||
}
|
|
||||||
@@ -1,420 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - File Management Service
|
|
||||||
Handles file uploads, storage, validation, and cleanup
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import shutil
|
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Tuple, Optional
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
from fastapi import UploadFile
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.models.ocr import OCRBatch, OCRFile, FileStatus
|
|
||||||
from app.services.preprocessor import DocumentPreprocessor
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FileManagementError(Exception):
|
|
||||||
"""Exception raised for file management errors"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FileManager:
|
|
||||||
"""
|
|
||||||
File management service for upload, storage, and cleanup
|
|
||||||
|
|
||||||
Directory structure:
|
|
||||||
uploads/
|
|
||||||
├── batches/
|
|
||||||
│ └── {batch_id}/
|
|
||||||
│ ├── inputs/ # Original uploaded files
|
|
||||||
│ ├── outputs/ # OCR results
|
|
||||||
│ │ ├── markdown/ # Markdown files
|
|
||||||
│ │ ├── json/ # JSON files
|
|
||||||
│ │ └── images/ # Extracted images
|
|
||||||
│ └── exports/ # Export files (PDF, Excel, etc.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize file manager"""
|
|
||||||
self.preprocessor = DocumentPreprocessor()
|
|
||||||
self.base_upload_dir = Path(settings.upload_dir)
|
|
||||||
self.base_upload_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
def create_batch_directory(self, batch_id: int) -> Path:
|
|
||||||
"""
|
|
||||||
Create directory structure for a batch
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch_id: Batch ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: Batch directory path
|
|
||||||
"""
|
|
||||||
batch_dir = self.base_upload_dir / "batches" / str(batch_id)
|
|
||||||
|
|
||||||
# Create subdirectories
|
|
||||||
(batch_dir / "inputs").mkdir(parents=True, exist_ok=True)
|
|
||||||
(batch_dir / "outputs" / "markdown").mkdir(parents=True, exist_ok=True)
|
|
||||||
(batch_dir / "outputs" / "json").mkdir(parents=True, exist_ok=True)
|
|
||||||
(batch_dir / "outputs" / "images").mkdir(parents=True, exist_ok=True)
|
|
||||||
(batch_dir / "exports").mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
logger.info(f"Created batch directory: {batch_dir}")
|
|
||||||
return batch_dir
|
|
||||||
|
|
||||||
def get_batch_directory(self, batch_id: int) -> Path:
|
|
||||||
"""
|
|
||||||
Get batch directory path
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch_id: Batch ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: Batch directory path
|
|
||||||
"""
|
|
||||||
return self.base_upload_dir / "batches" / str(batch_id)
|
|
||||||
|
|
||||||
def validate_upload(self, file: UploadFile) -> Tuple[bool, Optional[str]]:
|
|
||||||
"""
|
|
||||||
Validate uploaded file before saving
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file: Uploaded file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_valid, error_message)
|
|
||||||
"""
|
|
||||||
# Check filename
|
|
||||||
if not file.filename:
|
|
||||||
return False, "文件名不能為空"
|
|
||||||
|
|
||||||
# Check file size (read content size)
|
|
||||||
file.file.seek(0, 2) # Seek to end
|
|
||||||
file_size = file.file.tell()
|
|
||||||
file.file.seek(0) # Reset to beginning
|
|
||||||
|
|
||||||
if file_size == 0:
|
|
||||||
return False, "文件為空"
|
|
||||||
|
|
||||||
if file_size > settings.max_upload_size:
|
|
||||||
max_mb = settings.max_upload_size / (1024 * 1024)
|
|
||||||
return False, f"文件大小超過限制 ({max_mb}MB)"
|
|
||||||
|
|
||||||
# Check file extension
|
|
||||||
file_ext = Path(file.filename).suffix.lower()
|
|
||||||
allowed_extensions = {'.png', '.jpg', '.jpeg', '.pdf', '.doc', '.docx', '.ppt', '.pptx'}
|
|
||||||
if file_ext not in allowed_extensions:
|
|
||||||
return False, f"不支持的文件格式 ({file_ext}),僅支持: {', '.join(allowed_extensions)}"
|
|
||||||
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
def save_upload(
|
|
||||||
self,
|
|
||||||
file: UploadFile,
|
|
||||||
batch_id: int,
|
|
||||||
validate: bool = True
|
|
||||||
) -> Tuple[Path, str]:
|
|
||||||
"""
|
|
||||||
Save uploaded file to batch directory
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file: Uploaded file
|
|
||||||
batch_id: Batch ID
|
|
||||||
validate: Whether to validate file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (file_path, original_filename)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileManagementError: If file validation or saving fails
|
|
||||||
"""
|
|
||||||
# Validate if requested
|
|
||||||
if validate:
|
|
||||||
is_valid, error_msg = self.validate_upload(file)
|
|
||||||
if not is_valid:
|
|
||||||
raise FileManagementError(error_msg)
|
|
||||||
|
|
||||||
# Generate unique filename to avoid conflicts
|
|
||||||
original_filename = file.filename
|
|
||||||
file_ext = Path(original_filename).suffix
|
|
||||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
|
||||||
|
|
||||||
# Get batch input directory
|
|
||||||
batch_dir = self.get_batch_directory(batch_id)
|
|
||||||
input_dir = batch_dir / "inputs"
|
|
||||||
input_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Save file
|
|
||||||
file_path = input_dir / unique_filename
|
|
||||||
try:
|
|
||||||
with file_path.open("wb") as buffer:
|
|
||||||
shutil.copyfileobj(file.file, buffer)
|
|
||||||
|
|
||||||
logger.info(f"Saved upload: {file_path} (original: {original_filename})")
|
|
||||||
return file_path, original_filename
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Clean up partial file if exists
|
|
||||||
file_path.unlink(missing_ok=True)
|
|
||||||
raise FileManagementError(f"保存文件失敗: {str(e)}")
|
|
||||||
|
|
||||||
def validate_saved_file(self, file_path: Path) -> Tuple[bool, Optional[str], Optional[str]]:
|
|
||||||
"""
|
|
||||||
Validate saved file using preprocessor
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to saved file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_valid, error_message, detected_format)
|
|
||||||
"""
|
|
||||||
return self.preprocessor.validate_file(file_path)
|
|
||||||
|
|
||||||
def create_batch(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
user_id: int,
|
|
||||||
batch_name: Optional[str] = None
|
|
||||||
) -> OCRBatch:
|
|
||||||
"""
|
|
||||||
Create new OCR batch
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
batch_name: Optional batch name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OCRBatch: Created batch object
|
|
||||||
"""
|
|
||||||
# Create batch record
|
|
||||||
batch = OCRBatch(
|
|
||||||
user_id=user_id,
|
|
||||||
batch_name=batch_name or f"Batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
||||||
)
|
|
||||||
db.add(batch)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(batch)
|
|
||||||
|
|
||||||
# Create directory structure
|
|
||||||
self.create_batch_directory(batch.id)
|
|
||||||
|
|
||||||
logger.info(f"Created batch: {batch.id} for user {user_id}")
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def add_file_to_batch(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
batch_id: int,
|
|
||||||
file: UploadFile
|
|
||||||
) -> OCRFile:
|
|
||||||
"""
|
|
||||||
Add file to batch and save to disk
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
batch_id: Batch ID
|
|
||||||
file: Uploaded file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OCRFile: Created file record
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileManagementError: If file operations fail
|
|
||||||
"""
|
|
||||||
# Save file to disk
|
|
||||||
file_path, original_filename = self.save_upload(file, batch_id)
|
|
||||||
|
|
||||||
# Validate saved file
|
|
||||||
is_valid, detected_format, error_msg = self.validate_saved_file(file_path)
|
|
||||||
|
|
||||||
# Create file record
|
|
||||||
ocr_file = OCRFile(
|
|
||||||
batch_id=batch_id,
|
|
||||||
filename=file_path.name,
|
|
||||||
original_filename=original_filename,
|
|
||||||
file_path=str(file_path),
|
|
||||||
file_size=file_path.stat().st_size,
|
|
||||||
file_format=detected_format or Path(original_filename).suffix.lower().lstrip('.'),
|
|
||||||
status=FileStatus.PENDING if is_valid else FileStatus.FAILED,
|
|
||||||
error_message=error_msg if not is_valid else None
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(ocr_file)
|
|
||||||
|
|
||||||
# Update batch total_files count
|
|
||||||
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
|
|
||||||
if batch:
|
|
||||||
batch.total_files += 1
|
|
||||||
if not is_valid:
|
|
||||||
batch.failed_files += 1
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db.refresh(ocr_file)
|
|
||||||
|
|
||||||
logger.info(f"Added file to batch {batch_id}: {ocr_file.id} (status: {ocr_file.status})")
|
|
||||||
return ocr_file
|
|
||||||
|
|
||||||
def add_files_to_batch(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
batch_id: int,
|
|
||||||
files: List[UploadFile]
|
|
||||||
) -> List[OCRFile]:
|
|
||||||
"""
|
|
||||||
Add multiple files to batch
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
batch_id: Batch ID
|
|
||||||
files: List of uploaded files
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[OCRFile]: List of created file records
|
|
||||||
"""
|
|
||||||
ocr_files = []
|
|
||||||
for file in files:
|
|
||||||
try:
|
|
||||||
ocr_file = self.add_file_to_batch(db, batch_id, file)
|
|
||||||
ocr_files.append(ocr_file)
|
|
||||||
except FileManagementError as e:
|
|
||||||
logger.error(f"Failed to add file {file.filename} to batch {batch_id}: {e}")
|
|
||||||
# Continue with other files
|
|
||||||
continue
|
|
||||||
|
|
||||||
return ocr_files
|
|
||||||
|
|
||||||
def get_file_paths(self, batch_id: int, file_id: int) -> dict:
|
|
||||||
"""
|
|
||||||
Get all paths for a file in a batch
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch_id: Batch ID
|
|
||||||
file_id: File ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing all relevant paths
|
|
||||||
"""
|
|
||||||
batch_dir = self.get_batch_directory(batch_id)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_dir": batch_dir / "inputs",
|
|
||||||
"output_dir": batch_dir / "outputs",
|
|
||||||
"markdown_dir": batch_dir / "outputs" / "markdown",
|
|
||||||
"json_dir": batch_dir / "outputs" / "json",
|
|
||||||
"images_dir": batch_dir / "outputs" / "images" / str(file_id),
|
|
||||||
"export_dir": batch_dir / "exports",
|
|
||||||
}
|
|
||||||
|
|
||||||
def cleanup_expired_batches(self, db: Session, retention_hours: int = 24) -> int:
|
|
||||||
"""
|
|
||||||
Clean up expired batch files
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
retention_hours: Number of hours to retain files
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: Number of batches cleaned up
|
|
||||||
"""
|
|
||||||
cutoff_time = datetime.utcnow() - timedelta(hours=retention_hours)
|
|
||||||
|
|
||||||
# Find expired batches
|
|
||||||
expired_batches = db.query(OCRBatch).filter(
|
|
||||||
OCRBatch.created_at < cutoff_time
|
|
||||||
).all()
|
|
||||||
|
|
||||||
cleaned_count = 0
|
|
||||||
for batch in expired_batches:
|
|
||||||
try:
|
|
||||||
# Delete batch directory
|
|
||||||
batch_dir = self.get_batch_directory(batch.id)
|
|
||||||
if batch_dir.exists():
|
|
||||||
shutil.rmtree(batch_dir)
|
|
||||||
logger.info(f"Deleted batch directory: {batch_dir}")
|
|
||||||
|
|
||||||
# Delete database records (cascade will handle related records)
|
|
||||||
db.delete(batch)
|
|
||||||
cleaned_count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to cleanup batch {batch.id}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if cleaned_count > 0:
|
|
||||||
db.commit()
|
|
||||||
logger.info(f"Cleaned up {cleaned_count} expired batches")
|
|
||||||
|
|
||||||
return cleaned_count
|
|
||||||
|
|
||||||
def verify_file_ownership(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
user_id: int,
|
|
||||||
batch_id: int
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Verify user owns the batch
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
batch_id: Batch ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if user owns batch, False otherwise
|
|
||||||
"""
|
|
||||||
batch = db.query(OCRBatch).filter(
|
|
||||||
OCRBatch.id == batch_id,
|
|
||||||
OCRBatch.user_id == user_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
return batch is not None
|
|
||||||
|
|
||||||
def get_batch_statistics(self, db: Session, batch_id: int) -> dict:
|
|
||||||
"""
|
|
||||||
Get statistics for a batch
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
batch_id: Batch ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing batch statistics
|
|
||||||
"""
|
|
||||||
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
|
|
||||||
if not batch:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Calculate total file size
|
|
||||||
total_size = sum(f.file_size for f in batch.files)
|
|
||||||
|
|
||||||
# Calculate processing time
|
|
||||||
processing_time = None
|
|
||||||
if batch.completed_at and batch.started_at:
|
|
||||||
processing_time = (batch.completed_at - batch.started_at).total_seconds()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"batch_id": batch.id,
|
|
||||||
"batch_name": batch.batch_name,
|
|
||||||
"status": batch.status,
|
|
||||||
"total_files": batch.total_files,
|
|
||||||
"completed_files": batch.completed_files,
|
|
||||||
"failed_files": batch.failed_files,
|
|
||||||
"pending_files": batch.total_files - batch.completed_files - batch.failed_files,
|
|
||||||
"progress_percentage": batch.progress_percentage,
|
|
||||||
"total_file_size": total_size,
|
|
||||||
"total_file_size_mb": round(total_size / (1024 * 1024), 2),
|
|
||||||
"created_at": batch.created_at.isoformat(),
|
|
||||||
"started_at": batch.started_at.isoformat() if batch.started_at else None,
|
|
||||||
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
|
|
||||||
"processing_time": processing_time,
|
|
||||||
}
|
|
||||||
@@ -1,282 +0,0 @@
|
|||||||
"""
|
|
||||||
Tool_OCR - Translation Service (RESERVED)
|
|
||||||
Abstract interface and stub implementation for future translation feature
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, Optional, List
|
|
||||||
from enum import Enum
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationEngine(str, Enum):
|
|
||||||
"""Supported translation engines"""
|
|
||||||
OFFLINE = "offline" # Argos Translate (offline)
|
|
||||||
ERNIE = "ernie" # Baidu ERNIE API
|
|
||||||
GOOGLE = "google" # Google Translate API
|
|
||||||
DEEPL = "deepl" # DeepL API
|
|
||||||
|
|
||||||
|
|
||||||
class LanguageCode(str, Enum):
|
|
||||||
"""Supported language codes"""
|
|
||||||
CHINESE = "zh"
|
|
||||||
ENGLISH = "en"
|
|
||||||
JAPANESE = "ja"
|
|
||||||
KOREAN = "ko"
|
|
||||||
FRENCH = "fr"
|
|
||||||
GERMAN = "de"
|
|
||||||
SPANISH = "es"
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationServiceInterface(ABC):
|
|
||||||
"""
|
|
||||||
Abstract interface for translation services
|
|
||||||
|
|
||||||
This interface defines the contract for all translation engine implementations.
|
|
||||||
Future implementations should inherit from this class.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def translate_text(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
source_lang: str,
|
|
||||||
target_lang: str,
|
|
||||||
**kwargs
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Translate a single text string
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to translate
|
|
||||||
source_lang: Source language code
|
|
||||||
target_lang: Target language code
|
|
||||||
**kwargs: Engine-specific parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Translated text
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def translate_document(
|
|
||||||
self,
|
|
||||||
markdown_content: str,
|
|
||||||
source_lang: str,
|
|
||||||
target_lang: str,
|
|
||||||
preserve_structure: bool = True,
|
|
||||||
**kwargs
|
|
||||||
) -> Dict[str, any]:
|
|
||||||
"""
|
|
||||||
Translate a Markdown document while preserving structure
|
|
||||||
|
|
||||||
Args:
|
|
||||||
markdown_content: Markdown content to translate
|
|
||||||
source_lang: Source language code
|
|
||||||
target_lang: Target language code
|
|
||||||
preserve_structure: Whether to preserve markdown structure
|
|
||||||
**kwargs: Engine-specific parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing:
|
|
||||||
- translated_content: Translated markdown
|
|
||||||
- metadata: Translation metadata (engine, time, etc.)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def batch_translate(
|
|
||||||
self,
|
|
||||||
texts: List[str],
|
|
||||||
source_lang: str,
|
|
||||||
target_lang: str,
|
|
||||||
**kwargs
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
|
||||||
Translate multiple texts in batch
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of texts to translate
|
|
||||||
source_lang: Source language code
|
|
||||||
target_lang: Target language code
|
|
||||||
**kwargs: Engine-specific parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: List of translated texts
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_supported_languages(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
Get list of supported language codes for this engine
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: List of supported language codes
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def validate_config(self) -> bool:
|
|
||||||
"""
|
|
||||||
Validate engine configuration (API keys, model files, etc.)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if configuration is valid
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationEngineFactory:
|
|
||||||
"""
|
|
||||||
Factory for creating translation engine instances
|
|
||||||
|
|
||||||
RESERVED: This is a placeholder for future implementation.
|
|
||||||
When translation feature is implemented, this factory will instantiate
|
|
||||||
the appropriate translation engine based on configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_engine(
|
|
||||||
engine_type: TranslationEngine,
|
|
||||||
config: Optional[Dict] = None
|
|
||||||
) -> TranslationServiceInterface:
|
|
||||||
"""
|
|
||||||
Create a translation engine instance
|
|
||||||
|
|
||||||
Args:
|
|
||||||
engine_type: Type of translation engine
|
|
||||||
config: Engine-specific configuration
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TranslationServiceInterface: Translation engine instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: Always raised (stub implementation)
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Translation feature is not yet implemented. "
|
|
||||||
"This is a reserved placeholder for future development."
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_available_engines() -> List[str]:
|
|
||||||
"""
|
|
||||||
Get list of available translation engines
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: List of engine types (currently empty)
|
|
||||||
"""
|
|
||||||
return []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_engine_available(engine_type: TranslationEngine) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a specific engine is available
|
|
||||||
|
|
||||||
Args:
|
|
||||||
engine_type: Engine type to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Always False (stub implementation)
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class StubTranslationService:
|
|
||||||
"""
|
|
||||||
Stub translation service for API endpoints
|
|
||||||
|
|
||||||
This service provides placeholder responses for translation endpoints
|
|
||||||
until the feature is fully implemented.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_feature_status() -> Dict[str, any]:
|
|
||||||
"""
|
|
||||||
Get translation feature status
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with feature status information
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"available": False,
|
|
||||||
"status": "reserved",
|
|
||||||
"message": "Translation feature is reserved for future implementation",
|
|
||||||
"supported_engines": [],
|
|
||||||
"planned_engines": [
|
|
||||||
{
|
|
||||||
"type": "offline",
|
|
||||||
"name": "Argos Translate",
|
|
||||||
"description": "Offline neural translation",
|
|
||||||
"status": "planned"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "ernie",
|
|
||||||
"name": "Baidu ERNIE",
|
|
||||||
"description": "Baidu AI translation API",
|
|
||||||
"status": "planned"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "google",
|
|
||||||
"name": "Google Translate",
|
|
||||||
"description": "Google Cloud Translation API",
|
|
||||||
"status": "planned"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "deepl",
|
|
||||||
"name": "DeepL",
|
|
||||||
"description": "DeepL translation API",
|
|
||||||
"status": "planned"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"roadmap": {
|
|
||||||
"phase": "Phase 5",
|
|
||||||
"priority": "low",
|
|
||||||
"implementation_after": "Production deployment and user feedback"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_supported_languages() -> List[Dict[str, str]]:
|
|
||||||
"""
|
|
||||||
Get list of languages planned for translation support
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of language info dicts
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
{"code": "zh", "name": "Chinese (Simplified)", "status": "planned"},
|
|
||||||
{"code": "en", "name": "English", "status": "planned"},
|
|
||||||
{"code": "ja", "name": "Japanese", "status": "planned"},
|
|
||||||
{"code": "ko", "name": "Korean", "status": "planned"},
|
|
||||||
{"code": "fr", "name": "French", "status": "planned"},
|
|
||||||
{"code": "de", "name": "German", "status": "planned"},
|
|
||||||
{"code": "es", "name": "Spanish", "status": "planned"},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Example placeholder for future engine implementations:
|
|
||||||
#
|
|
||||||
# class ArgosTranslationEngine(TranslationServiceInterface):
|
|
||||||
# """Offline translation using Argos Translate"""
|
|
||||||
# def __init__(self, model_path: str):
|
|
||||||
# self.model_path = model_path
|
|
||||||
# # Initialize Argos models
|
|
||||||
#
|
|
||||||
# def translate_text(self, text, source_lang, target_lang, **kwargs):
|
|
||||||
# # Implementation here
|
|
||||||
# pass
|
|
||||||
#
|
|
||||||
# class ERNIETranslationEngine(TranslationServiceInterface):
|
|
||||||
# """Baidu ERNIE API translation"""
|
|
||||||
# def __init__(self, api_key: str, api_secret: str):
|
|
||||||
# self.api_key = api_key
|
|
||||||
# self.api_secret = api_secret
|
|
||||||
#
|
|
||||||
# def translate_text(self, text, source_lang, target_lang, **kwargs):
|
|
||||||
# # Implementation here
|
|
||||||
# pass
|
|
||||||
32
backend/check_tables.py
Normal file
32
backend/check_tables.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Check existing tables"""
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
engine = create_engine(settings.database_url)
|
||||||
|
|
||||||
|
with engine.connect() as conn:
|
||||||
|
# Get all tables
|
||||||
|
result = conn.execute(text("SHOW TABLES"))
|
||||||
|
tables = [row[0] for row in result.fetchall()]
|
||||||
|
|
||||||
|
print("Existing tables:")
|
||||||
|
for table in sorted(tables):
|
||||||
|
print(f" - {table}")
|
||||||
|
|
||||||
|
# Check which V2 tables exist
|
||||||
|
v2_tables = ['tool_ocr_users', 'tool_ocr_sessions', 'tool_ocr_tasks',
|
||||||
|
'tool_ocr_task_files', 'tool_ocr_audit_logs']
|
||||||
|
print("\nV2 Tables status:")
|
||||||
|
for table in v2_tables:
|
||||||
|
exists = table in tables
|
||||||
|
print(f" {'✓' if exists else '✗'} {table}")
|
||||||
|
|
||||||
|
# Check which old tables exist
|
||||||
|
old_tables = ['paddle_ocr_users', 'paddle_ocr_batches', 'paddle_ocr_files',
|
||||||
|
'paddle_ocr_results', 'paddle_ocr_export_rules', 'paddle_ocr_translation_configs']
|
||||||
|
print("\nOld Tables status:")
|
||||||
|
for table in old_tables:
|
||||||
|
exists = table in tables
|
||||||
|
print(f" {'✓' if exists else '✗'} {table}")
|
||||||
29
backend/fix_alembic_version.py
Normal file
29
backend/fix_alembic_version.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Fix alembic version in database"""
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
# Create database connection
|
||||||
|
engine = create_engine(settings.database_url)
|
||||||
|
|
||||||
|
with engine.connect() as conn:
|
||||||
|
# Delete the problematic version
|
||||||
|
conn.execute(text("DELETE FROM alembic_version WHERE version_num = '3ede847231ff'"))
|
||||||
|
conn.commit()
|
||||||
|
print("✓ Removed problematic alembic version")
|
||||||
|
|
||||||
|
# Check current version
|
||||||
|
result = conn.execute(text("SELECT version_num FROM alembic_version"))
|
||||||
|
versions = result.fetchall()
|
||||||
|
|
||||||
|
if versions:
|
||||||
|
print(f"Current version(s): {[v[0] for v in versions]}")
|
||||||
|
else:
|
||||||
|
print("No alembic version found in database")
|
||||||
|
# Set to the base version before our new migrations
|
||||||
|
conn.execute(text("INSERT INTO alembic_version (version_num) VALUES ('271dc036ea80')"))
|
||||||
|
conn.commit()
|
||||||
|
print("✓ Set alembic version to 271dc036ea80")
|
||||||
|
|
||||||
|
print("\nDone!")
|
||||||
@@ -28,13 +28,7 @@ import {
|
|||||||
TableHeader,
|
TableHeader,
|
||||||
TableRow,
|
TableRow,
|
||||||
} from '@/components/ui/table'
|
} from '@/components/ui/table'
|
||||||
import {
|
import { Select } from '@/components/ui/select'
|
||||||
Select,
|
|
||||||
SelectContent,
|
|
||||||
SelectItem,
|
|
||||||
SelectTrigger,
|
|
||||||
SelectValue,
|
|
||||||
} from '@/components/ui/select'
|
|
||||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
|
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
|
||||||
|
|
||||||
export default function TaskHistoryPage() {
|
export default function TaskHistoryPage() {
|
||||||
@@ -296,22 +290,18 @@ export default function TaskHistoryPage() {
|
|||||||
<label className="block text-sm font-medium text-gray-700 mb-2">狀態</label>
|
<label className="block text-sm font-medium text-gray-700 mb-2">狀態</label>
|
||||||
<Select
|
<Select
|
||||||
value={statusFilter}
|
value={statusFilter}
|
||||||
onValueChange={(value) => {
|
onChange={(e) => {
|
||||||
setStatusFilter(value as any)
|
setStatusFilter(e.target.value as any)
|
||||||
handleFilterChange()
|
handleFilterChange()
|
||||||
}}
|
}}
|
||||||
>
|
options={[
|
||||||
<SelectTrigger>
|
{ value: 'all', label: '全部' },
|
||||||
<SelectValue />
|
{ value: 'pending', label: '待處理' },
|
||||||
</SelectTrigger>
|
{ value: 'processing', label: '處理中' },
|
||||||
<SelectContent>
|
{ value: 'completed', label: '已完成' },
|
||||||
<SelectItem value="all">全部</SelectItem>
|
{ value: 'failed', label: '失敗' },
|
||||||
<SelectItem value="pending">待處理</SelectItem>
|
]}
|
||||||
<SelectItem value="processing">處理中</SelectItem>
|
/>
|
||||||
<SelectItem value="completed">已完成</SelectItem>
|
|
||||||
<SelectItem value="failed">失敗</SelectItem>
|
|
||||||
</SelectContent>
|
|
||||||
</Select>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
|
|||||||
@@ -8,7 +8,8 @@
|
|||||||
* - Session management
|
* - Session management
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import axios, { AxiosError, AxiosInstance } from 'axios'
|
import axios, { AxiosError } from 'axios'
|
||||||
|
import type { AxiosInstance } from 'axios'
|
||||||
import type {
|
import type {
|
||||||
LoginRequest,
|
LoginRequest,
|
||||||
ApiError,
|
ApiError,
|
||||||
|
|||||||
Reference in New Issue
Block a user