refactor: complete V1 to V2 migration and remove legacy architecture

Remove all V1 architecture components and promote V2 to primary:
- Delete all paddle_ocr_* table models (export, ocr, translation, user)
- Delete legacy routers (auth, export, ocr, translation)
- Delete legacy schemas and services
- Promote user_v2.py to user.py as primary user model
- Update all imports and dependencies to use V2 models only
- Update main.py version to 2.0.0

Database changes:
- Fix SQLAlchemy reserved word: rename audit_log.metadata to extra_data
- Add migration to drop all paddle_ocr_* tables
- Update alembic env to only import V2 models

Frontend fixes:
- Fix Select component exports in TaskHistoryPage.tsx
- Update to use simplified Select API with options prop
- Fix AxiosInstance TypeScript import syntax

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
egg
2025-11-14 21:27:39 +08:00
parent ad2b832fb6
commit fd98018ddd
34 changed files with 554 additions and 3787 deletions

View File

@@ -1,6 +1,6 @@
"""
Tool_OCR - FastAPI Dependencies
Authentication and database session dependencies
Tool_OCR - FastAPI Dependencies (V2)
Authentication and database session dependencies with external authentication
"""
from typing import Generator, Optional
@@ -13,7 +13,6 @@ from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.core.security import decode_access_token
from app.models.user import User
from app.models.user_v2 import User as UserV2
from app.models.session import Session as UserSession
from app.services.admin_service import admin_service
@@ -44,7 +43,7 @@ def get_current_user(
db: Session = Depends(get_db)
) -> User:
"""
Get current authenticated user from JWT token
Get current authenticated user from JWT token (External Authentication)
Args:
credentials: HTTP Bearer credentials
@@ -65,110 +64,6 @@ def get_current_user(
# Extract token
token = credentials.credentials
# Decode token
payload = decode_access_token(token)
if payload is None:
raise credentials_exception
# Extract user ID from token (convert from string to int)
user_id_str: Optional[str] = payload.get("sub")
if user_id_str is None:
raise credentials_exception
try:
user_id: int = int(user_id_str)
except (ValueError, TypeError):
raise credentials_exception
# Query user from database
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise credentials_exception
# Check if user is active
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return user
def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
Get current active user
Args:
current_user: Current user from get_current_user
Returns:
User: Current active user
Raises:
HTTPException: If user is inactive
"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return current_user
def get_current_admin_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
Get current admin user
Args:
current_user: Current user from get_current_user
Returns:
User: Current admin user
Raises:
HTTPException: If user is not admin
"""
if not current_user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough privileges"
)
return current_user
# ===== V2 Dependencies for External Authentication =====
def get_current_user_v2(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db)
) -> UserV2:
"""
Get current authenticated user from JWT token (V2 with external auth)
Args:
credentials: HTTP Bearer credentials
db: Database session
Returns:
UserV2: Current user object
Raises:
HTTPException: If token is invalid or user not found
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# Extract token
token = credentials.credentials
# Decode token
payload = decode_access_token(token)
if payload is None:
@@ -187,10 +82,10 @@ def get_current_user_v2(
# Extract session ID from token (optional)
session_id: Optional[int] = payload.get("session_id")
# Query user from database (using V2 model)
user = db.query(UserV2).filter(UserV2.id == user_id).first()
# Query user from database
user = db.query(User).filter(User.id == user_id).first()
if user is None:
logger.warning(f"User {user_id} not found in V2 table")
logger.warning(f"User {user_id} not found")
raise credentials_exception
# Check if user is active
@@ -234,17 +129,17 @@ def get_current_user_v2(
return user
def get_current_active_user_v2(
current_user: UserV2 = Depends(get_current_user_v2)
) -> UserV2:
def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
Get current active user (V2)
Get current active user
Args:
current_user: Current user from get_current_user_v2
current_user: Current user from get_current_user
Returns:
UserV2: Current active user
User: Current active user
Raises:
HTTPException: If user is inactive
@@ -257,17 +152,17 @@ def get_current_active_user_v2(
return current_user
def get_current_admin_user_v2(
current_user: UserV2 = Depends(get_current_user_v2)
) -> UserV2:
def get_current_admin_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
Get current admin user (V2)
Get current admin user
Args:
current_user: Current user from get_current_user_v2
current_user: Current user from get_current_user
Returns:
UserV2: Current admin user
User: Current admin user
Raises:
HTTPException: If user is not admin

View File

@@ -1,5 +1,5 @@
"""
Tool_OCR - FastAPI Application Entry Point
Tool_OCR - FastAPI Application Entry Point (V2)
Main application setup with CORS, routes, and startup/shutdown events
"""
@@ -7,11 +7,9 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import logging
import asyncio
from pathlib import Path
from app.core.config import settings
from app.services.background_tasks import task_manager
# Ensure log directory exists before configuring logging
Path(settings.log_file).parent.mkdir(parents=True, exist_ok=True)
@@ -32,19 +30,12 @@ logger = logging.getLogger(__name__)
async def lifespan(app: FastAPI):
"""Application lifespan events"""
# Startup
logger.info("Starting Tool_OCR application...")
logger.info("Starting Tool_OCR V2 application...")
# Ensure all directories exist
settings.ensure_directories()
logger.info("All directories created/verified")
# Start cleanup scheduler as background task
cleanup_task = asyncio.create_task(task_manager.start_cleanup_scheduler())
logger.info("Started cleanup scheduler for expired files")
# TODO: Initialize database connection pool
# TODO: Load PaddleOCR models
logger.info("Application startup complete")
yield
@@ -52,21 +43,12 @@ async def lifespan(app: FastAPI):
# Shutdown
logger.info("Shutting down Tool_OCR application...")
# Cancel cleanup task
cleanup_task.cancel()
try:
await cleanup_task
except asyncio.CancelledError:
logger.info("Cleanup scheduler stopped")
# TODO: Close database connections
# Create FastAPI application
app = FastAPI(
title="Tool_OCR",
description="OCR Batch Processing System with Structure Extraction",
version="0.1.0",
title="Tool_OCR V2",
description="OCR Processing System with External Authentication & Task Isolation",
version="2.0.0",
lifespan=lifespan,
)
@@ -88,8 +70,8 @@ async def health_check():
response = {
"status": "healthy",
"service": "Tool_OCR",
"version": "0.1.0",
"service": "Tool_OCR V2",
"version": "2.0.0",
}
# Add GPU status information
@@ -134,26 +116,17 @@ async def health_check():
async def root():
"""Root endpoint with API information"""
return {
"message": "Tool_OCR API",
"version": "0.1.0",
"message": "Tool_OCR API V2 - External Authentication",
"version": "2.0.0",
"docs_url": "/docs",
"health_check": "/health",
}
# Include API routers
from app.routers import auth, ocr, export, translation
# V2 routers with external authentication
from app.routers import auth_v2, tasks, admin
# Include V2 API routers
from app.routers import auth, tasks, admin
# Legacy V1 routers
app.include_router(auth.router)
app.include_router(ocr.router)
app.include_router(export.router)
app.include_router(translation.router) # RESERVED for Phase 5
# New V2 routers with external authentication
app.include_router(auth_v2.router)
app.include_router(tasks.router)
app.include_router(admin.router)

View File

@@ -1,31 +1,20 @@
"""
Tool_OCR - Database Models
Tool_OCR - Database Models (V2)
New schema with external API authentication and user task isolation.
External API authentication with user task isolation.
All tables use 'tool_ocr_' prefix for namespace separation.
"""
# New models for external authentication system
from app.models.user_v2 import User
from app.models.user import User
from app.models.task import Task, TaskFile, TaskStatus
from app.models.session import Session
# Legacy models (will be deprecated after migration)
from app.models.ocr import OCRBatch, OCRFile, OCRResult
from app.models.export import ExportRule
from app.models.translation import TranslationConfig
from app.models.audit_log import AuditLog
__all__ = [
# New authentication and task models
"User",
"Task",
"TaskFile",
"TaskStatus",
"Session",
# Legacy models (deprecated)
"OCRBatch",
"OCRFile",
"OCRResult",
"ExportRule",
"TranslationConfig",
"AuditLog",
]

View File

@@ -67,7 +67,7 @@ class AuditLog(Base):
comment="1 for success, 0 for failure"
)
error_message = Column(Text, nullable=True, comment="Error details if failed")
metadata = Column(Text, nullable=True, comment="Additional JSON metadata")
extra_data = Column(Text, nullable=True, comment="Additional JSON metadata")
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
# Relationships
@@ -90,6 +90,6 @@ class AuditLog(Base):
"resource_id": self.resource_id,
"success": bool(self.success),
"error_message": self.error_message,
"metadata": self.metadata,
"extra_data": self.extra_data,
"created_at": self.created_at.isoformat() if self.created_at else None
}

View File

@@ -1,55 +0,0 @@
"""
Tool_OCR - Export Rule Model
User-defined export rules and formatting configurations
"""
from sqlalchemy import Column, Integer, String, DateTime, Text, ForeignKey, JSON
from sqlalchemy.orm import relationship
from datetime import datetime
from app.core.database import Base
class ExportRule(Base):
"""Export rule configuration for customized output formatting"""
__tablename__ = "paddle_ocr_export_rules"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("paddle_ocr_users.id", ondelete="CASCADE"), nullable=False, index=True)
rule_name = Column(String(100), nullable=False)
description = Column(Text, nullable=True)
# Rule configuration stored as JSON
# {
# "filters": {
# "confidence_threshold": 0.8,
# "filename_pattern": "invoice_*",
# "language": "ch"
# },
# "formatting": {
# "add_line_numbers": true,
# "sort_by_position": true,
# "group_by_filename": false
# },
# "export_options": {
# "include_metadata": true,
# "include_confidence": true,
# "include_bounding_boxes": false
# }
# }
config_json = Column(JSON, nullable=False)
# CSS template for PDF export (optional)
# Can reference predefined templates: "default", "academic", "business", "report"
# Or store custom CSS
css_template = Column(Text, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
user = relationship("User", back_populates="export_rules")
def __repr__(self):
return f"<ExportRule(id={self.id}, name='{self.rule_name}', user_id={self.user_id})>"

View File

@@ -1,122 +0,0 @@
"""
Tool_OCR - OCR Models
Database models for OCR batches, files, and results
"""
from sqlalchemy import Column, Integer, String, DateTime, Float, Text, ForeignKey, Enum, JSON
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
from app.core.database import Base
class BatchStatus(str, enum.Enum):
"""Batch processing status"""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
PARTIAL = "partial" # Some files failed
FAILED = "failed"
class FileStatus(str, enum.Enum):
"""Individual file processing status"""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class OCRBatch(Base):
"""OCR batch processing tracking"""
__tablename__ = "paddle_ocr_batches"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("paddle_ocr_users.id", ondelete="CASCADE"), nullable=False, index=True)
batch_name = Column(String(255), nullable=True)
status = Column(Enum(BatchStatus), default=BatchStatus.PENDING, nullable=False, index=True)
total_files = Column(Integer, default=0, nullable=False)
completed_files = Column(Integer, default=0, nullable=False)
failed_files = Column(Integer, default=0, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
# Relationships
user = relationship("User", back_populates="ocr_batches")
files = relationship("OCRFile", back_populates="batch", cascade="all, delete-orphan")
@property
def progress_percentage(self) -> float:
"""Calculate progress percentage"""
if self.total_files == 0:
return 0.0
return (self.completed_files / self.total_files) * 100
def __repr__(self):
return f"<OCRBatch(id={self.id}, status='{self.status}', progress={self.progress_percentage:.1f}%)>"
class OCRFile(Base):
"""Individual file in an OCR batch"""
__tablename__ = "paddle_ocr_files"
id = Column(Integer, primary_key=True, index=True)
batch_id = Column(Integer, ForeignKey("paddle_ocr_batches.id", ondelete="CASCADE"), nullable=False, index=True)
filename = Column(String(255), nullable=False)
original_filename = Column(String(255), nullable=False)
file_path = Column(String(512), nullable=False)
file_size = Column(Integer, nullable=False) # Size in bytes
file_format = Column(String(20), nullable=False) # png, jpg, pdf, etc.
status = Column(Enum(FileStatus), default=FileStatus.PENDING, nullable=False, index=True)
error_message = Column(Text, nullable=True)
retry_count = Column(Integer, default=0, nullable=False) # Number of retry attempts
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
processing_time = Column(Float, nullable=True) # Processing time in seconds
# Relationships
batch = relationship("OCRBatch", back_populates="files")
result = relationship("OCRResult", back_populates="file", uselist=False, cascade="all, delete-orphan")
def __repr__(self):
return f"<OCRFile(id={self.id}, filename='{self.filename}', status='{self.status}')>"
class OCRResult(Base):
"""OCR processing result with structure and images"""
__tablename__ = "paddle_ocr_results"
id = Column(Integer, primary_key=True, index=True)
file_id = Column(Integer, ForeignKey("paddle_ocr_files.id", ondelete="CASCADE"), unique=True, nullable=False, index=True)
# Output file paths
markdown_path = Column(String(512), nullable=True) # Path to Markdown file
json_path = Column(String(512), nullable=True) # Path to JSON file
images_dir = Column(String(512), nullable=True) # Directory containing extracted images
# OCR metadata
detected_language = Column(String(20), nullable=True) # ch, en, japan, korean
total_text_regions = Column(Integer, default=0, nullable=False)
average_confidence = Column(Float, nullable=True)
# Layout structure data (stored as JSON)
# Contains: layout elements (title, paragraph, table, image, formula), reading order, bounding boxes
layout_data = Column(JSON, nullable=True)
# Extracted images metadata (stored as JSON)
# Contains: list of {image_path, bbox, element_type}
images_metadata = Column(JSON, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
# Relationships
file = relationship("OCRFile", back_populates="result")
def __repr__(self):
return f"<OCRResult(id={self.id}, file_id={self.file_id}, language='{self.detected_language}')>"

View File

@@ -1,43 +0,0 @@
"""
Tool_OCR - Translation Config Model (RESERVED)
Reserved for future translation feature implementation
"""
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON
from sqlalchemy.orm import relationship
from datetime import datetime
from app.core.database import Base
class TranslationConfig(Base):
"""
Translation configuration (RESERVED for future implementation)
This table is created but not actively used until translation feature is implemented.
"""
__tablename__ = "paddle_ocr_translation_configs"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("paddle_ocr_users.id", ondelete="CASCADE"), nullable=False, index=True)
source_lang = Column(String(20), nullable=False) # ch, en, japan, korean, etc.
target_lang = Column(String(20), nullable=False) # en, ch, japan, korean, etc.
# Translation engine type: "offline" (argostranslate), "ernie", "google", "deepl"
engine_type = Column(String(50), nullable=False, default="offline")
# Engine-specific configuration stored as JSON
# For offline (argostranslate): {"model_path": "/path/to/model"}
# For API-based: {"api_key": "xxx", "endpoint": "https://..."}
engine_config = Column(JSON, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
user = relationship("User", back_populates="translation_configs")
def __repr__(self):
return f"<TranslationConfig(id={self.id}, {self.source_lang}->{self.target_lang}, engine='{self.engine_type}')>"

View File

@@ -1,6 +1,6 @@
"""
Tool_OCR - User Model
User authentication and management
Tool_OCR - User Model v2.0
External API authentication with simplified schema
"""
from sqlalchemy import Column, Integer, String, DateTime, Boolean
@@ -11,24 +11,39 @@ from app.core.database import Base
class User(Base):
"""User model for JWT authentication"""
"""
User model for external API authentication
__tablename__ = "paddle_ocr_users"
Uses email as primary identifier from Azure AD.
No password storage - authentication via external API only.
"""
id = Column(Integer, primary_key=True, index=True)
username = Column(String(50), unique=True, nullable=False, index=True)
email = Column(String(100), unique=True, nullable=False, index=True)
password_hash = Column(String(255), nullable=False)
full_name = Column(String(100), nullable=True)
is_active = Column(Boolean, default=True, nullable=False)
is_admin = Column(Boolean, default=False, nullable=False)
__tablename__ = "tool_ocr_users"
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
email = Column(String(255), unique=True, nullable=False, index=True,
comment="Primary identifier from Azure AD")
display_name = Column(String(255), nullable=True,
comment="Display name from API response")
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
last_login = Column(DateTime, nullable=True)
is_active = Column(Boolean, default=True, nullable=False, index=True)
# Relationships
ocr_batches = relationship("OCRBatch", back_populates="user", cascade="all, delete-orphan")
export_rules = relationship("ExportRule", back_populates="user", cascade="all, delete-orphan")
translation_configs = relationship("TranslationConfig", back_populates="user", cascade="all, delete-orphan")
tasks = relationship("Task", back_populates="user", cascade="all, delete-orphan")
sessions = relationship("Session", back_populates="user", cascade="all, delete-orphan")
audit_logs = relationship("AuditLog", back_populates="user")
def __repr__(self):
return f"<User(id={self.id}, username='{self.username}', email='{self.email}')>"
return f"<User(id={self.id}, email='{self.email}', display_name='{self.display_name}')>"
def to_dict(self):
"""Convert user to dictionary"""
return {
"id": self.id,
"email": self.email,
"display_name": self.display_name,
"created_at": self.created_at.isoformat() if self.created_at else None,
"last_login": self.last_login.isoformat() if self.last_login else None,
"is_active": self.is_active
}

View File

@@ -1,49 +0,0 @@
"""
Tool_OCR - User Model v2.0
External API authentication with simplified schema
"""
from sqlalchemy import Column, Integer, String, DateTime, Boolean
from sqlalchemy.orm import relationship
from datetime import datetime
from app.core.database import Base
class User(Base):
"""
User model for external API authentication
Uses email as primary identifier from Azure AD.
No password storage - authentication via external API only.
"""
__tablename__ = "tool_ocr_users"
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
email = Column(String(255), unique=True, nullable=False, index=True,
comment="Primary identifier from Azure AD")
display_name = Column(String(255), nullable=True,
comment="Display name from API response")
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
last_login = Column(DateTime, nullable=True)
is_active = Column(Boolean, default=True, nullable=False, index=True)
# Relationships
tasks = relationship("Task", back_populates="user", cascade="all, delete-orphan")
sessions = relationship("Session", back_populates="user", cascade="all, delete-orphan")
audit_logs = relationship("AuditLog", back_populates="user")
def __repr__(self):
return f"<User(id={self.id}, email='{self.email}', display_name='{self.display_name}')>"
def to_dict(self):
"""Convert user to dictionary"""
return {
"id": self.id,
"email": self.email,
"display_name": self.display_name,
"created_at": self.created_at.isoformat() if self.created_at else None,
"last_login": self.last_login.isoformat() if self.last_login else None,
"is_active": self.is_active
}

View File

@@ -1,7 +1,7 @@
"""
Tool_OCR - API Routers
Tool_OCR - API Routers (V2)
"""
from app.routers import auth, ocr, export, translation
from app.routers import auth, tasks, admin
__all__ = ["auth", "ocr", "export", "translation"]
__all__ = ["auth", "tasks", "admin"]

View File

@@ -10,8 +10,8 @@ from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from app.core.deps import get_db, get_current_admin_user_v2
from app.models.user_v2 import User
from app.core.deps import get_db, get_current_admin_user
from app.models.user import User
from app.services.admin_service import admin_service
from app.services.audit_service import audit_service
@@ -23,7 +23,7 @@ router = APIRouter(prefix="/api/v2/admin", tags=["Admin"])
@router.get("/stats", summary="Get system statistics")
async def get_system_stats(
db: Session = Depends(get_db),
admin_user: User = Depends(get_current_admin_user_v2)
admin_user: User = Depends(get_current_admin_user)
):
"""
Get overall system statistics
@@ -47,7 +47,7 @@ async def list_users(
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
db: Session = Depends(get_db),
admin_user: User = Depends(get_current_admin_user_v2)
admin_user: User = Depends(get_current_admin_user)
):
"""
Get list of all users with statistics
@@ -79,7 +79,7 @@ async def get_top_users(
metric: str = Query("tasks", regex="^(tasks|completed_tasks)$"),
limit: int = Query(10, ge=1, le=50),
db: Session = Depends(get_db),
admin_user: User = Depends(get_current_admin_user_v2)
admin_user: User = Depends(get_current_admin_user)
):
"""
Get top users by metric
@@ -115,7 +115,7 @@ async def get_audit_logs(
page: int = Query(1, ge=1),
page_size: int = Query(100, ge=1, le=500),
db: Session = Depends(get_db),
admin_user: User = Depends(get_current_admin_user_v2)
admin_user: User = Depends(get_current_admin_user)
):
"""
Get audit logs with filtering
@@ -169,7 +169,7 @@ async def get_user_activity_summary(
user_id: int,
days: int = Query(30, ge=1, le=365),
db: Session = Depends(get_db),
admin_user: User = Depends(get_current_admin_user_v2)
admin_user: User = Depends(get_current_admin_user)
):
"""
Get user activity summary for the last N days

View File

@@ -1,70 +1,347 @@
"""
Tool_OCR - Authentication Router
JWT login endpoint
Tool_OCR - External Authentication Router (V2)
Handles authentication via external Microsoft Azure AD API
"""
from datetime import timedelta
from datetime import datetime, timedelta
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.deps import get_db
from app.core.security import verify_password, create_access_token
from app.core.deps import get_db, get_current_user
from app.core.security import create_access_token
from app.models.user import User
from app.schemas.auth import LoginRequest, Token
from app.models.session import Session as UserSession
from app.schemas.auth import LoginRequest, Token, UserResponse
from app.services.external_auth_service import external_auth_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/auth", tags=["Authentication"])
router = APIRouter(prefix="/api/v2/auth", tags=["Authentication V2"])
@router.post("/login", response_model=Token, summary="User login")
def get_client_ip(request: Request) -> str:
"""Extract client IP address from request"""
# Check X-Forwarded-For header (for proxies)
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
# Check X-Real-IP header
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fallback to direct client
return request.client.host if request.client else "unknown"
def get_user_agent(request: Request) -> str:
"""Extract user agent from request"""
return request.headers.get("User-Agent", "unknown")[:500]
@router.post("/login", response_model=Token, summary="External API login")
async def login(
login_data: LoginRequest,
request: Request,
db: Session = Depends(get_db)
):
"""
User login with username and password
User login via external Microsoft Azure AD API
Returns JWT access token for authentication
Returns JWT access token and stores session information
- **username**: User's username
- **username**: User's email address
- **password**: User's password
"""
# Query user by username
user = db.query(User).filter(User.username == login_data.username).first()
# Call external authentication API
success, auth_response, error_msg = await external_auth_service.authenticate_user(
username=login_data.username,
password=login_data.password
)
# Verify user exists and password is correct
if not user or not verify_password(login_data.password, user.password_hash):
logger.warning(f"Failed login attempt for username: {login_data.username}")
if not success or not auth_response:
logger.warning(
f"External auth failed for user {login_data.username}: {error_msg}"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
detail=error_msg or "Authentication failed",
headers={"WWW-Authenticate": "Bearer"},
)
# Check if user is active
if not user.is_active:
logger.warning(f"Inactive user login attempt: {login_data.username}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User account is inactive"
)
# Extract user info from external API response
user_info = auth_response.user_info
email = user_info.email
display_name = user_info.name
# Create access token
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": str(user.id), "username": user.username},
expires_delta=access_token_expires
# Find or create user in database
user = db.query(User).filter(User.email == email).first()
if not user:
# Create new user
user = User(
email=email,
display_name=display_name,
is_active=True,
last_login=datetime.utcnow()
)
db.add(user)
db.commit()
db.refresh(user)
logger.info(f"Created new user: {email} (ID: {user.id})")
else:
# Update existing user
user.display_name = display_name
user.last_login = datetime.utcnow()
# Check if user is active
if not user.is_active:
logger.warning(f"Inactive user login attempt: {email}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User account is inactive"
)
db.commit()
db.refresh(user)
logger.info(f"Updated existing user: {email} (ID: {user.id})")
# Parse token expiration
try:
expires_at = datetime.fromisoformat(auth_response.expires_at.replace('Z', '+00:00'))
issued_at = datetime.fromisoformat(auth_response.issued_at.replace('Z', '+00:00'))
except Exception as e:
logger.error(f"Failed to parse token timestamps: {e}")
expires_at = datetime.utcnow() + timedelta(seconds=auth_response.expires_in)
issued_at = datetime.utcnow()
# Create session in database
# TODO: Implement token encryption before storing
session = UserSession(
user_id=user.id,
access_token=auth_response.access_token, # Should be encrypted
id_token=auth_response.id_token, # Should be encrypted
token_type=auth_response.token_type,
expires_at=expires_at,
issued_at=issued_at,
ip_address=get_client_ip(request),
user_agent=get_user_agent(request)
)
db.add(session)
db.commit()
db.refresh(session)
logger.info(
f"Created session {session.id} for user {user.email} "
f"(expires: {expires_at})"
)
logger.info(f"Successful login: {user.username} (ID: {user.id})")
# Create internal JWT token for API access
# This token contains user ID and session ID
internal_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
internal_access_token = create_access_token(
data={
"sub": str(user.id),
"email": user.email,
"session_id": session.id
},
expires_delta=internal_token_expires
)
return {
"access_token": access_token,
"access_token": internal_access_token,
"token_type": "bearer",
"expires_in": settings.access_token_expire_minutes * 60 # Convert to seconds
"expires_in": int(internal_token_expires.total_seconds()),
"user": {
"id": user.id,
"email": user.email,
"display_name": user.display_name
}
}
@router.post("/logout", summary="User logout")
async def logout(
session_id: Optional[int] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
User logout - invalidates session
- **session_id**: Session ID to logout (optional, logs out all if not provided)
"""
# TODO: Implement proper current_user dependency from JWT token
# For now, this is a placeholder
if session_id:
# Logout specific session
session = db.query(UserSession).filter(
UserSession.id == session_id,
UserSession.user_id == current_user.id
).first()
if session:
db.delete(session)
db.commit()
logger.info(f"Logged out session {session_id} for user {current_user.email}")
return {"message": "Logged out successfully"}
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
else:
# Logout all sessions
sessions = db.query(UserSession).filter(
UserSession.user_id == current_user.id
).all()
count = len(sessions)
for session in sessions:
db.delete(session)
db.commit()
logger.info(f"Logged out all {count} sessions for user {current_user.email}")
return {"message": f"Logged out {count} sessions"}
@router.get("/me", response_model=UserResponse, summary="Get current user")
async def get_me(
current_user: User = Depends(get_current_user)
):
"""
Get current authenticated user information
"""
# TODO: Implement proper current_user dependency from JWT token
return {
"id": current_user.id,
"email": current_user.email,
"display_name": current_user.display_name,
"created_at": current_user.created_at,
"last_login": current_user.last_login,
"is_active": current_user.is_active
}
@router.get("/sessions", summary="List user sessions")
async def list_sessions(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
List all active sessions for current user
"""
sessions = db.query(UserSession).filter(
UserSession.user_id == current_user.id
).order_by(UserSession.created_at.desc()).all()
return {
"sessions": [
{
"id": s.id,
"token_type": s.token_type,
"expires_at": s.expires_at,
"issued_at": s.issued_at,
"ip_address": s.ip_address,
"user_agent": s.user_agent,
"created_at": s.created_at,
"last_accessed_at": s.last_accessed_at,
"is_expired": s.is_expired,
"time_until_expiry": s.time_until_expiry
}
for s in sessions
]
}
@router.post("/refresh", response_model=Token, summary="Refresh access token")
async def refresh_token(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Refresh access token before expiration
Re-authenticates with external API using stored session.
Note: Since external API doesn't provide refresh tokens,
we re-issue internal JWT tokens with extended expiry.
"""
try:
# Find user's most recent session
session = db.query(UserSession).filter(
UserSession.user_id == current_user.id
).order_by(UserSession.created_at.desc()).first()
if not session:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No active session found"
)
# Check if token is expiring soon (within TOKEN_REFRESH_BUFFER)
if not external_auth_service.is_token_expiring_soon(session.expires_at):
# Token still valid for a while, just issue new internal JWT
internal_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
internal_access_token = create_access_token(
data={
"sub": str(current_user.id),
"email": current_user.email,
"session_id": session.id
},
expires_delta=internal_token_expires
)
logger.info(f"Refreshed internal token for user {current_user.email}")
return {
"access_token": internal_access_token,
"token_type": "bearer",
"expires_in": int(internal_token_expires.total_seconds()),
"user": {
"id": current_user.id,
"email": current_user.email,
"display_name": current_user.display_name
}
}
# External token expiring soon - would need re-authentication
# For now, we extend internal token and log a warning
logger.warning(
f"External token expiring soon for user {current_user.email}. "
"User should re-authenticate."
)
internal_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
internal_access_token = create_access_token(
data={
"sub": str(current_user.id),
"email": current_user.email,
"session_id": session.id
},
expires_delta=internal_token_expires
)
return {
"access_token": internal_access_token,
"token_type": "bearer",
"expires_in": int(internal_token_expires.total_seconds()),
"user": {
"id": current_user.id,
"email": current_user.email,
"display_name": current_user.display_name
}
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"Token refresh failed for user {current_user.id}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Token refresh failed: {str(e)}"
)

View File

@@ -1,347 +0,0 @@
"""
Tool_OCR - External Authentication Router (V2)
Handles authentication via external Microsoft Azure AD API
"""
from datetime import datetime, timedelta
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.deps import get_db, get_current_user_v2
from app.core.security import create_access_token
from app.models.user_v2 import User
from app.models.session import Session as UserSession
from app.schemas.auth import LoginRequest, Token, UserResponse
from app.services.external_auth_service import external_auth_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v2/auth", tags=["Authentication V2"])
def get_client_ip(request: Request) -> str:
"""Extract client IP address from request"""
# Check X-Forwarded-For header (for proxies)
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
# Check X-Real-IP header
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fallback to direct client
return request.client.host if request.client else "unknown"
def get_user_agent(request: Request) -> str:
"""Extract user agent from request"""
return request.headers.get("User-Agent", "unknown")[:500]
@router.post("/login", response_model=Token, summary="External API login")
async def login(
login_data: LoginRequest,
request: Request,
db: Session = Depends(get_db)
):
"""
User login via external Microsoft Azure AD API
Returns JWT access token and stores session information
- **username**: User's email address
- **password**: User's password
"""
# Call external authentication API
success, auth_response, error_msg = await external_auth_service.authenticate_user(
username=login_data.username,
password=login_data.password
)
if not success or not auth_response:
logger.warning(
f"External auth failed for user {login_data.username}: {error_msg}"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=error_msg or "Authentication failed",
headers={"WWW-Authenticate": "Bearer"},
)
# Extract user info from external API response
user_info = auth_response.user_info
email = user_info.email
display_name = user_info.name
# Find or create user in database
user = db.query(User).filter(User.email == email).first()
if not user:
# Create new user
user = User(
email=email,
display_name=display_name,
is_active=True,
last_login=datetime.utcnow()
)
db.add(user)
db.commit()
db.refresh(user)
logger.info(f"Created new user: {email} (ID: {user.id})")
else:
# Update existing user
user.display_name = display_name
user.last_login = datetime.utcnow()
# Check if user is active
if not user.is_active:
logger.warning(f"Inactive user login attempt: {email}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User account is inactive"
)
db.commit()
db.refresh(user)
logger.info(f"Updated existing user: {email} (ID: {user.id})")
# Parse token expiration
try:
expires_at = datetime.fromisoformat(auth_response.expires_at.replace('Z', '+00:00'))
issued_at = datetime.fromisoformat(auth_response.issued_at.replace('Z', '+00:00'))
except Exception as e:
logger.error(f"Failed to parse token timestamps: {e}")
expires_at = datetime.utcnow() + timedelta(seconds=auth_response.expires_in)
issued_at = datetime.utcnow()
# Create session in database
# TODO: Implement token encryption before storing
session = UserSession(
user_id=user.id,
access_token=auth_response.access_token, # Should be encrypted
id_token=auth_response.id_token, # Should be encrypted
token_type=auth_response.token_type,
expires_at=expires_at,
issued_at=issued_at,
ip_address=get_client_ip(request),
user_agent=get_user_agent(request)
)
db.add(session)
db.commit()
db.refresh(session)
logger.info(
f"Created session {session.id} for user {user.email} "
f"(expires: {expires_at})"
)
# Create internal JWT token for API access
# This token contains user ID and session ID
internal_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
internal_access_token = create_access_token(
data={
"sub": str(user.id),
"email": user.email,
"session_id": session.id
},
expires_delta=internal_token_expires
)
return {
"access_token": internal_access_token,
"token_type": "bearer",
"expires_in": int(internal_token_expires.total_seconds()),
"user": {
"id": user.id,
"email": user.email,
"display_name": user.display_name
}
}
@router.post("/logout", summary="User logout")
async def logout(
session_id: Optional[int] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
):
"""
User logout - invalidates session
- **session_id**: Session ID to logout (optional, logs out all if not provided)
"""
# TODO: Implement proper current_user dependency from JWT token
# For now, this is a placeholder
if session_id:
# Logout specific session
session = db.query(UserSession).filter(
UserSession.id == session_id,
UserSession.user_id == current_user.id
).first()
if session:
db.delete(session)
db.commit()
logger.info(f"Logged out session {session_id} for user {current_user.email}")
return {"message": "Logged out successfully"}
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found"
)
else:
# Logout all sessions
sessions = db.query(UserSession).filter(
UserSession.user_id == current_user.id
).all()
count = len(sessions)
for session in sessions:
db.delete(session)
db.commit()
logger.info(f"Logged out all {count} sessions for user {current_user.email}")
return {"message": f"Logged out {count} sessions"}
@router.get("/me", response_model=UserResponse, summary="Get current user")
async def get_me(
current_user: User = Depends(get_current_user_v2)
):
"""
Get current authenticated user information
"""
# TODO: Implement proper current_user dependency from JWT token
return {
"id": current_user.id,
"email": current_user.email,
"display_name": current_user.display_name,
"created_at": current_user.created_at,
"last_login": current_user.last_login,
"is_active": current_user.is_active
}
@router.get("/sessions", summary="List user sessions")
async def list_sessions(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
):
"""
List all active sessions for current user
"""
sessions = db.query(UserSession).filter(
UserSession.user_id == current_user.id
).order_by(UserSession.created_at.desc()).all()
return {
"sessions": [
{
"id": s.id,
"token_type": s.token_type,
"expires_at": s.expires_at,
"issued_at": s.issued_at,
"ip_address": s.ip_address,
"user_agent": s.user_agent,
"created_at": s.created_at,
"last_accessed_at": s.last_accessed_at,
"is_expired": s.is_expired,
"time_until_expiry": s.time_until_expiry
}
for s in sessions
]
}
@router.post("/refresh", response_model=Token, summary="Refresh access token")
async def refresh_token(
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
):
"""
Refresh access token before expiration
Re-authenticates with external API using stored session.
Note: Since external API doesn't provide refresh tokens,
we re-issue internal JWT tokens with extended expiry.
"""
try:
# Find user's most recent session
session = db.query(UserSession).filter(
UserSession.user_id == current_user.id
).order_by(UserSession.created_at.desc()).first()
if not session:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No active session found"
)
# Check if token is expiring soon (within TOKEN_REFRESH_BUFFER)
if not external_auth_service.is_token_expiring_soon(session.expires_at):
# Token still valid for a while, just issue new internal JWT
internal_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
internal_access_token = create_access_token(
data={
"sub": str(current_user.id),
"email": current_user.email,
"session_id": session.id
},
expires_delta=internal_token_expires
)
logger.info(f"Refreshed internal token for user {current_user.email}")
return {
"access_token": internal_access_token,
"token_type": "bearer",
"expires_in": int(internal_token_expires.total_seconds()),
"user": {
"id": current_user.id,
"email": current_user.email,
"display_name": current_user.display_name
}
}
# External token expiring soon - would need re-authentication
# For now, we extend internal token and log a warning
logger.warning(
f"External token expiring soon for user {current_user.email}. "
"User should re-authenticate."
)
internal_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
internal_access_token = create_access_token(
data={
"sub": str(current_user.id),
"email": current_user.email,
"session_id": session.id
},
expires_delta=internal_token_expires
)
return {
"access_token": internal_access_token,
"token_type": "bearer",
"expires_in": int(internal_token_expires.total_seconds()),
"user": {
"id": current_user.id,
"email": current_user.email,
"display_name": current_user.display_name
}
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"Token refresh failed for user {current_user.id}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Token refresh failed: {str(e)}"
)

View File

@@ -1,338 +0,0 @@
"""
Tool_OCR - Export Router
Export results in multiple formats
"""
import logging
from typing import List
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from app.core.deps import get_db, get_current_active_user
from app.models.user import User
from app.models.ocr import OCRBatch, OCRFile, OCRResult, FileStatus
from app.models.export import ExportRule
from app.schemas.export import (
ExportRequest,
ExportRuleCreate,
ExportRuleUpdate,
ExportRuleResponse,
CSSTemplateResponse,
)
from app.services.export_service import ExportService, ExportError
from app.services.pdf_generator import PDFGenerator
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/export", tags=["Export"])
# Initialize services
export_service = ExportService()
pdf_generator = PDFGenerator()
@router.post("", summary="Export OCR results")
async def export_results(
request: ExportRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
Export OCR results in specified format
Supports multiple export formats: txt, json, excel, markdown, pdf, zip
- **batch_id**: Batch ID to export
- **format**: Export format (txt, json, excel, markdown, pdf, zip)
- **rule_id**: Optional export rule ID to apply filters
- **css_template**: CSS template for PDF export (default, academic, business)
- **include_formats**: Formats to include in ZIP export
"""
# Verify batch ownership
batch = db.query(OCRBatch).filter(
OCRBatch.id == request.batch_id,
OCRBatch.user_id == current_user.id
).first()
if not batch:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Batch not found"
)
# Get completed results
results = db.query(OCRResult).join(OCRFile).filter(
OCRFile.batch_id == request.batch_id,
OCRFile.status == FileStatus.COMPLETED
).all()
if not results:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No completed results found for this batch"
)
# Apply export rule if specified
if request.rule_id:
try:
results = export_service.apply_export_rule(db, results, request.rule_id)
except ExportError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e)
)
try:
# Generate export based on format
export_dir = Path(f"uploads/batches/{batch.id}/exports")
export_dir.mkdir(parents=True, exist_ok=True)
if request.format == "txt":
output_path = export_dir / f"batch_{batch.id}_export.txt"
export_service.export_to_txt(results, output_path)
elif request.format == "json":
output_path = export_dir / f"batch_{batch.id}_export.json"
export_service.export_to_json(results, output_path)
elif request.format == "excel":
output_path = export_dir / f"batch_{batch.id}_export.xlsx"
export_service.export_to_excel(results, output_path)
elif request.format == "markdown":
output_path = export_dir / f"batch_{batch.id}_export.md"
export_service.export_to_markdown(results, output_path, combine=True)
elif request.format == "zip":
output_path = export_dir / f"batch_{batch.id}_export.zip"
include_formats = request.include_formats or ["markdown", "json"]
export_service.export_batch_to_zip(db, batch.id, output_path, include_formats)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported export format: {request.format}"
)
logger.info(f"Exported batch {batch.id} to {request.format} format: {output_path}")
# Return file for download
return FileResponse(
path=str(output_path),
filename=output_path.name,
media_type="application/octet-stream"
)
except ExportError as e:
logger.error(f"Export error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
except Exception as e:
logger.error(f"Unexpected export error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Export failed"
)
@router.get("/pdf/{file_id}", summary="Generate PDF for single file")
async def generate_pdf(
file_id: int,
css_template: str = "default",
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
Generate layout-preserved PDF for a single file
- **file_id**: File ID
- **css_template**: CSS template (default, academic, business)
"""
# Get file and verify ownership
ocr_file = db.query(OCRFile).join(OCRBatch).filter(
OCRFile.id == file_id,
OCRBatch.user_id == current_user.id
).first()
if not ocr_file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found"
)
# Get result
result = db.query(OCRResult).filter(OCRResult.file_id == file_id).first()
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OCR result not found"
)
try:
# Generate PDF
export_dir = Path(f"uploads/batches/{ocr_file.batch_id}/exports")
export_dir.mkdir(parents=True, exist_ok=True)
output_path = export_dir / f"file_{file_id}_export.pdf"
export_service.export_to_pdf(
result=result,
output_path=output_path,
css_template=css_template,
metadata={"title": ocr_file.original_filename}
)
logger.info(f"Generated PDF for file {file_id}: {output_path}")
return FileResponse(
path=str(output_path),
filename=f"{Path(ocr_file.original_filename).stem}.pdf",
media_type="application/pdf"
)
except ExportError as e:
logger.error(f"PDF generation error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
@router.get("/rules", response_model=List[ExportRuleResponse], summary="List export rules")
async def list_export_rules(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
List all export rules for current user
Returns list of saved export rules
"""
rules = db.query(ExportRule).filter(ExportRule.user_id == current_user.id).all()
return rules
@router.post("/rules", response_model=ExportRuleResponse, summary="Create export rule")
async def create_export_rule(
rule: ExportRuleCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
Create new export rule
Saves custom export configuration for reuse
- **rule_name**: Rule name
- **description**: Optional description
- **config_json**: Rule configuration (filters, formatting, export_options)
- **css_template**: Optional custom CSS for PDF export
"""
# Create rule
new_rule = ExportRule(
user_id=current_user.id,
rule_name=rule.rule_name,
description=rule.description,
config_json=rule.config_json,
css_template=rule.css_template
)
db.add(new_rule)
db.commit()
db.refresh(new_rule)
logger.info(f"Created export rule {new_rule.id} for user {current_user.id}")
return new_rule
@router.put("/rules/{rule_id}", response_model=ExportRuleResponse, summary="Update export rule")
async def update_export_rule(
rule_id: int,
rule: ExportRuleUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
Update existing export rule
- **rule_id**: Rule ID to update
- **rule_name**: Optional new rule name
- **description**: Optional new description
- **config_json**: Optional new configuration
- **css_template**: Optional new CSS template
"""
# Get rule and verify ownership
db_rule = db.query(ExportRule).filter(
ExportRule.id == rule_id,
ExportRule.user_id == current_user.id
).first()
if not db_rule:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Export rule not found"
)
# Update fields
update_data = rule.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(db_rule, field, value)
db.commit()
db.refresh(db_rule)
logger.info(f"Updated export rule {rule_id}")
return db_rule
@router.delete("/rules/{rule_id}", summary="Delete export rule")
async def delete_export_rule(
rule_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
Delete export rule
- **rule_id**: Rule ID to delete
"""
# Get rule and verify ownership
db_rule = db.query(ExportRule).filter(
ExportRule.id == rule_id,
ExportRule.user_id == current_user.id
).first()
if not db_rule:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Export rule not found"
)
db.delete(db_rule)
db.commit()
logger.info(f"Deleted export rule {rule_id}")
return {"message": "Export rule deleted successfully"}
@router.get("/css-templates", response_model=List[CSSTemplateResponse], summary="List CSS templates")
async def list_css_templates():
"""
List available CSS templates for PDF generation
Returns list of predefined CSS templates with descriptions
"""
templates = pdf_generator.get_available_templates()
return [
{"name": name, "description": desc}
for name, desc in templates.items()
]

View File

@@ -1,244 +0,0 @@
"""
Tool_OCR - OCR Router
File upload, OCR processing, and status endpoints
"""
import logging
from typing import List
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, BackgroundTasks
from sqlalchemy.orm import Session
from app.core.deps import get_db, get_current_active_user
from app.models.user import User
from app.models.ocr import OCRBatch, OCRFile, OCRResult, BatchStatus, FileStatus
from app.schemas.ocr import (
OCRBatchResponse,
BatchStatusResponse,
FileStatusResponse,
OCRResultDetailResponse,
UploadBatchResponse,
ProcessRequest,
ProcessResponse,
)
from app.services.file_manager import FileManager, FileManagementError
from app.services.ocr_service import OCRService
from app.services.background_tasks import process_batch_files_with_retry
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1", tags=["OCR"])
# Initialize services
file_manager = FileManager()
ocr_service = OCRService()
@router.post("/upload", response_model=UploadBatchResponse, summary="Upload files for OCR")
async def upload_files(
files: List[UploadFile] = File(..., description="Files to upload (PNG, JPG, PDF)"),
batch_name: str = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
Upload files for OCR processing
Creates a new batch and uploads files to it
- **files**: List of files to upload (PNG, JPG, JPEG, PDF)
- **batch_name**: Optional name for the batch
"""
if not files:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No files provided"
)
try:
# Create batch
batch = file_manager.create_batch(db, current_user.id, batch_name)
# Upload files
uploaded_files = file_manager.add_files_to_batch(db, batch.id, files)
logger.info(f"Uploaded {len(uploaded_files)} files to batch {batch.id} for user {current_user.id}")
# Refresh batch to get updated counts
db.refresh(batch)
# Return response matching frontend expectations
return {
"batch_id": batch.id,
"files": uploaded_files
}
except FileManagementError as e:
logger.error(f"File upload error: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Unexpected error during upload: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to upload files"
)
# NOTE: process_batch_files function moved to app.services.background_tasks
# Now using process_batch_files_with_retry with retry logic
@router.post("/ocr/process", response_model=ProcessResponse, summary="Trigger OCR processing")
async def process_ocr(
request: ProcessRequest,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
Trigger OCR processing for a batch
Starts background processing of all files in the batch
- **batch_id**: Batch ID to process
- **lang**: Language code (ch, en, japan, korean)
- **detect_layout**: Enable layout detection
"""
# Verify batch ownership
batch = db.query(OCRBatch).filter(
OCRBatch.id == request.batch_id,
OCRBatch.user_id == current_user.id
).first()
if not batch:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Batch not found"
)
if batch.status != BatchStatus.PENDING:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Batch is already {batch.status.value}"
)
# Start background processing with retry logic
background_tasks.add_task(
process_batch_files_with_retry,
batch_id=batch.id,
lang=request.lang,
detect_layout=request.detect_layout,
db=SessionLocal() # Create new session for background task
)
logger.info(f"Started OCR processing for batch {batch.id}")
return {
"message": "OCR processing started",
"batch_id": batch.id,
"total_files": batch.total_files,
"status": "processing"
}
@router.get("/batch/{batch_id}/status", response_model=BatchStatusResponse, summary="Get batch status")
async def get_batch_status(
batch_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
Get batch processing status
Returns batch information and all files in the batch
- **batch_id**: Batch ID
"""
# Verify batch ownership
batch = db.query(OCRBatch).filter(
OCRBatch.id == batch_id,
OCRBatch.user_id == current_user.id
).first()
if not batch:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Batch not found"
)
# Get all files in batch
files = db.query(OCRFile).filter(OCRFile.batch_id == batch_id).all()
return {
"batch": batch,
"files": files
}
@router.get("/ocr/result/{file_id}", response_model=OCRResultDetailResponse, summary="Get OCR result")
async def get_ocr_result(
file_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
Get OCR result for a file
Returns flattened file and OCR result information for frontend preview
- **file_id**: File ID
"""
# Get file
ocr_file = db.query(OCRFile).join(OCRBatch).filter(
OCRFile.id == file_id,
OCRBatch.user_id == current_user.id
).first()
if not ocr_file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found"
)
# Get result if exists
result = db.query(OCRResult).filter(OCRResult.file_id == file_id).first()
# Read markdown content if result exists
markdown_content = None
if result and result.markdown_path:
markdown_file = Path(result.markdown_path)
if markdown_file.exists():
try:
markdown_content = markdown_file.read_text(encoding='utf-8')
except Exception as e:
logger.warning(f"Failed to read markdown file {result.markdown_path}: {e}")
# Build JSON data from result if available
json_data = None
if result:
json_data = {
"total_text_regions": result.total_text_regions,
"average_confidence": result.average_confidence,
"detected_language": result.detected_language,
"layout_data": result.layout_data,
"images_metadata": result.images_metadata,
}
# Return flattened structure matching frontend expectations
return {
"file_id": ocr_file.id,
"filename": ocr_file.filename,
"status": ocr_file.status.value,
"markdown_content": markdown_content,
"json_data": json_data,
"confidence": result.average_confidence if result else None,
"processing_time": ocr_file.processing_time,
}
# Import SessionLocal for background tasks
from app.core.database import SessionLocal

View File

@@ -10,8 +10,8 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from app.core.deps import get_db, get_current_user_v2
from app.models.user_v2 import User
from app.core.deps import get_db, get_current_user
from app.models.user import User
from app.models.task import TaskStatus
from app.schemas.task import (
TaskCreate,
@@ -34,7 +34,7 @@ router = APIRouter(prefix="/api/v2/tasks", tags=["Tasks"])
async def create_task(
task_data: TaskCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
current_user: User = Depends(get_current_user)
):
"""
Create a new OCR task
@@ -72,7 +72,7 @@ async def list_tasks(
order_by: str = Query("created_at"),
order_desc: bool = Query(True),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
current_user: User = Depends(get_current_user)
):
"""
List user's tasks with pagination and filtering
@@ -134,7 +134,7 @@ async def list_tasks(
@router.get("/stats", response_model=TaskStatsResponse)
async def get_task_stats(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
current_user: User = Depends(get_current_user)
):
"""
Get task statistics for current user
@@ -157,7 +157,7 @@ async def get_task_stats(
async def get_task(
task_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
current_user: User = Depends(get_current_user)
):
"""
Get task details by ID
@@ -184,7 +184,7 @@ async def update_task(
task_id: str,
task_update: TaskUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
current_user: User = Depends(get_current_user)
):
"""
Update task status and results
@@ -253,7 +253,7 @@ async def update_task(
async def delete_task(
task_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
current_user: User = Depends(get_current_user)
):
"""
Delete a task
@@ -280,7 +280,7 @@ async def delete_task(
async def download_json(
task_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
current_user: User = Depends(get_current_user)
):
"""
Download task result as JSON file
@@ -327,7 +327,7 @@ async def download_json(
async def download_markdown(
task_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
current_user: User = Depends(get_current_user)
):
"""
Download task result as Markdown file
@@ -374,7 +374,7 @@ async def download_markdown(
async def download_pdf(
task_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
current_user: User = Depends(get_current_user)
):
"""
Download task result as searchable PDF file
@@ -421,7 +421,7 @@ async def download_pdf(
async def start_task(
task_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
current_user: User = Depends(get_current_user)
):
"""
Start processing a pending task
@@ -459,7 +459,7 @@ async def start_task(
async def cancel_task(
task_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
current_user: User = Depends(get_current_user)
):
"""
Cancel a pending or processing task
@@ -513,7 +513,7 @@ async def cancel_task(
async def retry_task(
task_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user_v2)
current_user: User = Depends(get_current_user)
):
"""
Retry a failed task

View File

@@ -1,189 +0,0 @@
"""
Tool_OCR - Translation Router (RESERVED)
Stub endpoints for future translation feature
"""
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.core.deps import get_db, get_current_active_user
from app.models.user import User
from app.schemas.translation import (
TranslationRequest,
TranslationResponse,
TranslationFeatureStatus,
LanguageInfo,
)
from app.services.translation_service import StubTranslationService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/translate", tags=["Translation (RESERVED)"])
@router.get("/status", response_model=TranslationFeatureStatus, summary="Get translation feature status")
async def get_translation_status():
"""
Get translation feature status
Returns current implementation status and roadmap for translation feature.
This is a RESERVED feature that will be implemented in Phase 5.
**Status**: RESERVED - Not yet implemented
**Phase**: Phase 5 (Post-production)
**Priority**: Implemented after production deployment and user feedback
"""
return StubTranslationService.get_feature_status()
@router.get("/languages", response_model=List[LanguageInfo], summary="Get supported languages")
async def get_supported_languages():
"""
Get list of languages planned for translation support
Returns list of languages that will be supported when translation
feature is implemented.
**Status**: RESERVED - Planning phase
"""
return StubTranslationService.get_supported_languages()
@router.post("/document", response_model=TranslationResponse, summary="Translate document (RESERVED)")
async def translate_document(
request: TranslationRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
Translate OCR document (RESERVED - NOT IMPLEMENTED)
This endpoint is reserved for future translation functionality.
Returns 501 Not Implemented status.
**Expected Functionality** (when implemented):
- Translate markdown documents while preserving structure
- Support multiple translation engines (offline, ERNIE, Google, DeepL)
- Maintain layout and formatting
- Handle technical terminology
**Planned Features**:
- Offline translation (Argos Translate)
- Cloud API integration (ERNIE, Google, DeepL)
- Batch translation support
- Translation memory
- Glossary support
**Current Status**: RESERVED for Phase 5 implementation
---
**Request Parameters** (planned):
- **file_id**: ID of OCR result file to translate
- **source_lang**: Source language code (zh, en, ja, ko)
- **target_lang**: Target language code (zh, en, ja, ko)
- **engine_type**: Translation engine (offline, ernie, google, deepl)
- **preserve_structure**: Whether to preserve markdown structure
- **engine_config**: Engine-specific configuration
**Response** (planned):
- **task_id**: Translation task ID for tracking progress
- **status**: Translation status
- **translated_file_path**: Path to translated file (when completed)
"""
logger.info(f"Translation request received from user {current_user.id} (stub endpoint)")
# Return 501 Not Implemented with informative message
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail={
"error": "Translation feature not implemented",
"message": "This feature is reserved for future development (Phase 5)",
"status": "RESERVED",
"roadmap": {
"phase": "Phase 5",
"priority": "Implemented after production deployment",
"planned_features": [
"Offline translation (Argos Translate)",
"Cloud API integration (ERNIE, Google, DeepL)",
"Structure-preserving markdown translation",
"Batch translation support"
]
},
"request_received": {
"file_id": request.file_id,
"source_lang": request.source_lang,
"target_lang": request.target_lang,
"engine_type": request.engine_type
},
"action": "Please check back in a future release or contact support for updates"
}
)
@router.get("/task/{task_id}", summary="Get translation task status (RESERVED)")
async def get_translation_task_status(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
Get translation task status (RESERVED - NOT IMPLEMENTED)
This endpoint would track translation task progress.
Returns 501 Not Implemented status.
**Planned Functionality**:
- Real-time translation progress
- Status updates (pending, processing, completed, failed)
- Estimated completion time
- Error reporting
**Current Status**: RESERVED for Phase 5 implementation
"""
logger.info(f"Translation status check for task {task_id} from user {current_user.id} (stub endpoint)")
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail={
"error": "Translation feature not implemented",
"message": "Translation task tracking is reserved for Phase 5",
"task_id": task_id,
"status": "RESERVED"
}
)
@router.delete("/task/{task_id}", summary="Cancel translation task (RESERVED)")
async def cancel_translation_task(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""
Cancel ongoing translation task (RESERVED - NOT IMPLEMENTED)
This endpoint would allow cancellation of translation tasks.
Returns 501 Not Implemented status.
**Planned Functionality**:
- Cancel in-progress translations
- Clean up temporary files
- Refund credits (if applicable)
**Current Status**: RESERVED for Phase 5 implementation
"""
logger.info(f"Translation cancellation request for task {task_id} from user {current_user.id} (stub endpoint)")
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail={
"error": "Translation feature not implemented",
"message": "This feature is reserved for Phase 5",
"status": "RESERVED"
}
)

View File

@@ -1,59 +1,30 @@
"""
Tool_OCR - API Schemas
Tool_OCR - API Schemas (V2)
Pydantic models for request/response validation
"""
from app.schemas.auth import Token, TokenData, LoginRequest
from app.schemas.user import UserBase, UserCreate, UserResponse
from app.schemas.ocr import (
OCRBatchResponse,
OCRFileResponse,
OCRResultResponse,
BatchStatusResponse,
FileStatusResponse,
ProcessRequest,
ProcessResponse,
)
from app.schemas.export import (
ExportRequest,
ExportRuleCreate,
ExportRuleUpdate,
ExportRuleResponse,
CSSTemplateResponse,
)
from app.schemas.translation import (
TranslationRequest,
TranslationResponse,
TranslationFeatureStatus,
LanguageInfo,
from app.schemas.auth import LoginRequest, Token, UserResponse
from app.schemas.task import (
TaskCreate,
TaskUpdate,
TaskResponse,
TaskDetailResponse,
TaskListResponse,
TaskStatsResponse,
TaskStatusEnum,
)
__all__ = [
# Auth
"Token",
"TokenData",
"LoginRequest",
# User
"UserBase",
"UserCreate",
"Token",
"UserResponse",
# OCR
"OCRBatchResponse",
"OCRFileResponse",
"OCRResultResponse",
"BatchStatusResponse",
"FileStatusResponse",
"ProcessRequest",
"ProcessResponse",
# Export
"ExportRequest",
"ExportRuleCreate",
"ExportRuleUpdate",
"ExportRuleResponse",
"CSSTemplateResponse",
# Translation (RESERVED)
"TranslationRequest",
"TranslationResponse",
"TranslationFeatureStatus",
"LanguageInfo",
# Task
"TaskCreate",
"TaskUpdate",
"TaskResponse",
"TaskDetailResponse",
"TaskListResponse",
"TaskStatsResponse",
"TaskStatusEnum",
]

View File

@@ -1,104 +0,0 @@
"""
Tool_OCR - Export Schemas
"""
from datetime import datetime
from typing import Optional, Dict, Any, List
from pydantic import BaseModel, Field
class ExportOptions(BaseModel):
"""Export options schema"""
confidence_threshold: Optional[float] = Field(None, description="Minimum confidence threshold")
include_metadata: Optional[bool] = Field(True, description="Include metadata in export")
filename_pattern: Optional[str] = Field(None, description="Filename pattern for export")
css_template: Optional[str] = Field(None, description="CSS template for PDF export")
class ExportRequest(BaseModel):
"""Export request schema"""
batch_id: int = Field(..., description="Batch ID to export")
format: str = Field(..., description="Export format (txt, json, excel, markdown, pdf, zip)")
rule_id: Optional[int] = Field(None, description="Optional export rule ID to apply")
css_template: Optional[str] = Field("default", description="CSS template for PDF export")
include_formats: Optional[List[str]] = Field(None, description="Formats to include in ZIP export")
options: Optional[ExportOptions] = Field(None, description="Additional export options")
class Config:
json_schema_extra = {
"example": {
"batch_id": 1,
"format": "pdf",
"rule_id": None,
"css_template": "default",
"include_formats": ["markdown", "json"],
"options": {
"confidence_threshold": 0.8,
"include_metadata": True
}
}
}
class ExportRuleCreate(BaseModel):
"""Export rule creation schema"""
rule_name: str = Field(..., max_length=100, description="Rule name")
description: Optional[str] = Field(None, description="Rule description")
config_json: Dict[str, Any] = Field(..., description="Rule configuration as JSON")
css_template: Optional[str] = Field(None, description="Custom CSS template")
class Config:
json_schema_extra = {
"example": {
"rule_name": "High Confidence Only",
"description": "Export only results with confidence > 0.8",
"config_json": {
"filters": {
"confidence_threshold": 0.8
},
"formatting": {
"add_line_numbers": True
}
},
"css_template": None
}
}
class ExportRuleUpdate(BaseModel):
"""Export rule update schema"""
rule_name: Optional[str] = Field(None, max_length=100)
description: Optional[str] = None
config_json: Optional[Dict[str, Any]] = None
css_template: Optional[str] = None
class ExportRuleResponse(BaseModel):
"""Export rule response schema"""
id: int
user_id: int
rule_name: str
description: Optional[str] = None
config_json: Dict[str, Any]
css_template: Optional[str] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class CSSTemplateResponse(BaseModel):
"""CSS template response schema"""
name: str = Field(..., description="Template name")
description: str = Field(..., description="Template description")
filename: str = Field(..., description="Template filename")
class Config:
json_schema_extra = {
"example": {
"name": "default",
"description": "通用排版模板,適合大多數文檔",
"filename": "default.css"
}
}

View File

@@ -1,151 +0,0 @@
"""
Tool_OCR - OCR Schemas
"""
from datetime import datetime
from typing import Optional, Dict, List, Any
from pydantic import BaseModel, Field
from app.models.ocr import BatchStatus, FileStatus
class OCRFileResponse(BaseModel):
"""OCR file response schema"""
id: int
batch_id: int
filename: str
original_filename: str
file_size: int
file_format: str
status: FileStatus
error: Optional[str] = Field(None, validation_alias='error_message') # Map from error_message to error
created_at: datetime
processing_time: Optional[float] = None
class Config:
from_attributes = True
populate_by_name = True
class OCRResultResponse(BaseModel):
"""OCR result response schema"""
id: int
file_id: int
markdown_path: Optional[str] = None
markdown_content: Optional[str] = None # Added for frontend preview
json_path: Optional[str] = None
images_dir: Optional[str] = None
detected_language: Optional[str] = None
total_text_regions: int
average_confidence: Optional[float] = None
layout_data: Optional[Dict[str, Any]] = None
images_metadata: Optional[List[Dict[str, Any]]] = None
created_at: datetime
class Config:
from_attributes = True
class OCRBatchResponse(BaseModel):
"""OCR batch response schema"""
id: int
user_id: int
batch_name: Optional[str] = None
status: BatchStatus
total_files: int
completed_files: int
failed_files: int
progress_percentage: float
created_at: datetime
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class Config:
from_attributes = True
class BatchStatusResponse(BaseModel):
"""Batch status with file details"""
batch: OCRBatchResponse
files: List[OCRFileResponse]
class FileStatusResponse(BaseModel):
"""File status with result details"""
file: OCRFileResponse
result: Optional[OCRResultResponse] = None
class OCRResultDetailResponse(BaseModel):
"""OCR result detail response for frontend preview - flattened structure"""
file_id: int
filename: str
status: str
markdown_content: Optional[str] = None
json_data: Optional[Dict[str, Any]] = None
confidence: Optional[float] = None
processing_time: Optional[float] = None
class Config:
from_attributes = True
class UploadBatchResponse(BaseModel):
"""Upload response schema matching frontend expectations"""
batch_id: int = Field(..., description="Batch ID")
files: List[OCRFileResponse] = Field(..., description="Uploaded files")
class Config:
json_schema_extra = {
"example": {
"batch_id": 1,
"files": [
{
"id": 1,
"batch_id": 1,
"filename": "doc_1.png",
"original_filename": "document.png",
"file_size": 1024000,
"file_format": "png",
"status": "pending",
"error_message": None,
"created_at": "2025-01-01T00:00:00",
"processing_time": None
}
]
}
}
class ProcessRequest(BaseModel):
"""OCR process request schema"""
batch_id: int = Field(..., description="Batch ID to process")
lang: str = Field(default="ch", description="Language code (ch, en, japan, korean)")
detect_layout: bool = Field(default=True, description="Enable layout detection")
class Config:
json_schema_extra = {
"example": {
"batch_id": 1,
"lang": "ch",
"detect_layout": True
}
}
class ProcessResponse(BaseModel):
"""OCR process response schema"""
message: str
batch_id: int
total_files: int
status: str
class Config:
json_schema_extra = {
"example": {
"message": "OCR processing started",
"batch_id": 1,
"total_files": 5,
"status": "processing"
}
}

View File

@@ -1,124 +0,0 @@
"""
Tool_OCR - Translation Schemas (RESERVED)
Request/response models for translation endpoints
"""
from typing import Optional, Dict, List, Any
from pydantic import BaseModel, Field
class TranslationRequest(BaseModel):
"""
Translation request schema (RESERVED)
Expected format for document translation requests
"""
file_id: int = Field(..., description="File ID to translate")
source_lang: str = Field(..., description="Source language code (zh, en, ja, ko)")
target_lang: str = Field(..., description="Target language code (zh, en, ja, ko)")
engine_type: Optional[str] = Field("offline", description="Translation engine (offline, ernie, google, deepl)")
preserve_structure: bool = Field(True, description="Preserve markdown structure")
engine_config: Optional[Dict[str, Any]] = Field(None, description="Engine-specific configuration")
class Config:
json_schema_extra = {
"example": {
"file_id": 1,
"source_lang": "zh",
"target_lang": "en",
"engine_type": "offline",
"preserve_structure": True,
"engine_config": {}
}
}
class TranslationResponse(BaseModel):
"""
Translation response schema (RESERVED)
Expected format for translation results
"""
task_id: int = Field(..., description="Translation task ID")
file_id: int
source_lang: str
target_lang: str
engine_type: str
status: str = Field(..., description="Translation status (pending, processing, completed, failed)")
translated_file_path: Optional[str] = Field(None, description="Path to translated markdown file")
progress: float = Field(0.0, description="Translation progress (0.0-1.0)")
error_message: Optional[str] = None
class Config:
json_schema_extra = {
"example": {
"task_id": 1,
"file_id": 1,
"source_lang": "zh",
"target_lang": "en",
"engine_type": "offline",
"status": "processing",
"translated_file_path": None,
"progress": 0.5,
"error_message": None
}
}
class TranslationStatusResponse(BaseModel):
"""Translation task status response (RESERVED)"""
task_id: int
status: str
progress: float
created_at: str
completed_at: Optional[str] = None
error_message: Optional[str] = None
class TranslationConfigRequest(BaseModel):
"""Translation configuration request (RESERVED)"""
source_lang: str = Field(..., max_length=20)
target_lang: str = Field(..., max_length=20)
engine_type: str = Field(..., max_length=50)
engine_config: Optional[Dict[str, Any]] = None
class Config:
json_schema_extra = {
"example": {
"source_lang": "zh",
"target_lang": "en",
"engine_type": "offline",
"engine_config": {
"model_path": "/path/to/model"
}
}
}
class TranslationConfigResponse(BaseModel):
"""Translation configuration response (RESERVED)"""
id: int
user_id: int
source_lang: str
target_lang: str
engine_type: str
engine_config: Optional[Dict[str, Any]] = None
created_at: str
updated_at: str
class TranslationFeatureStatus(BaseModel):
"""Translation feature status response"""
available: bool = Field(..., description="Whether translation is available")
status: str = Field(..., description="Feature status (reserved, planned, implemented)")
message: str = Field(..., description="Status message")
supported_engines: List[str] = Field(default_factory=list, description="Currently supported engines")
planned_engines: List[Dict[str, str]] = Field(default_factory=list, description="Planned engines")
roadmap: Dict[str, Any] = Field(default_factory=dict, description="Implementation roadmap")
class LanguageInfo(BaseModel):
"""Language information"""
code: str = Field(..., description="Language code (ISO 639-1)")
name: str = Field(..., description="Language name")
status: str = Field(..., description="Support status (planned, supported)")

View File

@@ -1,53 +0,0 @@
"""
Tool_OCR - User Schemas
"""
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, EmailStr, Field
class UserBase(BaseModel):
"""Base user schema"""
username: str = Field(..., min_length=3, max_length=50)
email: EmailStr
full_name: Optional[str] = Field(None, max_length=100)
class UserCreate(UserBase):
"""User creation schema"""
password: str = Field(..., min_length=6, description="Password (min 6 characters)")
class Config:
json_schema_extra = {
"example": {
"username": "johndoe",
"email": "john@example.com",
"full_name": "John Doe",
"password": "secret123"
}
}
class UserResponse(UserBase):
"""User response schema"""
id: int
is_active: bool
is_admin: bool
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
json_schema_extra = {
"example": {
"id": 1,
"username": "johndoe",
"email": "john@example.com",
"full_name": "John Doe",
"is_active": True,
"is_admin": False,
"created_at": "2025-01-01T00:00:00",
"updated_at": "2025-01-01T00:00:00"
}
}

View File

@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from sqlalchemy import func, and_
from datetime import datetime, timedelta
from app.models.user_v2 import User
from app.models.user import User
from app.models.task import Task, TaskStatus
from app.models.session import Session as UserSession
from app.models.audit_log import AuditLog

View File

@@ -1,421 +0,0 @@
"""
Tool_OCR - Background Tasks Service
Handles async processing, cleanup, and scheduled tasks
"""
import logging
import asyncio
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional, Callable, Any
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.models.ocr import OCRBatch, OCRFile, OCRResult, BatchStatus, FileStatus
from app.services.ocr_service import OCRService
from app.services.file_manager import FileManager
from app.services.pdf_generator import PDFGenerator
logger = logging.getLogger(__name__)
class BackgroundTaskManager:
"""
Manages background tasks including retry logic, cleanup, and scheduled jobs
"""
def __init__(
self,
max_retries: int = 3,
retry_delay: int = 5,
cleanup_interval: int = 3600, # 1 hour
file_retention_hours: int = 24
):
self.max_retries = max_retries
self.retry_delay = retry_delay
self.cleanup_interval = cleanup_interval
self.file_retention_hours = file_retention_hours
self.ocr_service = OCRService()
self.file_manager = FileManager()
self.pdf_generator = PDFGenerator()
async def execute_with_retry(
self,
func: Callable,
*args,
max_retries: Optional[int] = None,
retry_delay: Optional[int] = None,
**kwargs
) -> Any:
"""
Execute a function with retry logic
Args:
func: Function to execute
args: Positional arguments for func
max_retries: Maximum retry attempts (overrides default)
retry_delay: Delay between retries in seconds (overrides default)
kwargs: Keyword arguments for func
Returns:
Function result
Raises:
Exception: If all retries are exhausted
"""
max_retries = max_retries or self.max_retries
retry_delay = retry_delay or self.retry_delay
last_exception = None
for attempt in range(max_retries + 1):
try:
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < max_retries:
logger.warning(
f"Attempt {attempt + 1}/{max_retries + 1} failed for {func.__name__}: {e}. "
f"Retrying in {retry_delay}s..."
)
await asyncio.sleep(retry_delay)
else:
logger.error(
f"All {max_retries + 1} attempts failed for {func.__name__}: {e}"
)
raise last_exception
def process_single_file_with_retry(
self,
ocr_file: OCRFile,
batch_id: int,
lang: str,
detect_layout: bool,
db: Session
) -> bool:
"""
Process a single file with retry logic
Args:
ocr_file: OCRFile instance
batch_id: Batch ID
lang: Language code
detect_layout: Whether to detect layout
db: Database session
Returns:
bool: True if successful, False otherwise
"""
for attempt in range(self.max_retries + 1):
try:
# Update file status
ocr_file.status = FileStatus.PROCESSING
ocr_file.started_at = datetime.utcnow()
ocr_file.retry_count = attempt
db.commit()
# Get file paths
file_path = Path(ocr_file.file_path)
paths = self.file_manager.get_file_paths(batch_id, ocr_file.id)
# Process OCR
result = self.ocr_service.process_image(
file_path,
lang=lang,
detect_layout=detect_layout
)
# Check if processing was successful
if result['status'] != 'success':
raise Exception(result.get('error_message', 'Unknown error during OCR processing'))
# Save results
json_path, markdown_path = self.ocr_service.save_results(
result=result,
output_dir=paths["output_dir"],
file_id=str(ocr_file.id)
)
# Extract data from result
text_regions = result.get('text_regions', [])
layout_data = result.get('layout_data')
images_metadata = result.get('images_metadata', [])
# Calculate average confidence (or use from result)
avg_confidence = result.get('average_confidence')
# Create OCR result record
ocr_result = OCRResult(
file_id=ocr_file.id,
markdown_path=str(markdown_path) if markdown_path else None,
json_path=str(json_path) if json_path else None,
images_dir=None, # Images dir not used in current implementation
detected_language=lang,
total_text_regions=len(text_regions),
average_confidence=avg_confidence,
layout_data=layout_data,
images_metadata=images_metadata
)
db.add(ocr_result)
# Update file status
ocr_file.status = FileStatus.COMPLETED
ocr_file.completed_at = datetime.utcnow()
ocr_file.processing_time = (ocr_file.completed_at - ocr_file.started_at).total_seconds()
# Commit with retry on connection errors
try:
db.commit()
except Exception as commit_error:
logger.warning(f"Commit failed, rolling back and retrying: {commit_error}")
db.rollback()
db.refresh(ocr_file)
ocr_file.status = FileStatus.COMPLETED
ocr_file.completed_at = datetime.utcnow()
ocr_file.processing_time = (ocr_file.completed_at - ocr_file.started_at).total_seconds()
db.commit()
logger.info(f"Successfully processed file {ocr_file.id} ({ocr_file.original_filename})")
return True
except Exception as e:
logger.error(f"Attempt {attempt + 1}/{self.max_retries + 1} failed for file {ocr_file.id}: {e}")
db.rollback() # Rollback failed transaction
if attempt < self.max_retries:
# Wait before retry
time.sleep(self.retry_delay)
else:
# Final failure
try:
ocr_file.status = FileStatus.FAILED
ocr_file.error_message = f"Failed after {self.max_retries + 1} attempts: {str(e)}"
ocr_file.completed_at = datetime.utcnow()
ocr_file.retry_count = attempt
db.commit()
except Exception as final_error:
logger.error(f"Failed to update error status: {final_error}")
db.rollback()
return False
return False
async def cleanup_expired_files(self, db: Session):
"""
Clean up files and batches older than retention period
Args:
db: Database session
"""
try:
cutoff_time = datetime.utcnow() - timedelta(hours=self.file_retention_hours)
# Find expired batches
expired_batches = db.query(OCRBatch).filter(
OCRBatch.created_at < cutoff_time,
OCRBatch.status.in_([BatchStatus.COMPLETED, BatchStatus.FAILED, BatchStatus.PARTIAL])
).all()
logger.info(f"Found {len(expired_batches)} expired batches to clean up")
for batch in expired_batches:
try:
# Get batch directory
batch_dir = self.file_manager.base_upload_dir / "batches" / str(batch.id)
# Delete physical files
if batch_dir.exists():
import shutil
shutil.rmtree(batch_dir)
logger.info(f"Deleted batch directory: {batch_dir}")
# Delete database records
# Delete results first (foreign key constraint)
db.query(OCRResult).filter(
OCRResult.file_id.in_(
db.query(OCRFile.id).filter(OCRFile.batch_id == batch.id)
)
).delete(synchronize_session=False)
# Delete files
db.query(OCRFile).filter(OCRFile.batch_id == batch.id).delete()
# Delete batch
db.delete(batch)
db.commit()
logger.info(f"Cleaned up expired batch {batch.id}")
except Exception as e:
logger.error(f"Error cleaning up batch {batch.id}: {e}")
db.rollback()
except Exception as e:
logger.error(f"Error in cleanup_expired_files: {e}")
async def generate_pdf_background(
self,
result_id: int,
output_path: str,
css_template: str = "default",
db: Session = None
):
"""
Generate PDF in background with retry logic
Args:
result_id: OCR result ID
output_path: Output PDF path
css_template: CSS template name
db: Database session
"""
should_close_db = False
if db is None:
db = SessionLocal()
should_close_db = True
try:
# Get result
result = db.query(OCRResult).filter(OCRResult.id == result_id).first()
if not result:
logger.error(f"Result {result_id} not found")
return
# Generate PDF with retry
await self.execute_with_retry(
self.pdf_generator.generate_pdf,
markdown_path=result.markdown_path,
output_path=output_path,
css_template=css_template,
max_retries=2,
retry_delay=3
)
logger.info(f"Successfully generated PDF for result {result_id}: {output_path}")
except Exception as e:
logger.error(f"Failed to generate PDF for result {result_id}: {e}")
finally:
if should_close_db:
db.close()
async def start_cleanup_scheduler(self):
"""
Start periodic cleanup scheduler
Runs cleanup task at specified intervals
"""
logger.info(f"Starting cleanup scheduler (interval: {self.cleanup_interval}s, retention: {self.file_retention_hours}h)")
while True:
try:
db = SessionLocal()
await self.cleanup_expired_files(db)
db.close()
except Exception as e:
logger.error(f"Error in cleanup scheduler: {e}")
# Wait for next interval
await asyncio.sleep(self.cleanup_interval)
# Global task manager instance
task_manager = BackgroundTaskManager()
def process_batch_files_with_retry(
batch_id: int,
lang: str,
detect_layout: bool,
db: Session
):
"""
Process all files in a batch with retry logic
Args:
batch_id: Batch ID
lang: Language code
detect_layout: Whether to detect layout
db: Database session
"""
try:
# Get batch
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
if not batch:
logger.error(f"Batch {batch_id} not found")
return
# Update batch status
batch.status = BatchStatus.PROCESSING
batch.started_at = datetime.utcnow()
db.commit()
# Get pending files
files = db.query(OCRFile).filter(
OCRFile.batch_id == batch_id,
OCRFile.status == FileStatus.PENDING
).all()
logger.info(f"Processing {len(files)} files in batch {batch_id} with retry logic")
# Process each file with retry
for ocr_file in files:
success = task_manager.process_single_file_with_retry(
ocr_file=ocr_file,
batch_id=batch_id,
lang=lang,
detect_layout=detect_layout,
db=db
)
# Update batch progress
if success:
batch.completed_files += 1
else:
batch.failed_files += 1
db.commit()
# Update batch final status
if batch.failed_files == 0:
batch.status = BatchStatus.COMPLETED
elif batch.completed_files > 0:
batch.status = BatchStatus.PARTIAL
else:
batch.status = BatchStatus.FAILED
batch.completed_at = datetime.utcnow()
# Commit with retry on connection errors
try:
db.commit()
except Exception as commit_error:
logger.warning(f"Batch commit failed, rolling back and retrying: {commit_error}")
db.rollback()
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
if batch:
batch.completed_at = datetime.utcnow()
db.commit()
logger.info(
f"Batch {batch_id} processing complete: "
f"{batch.completed_files} succeeded, {batch.failed_files} failed"
)
except Exception as e:
logger.error(f"Fatal error processing batch {batch_id}: {e}")
db.rollback() # Rollback any failed transaction
try:
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
if batch:
batch.status = BatchStatus.FAILED
batch.completed_at = datetime.utcnow()
db.commit()
except Exception as commit_error:
logger.error(f"Error updating batch status: {commit_error}")
db.rollback()

View File

@@ -1,512 +0,0 @@
"""
Tool_OCR - Export Service
Handles OCR result export in multiple formats with filtering and formatting rules
"""
import json
import logging
import zipfile
from pathlib import Path
from typing import List, Dict, Optional, Any
from datetime import datetime
import pandas as pd
from sqlalchemy.orm import Session
from app.core.config import settings
from app.models.ocr import OCRBatch, OCRFile, OCRResult, FileStatus
from app.models.export import ExportRule
from app.services.pdf_generator import PDFGenerator, PDFGenerationError
logger = logging.getLogger(__name__)
class ExportError(Exception):
"""Exception raised for export errors"""
pass
class ExportService:
"""
Export service for OCR results
Supported formats:
- TXT: Plain text export
- JSON: Full metadata export
- Excel: Tabular data export
- Markdown: Direct Markdown export
- PDF: Layout-preserved PDF export
- ZIP: Batch export archive
"""
def __init__(self):
"""Initialize export service"""
self.pdf_generator = PDFGenerator()
def apply_filters(
self,
results: List[OCRResult],
filters: Dict[str, Any]
) -> List[OCRResult]:
"""
Apply filters to OCR results
Args:
results: List of OCR results
filters: Filter configuration
- confidence_threshold: Minimum confidence (0.0-1.0)
- filename_pattern: Glob pattern for filename matching
- language: Filter by detected language
Returns:
List[OCRResult]: Filtered results
"""
filtered = results
# Confidence threshold filter
if "confidence_threshold" in filters:
threshold = filters["confidence_threshold"]
filtered = [r for r in filtered if r.average_confidence and r.average_confidence >= threshold]
# Filename pattern filter (using simple substring match)
if "filename_pattern" in filters:
pattern = filters["filename_pattern"].lower()
filtered = [
r for r in filtered
if pattern in r.file.original_filename.lower()
]
# Language filter
if "language" in filters:
lang = filters["language"]
filtered = [r for r in filtered if r.detected_language == lang]
return filtered
def export_to_txt(
self,
results: List[OCRResult],
output_path: Path,
formatting: Optional[Dict] = None
) -> Path:
"""
Export results to plain text file
Args:
results: List of OCR results
output_path: Output file path
formatting: Formatting options
- add_line_numbers: Add line numbers
- group_by_filename: Group text by source file
- include_metadata: Add file metadata headers
Returns:
Path: Output file path
Raises:
ExportError: If export fails
"""
try:
formatting = formatting or {}
output_lines = []
for idx, result in enumerate(results, 1):
# Read Markdown file
if not result.markdown_path or not Path(result.markdown_path).exists():
logger.warning(f"Markdown file not found for result {result.id}")
continue
markdown_content = Path(result.markdown_path).read_text(encoding="utf-8")
# Add metadata header if requested
if formatting.get("include_metadata", False):
output_lines.append(f"=" * 80)
output_lines.append(f"文件: {result.file.original_filename}")
output_lines.append(f"語言: {result.detected_language or '未知'}")
output_lines.append(f"信心度: {result.average_confidence:.2%}" if result.average_confidence else "信心度: N/A")
output_lines.append(f"=" * 80)
output_lines.append("")
# Add content with optional line numbers
if formatting.get("add_line_numbers", False):
for line_num, line in enumerate(markdown_content.split('\n'), 1):
output_lines.append(f"{line_num:4d} | {line}")
else:
output_lines.append(markdown_content)
# Add separator between files if grouping
if formatting.get("group_by_filename", False) and idx < len(results):
output_lines.append("\n" + "-" * 80 + "\n")
# Write to file
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text("\n".join(output_lines), encoding="utf-8")
logger.info(f"Exported {len(results)} results to TXT: {output_path}")
return output_path
except Exception as e:
raise ExportError(f"TXT export failed: {str(e)}")
def export_to_json(
self,
results: List[OCRResult],
output_path: Path,
include_layout: bool = True,
include_images: bool = True
) -> Path:
"""
Export results to JSON file with full metadata
Args:
results: List of OCR results
output_path: Output file path
include_layout: Include layout data
include_images: Include images metadata
Returns:
Path: Output file path
Raises:
ExportError: If export fails
"""
try:
export_data = {
"export_time": datetime.utcnow().isoformat(),
"total_files": len(results),
"results": []
}
for result in results:
result_data = {
"file_id": result.file.id,
"filename": result.file.original_filename,
"file_format": result.file.file_format,
"file_size": result.file.file_size,
"processing_time": result.file.processing_time,
"detected_language": result.detected_language,
"total_text_regions": result.total_text_regions,
"average_confidence": result.average_confidence,
"markdown_path": result.markdown_path,
}
# Include layout data if requested
if include_layout and result.layout_data:
result_data["layout_data"] = result.layout_data
# Include images metadata if requested
if include_images and result.images_metadata:
result_data["images_metadata"] = result.images_metadata
export_data["results"].append(result_data)
# Write to file
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(
json.dumps(export_data, ensure_ascii=False, indent=2),
encoding="utf-8"
)
logger.info(f"Exported {len(results)} results to JSON: {output_path}")
return output_path
except Exception as e:
raise ExportError(f"JSON export failed: {str(e)}")
def export_to_excel(
self,
results: List[OCRResult],
output_path: Path,
include_confidence: bool = True,
include_processing_time: bool = True
) -> Path:
"""
Export results to Excel file
Args:
results: List of OCR results
output_path: Output file path
include_confidence: Include confidence scores
include_processing_time: Include processing time
Returns:
Path: Output file path
Raises:
ExportError: If export fails
"""
try:
rows = []
for result in results:
# Read Markdown content
text_content = ""
if result.markdown_path and Path(result.markdown_path).exists():
text_content = Path(result.markdown_path).read_text(encoding="utf-8")
row = {
"文件名": result.file.original_filename,
"格式": result.file.file_format,
"大小(字節)": result.file.file_size,
"語言": result.detected_language or "未知",
"文本區域數": result.total_text_regions,
"提取內容": text_content[:1000] + "..." if len(text_content) > 1000 else text_content,
}
if include_confidence:
row["平均信心度"] = f"{result.average_confidence:.2%}" if result.average_confidence else "N/A"
if include_processing_time:
row["處理時間(秒)"] = f"{result.file.processing_time:.2f}" if result.file.processing_time else "N/A"
rows.append(row)
# Create DataFrame and export
df = pd.DataFrame(rows)
output_path.parent.mkdir(parents=True, exist_ok=True)
df.to_excel(output_path, index=False, engine='openpyxl')
logger.info(f"Exported {len(results)} results to Excel: {output_path}")
return output_path
except Exception as e:
raise ExportError(f"Excel export failed: {str(e)}")
def export_to_markdown(
self,
results: List[OCRResult],
output_path: Path,
combine: bool = True
) -> Path:
"""
Export results to Markdown file(s)
Args:
results: List of OCR results
output_path: Output file path (or directory if not combining)
combine: Combine all results into one file
Returns:
Path: Output file/directory path
Raises:
ExportError: If export fails
"""
try:
if combine:
# Combine all Markdown files into one
combined_content = []
for result in results:
if not result.markdown_path or not Path(result.markdown_path).exists():
continue
markdown_content = Path(result.markdown_path).read_text(encoding="utf-8")
# Add header
combined_content.append(f"# {result.file.original_filename}\n")
combined_content.append(markdown_content)
combined_content.append("\n---\n") # Separator
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text("\n".join(combined_content), encoding="utf-8")
logger.info(f"Exported {len(results)} results to combined Markdown: {output_path}")
return output_path
else:
# Export each result to separate file
output_path.mkdir(parents=True, exist_ok=True)
for result in results:
if not result.markdown_path or not Path(result.markdown_path).exists():
continue
# Copy Markdown file to output directory
src_path = Path(result.markdown_path)
dst_path = output_path / f"{result.file.original_filename}.md"
dst_path.write_text(src_path.read_text(encoding="utf-8"), encoding="utf-8")
logger.info(f"Exported {len(results)} results to separate Markdown files: {output_path}")
return output_path
except Exception as e:
raise ExportError(f"Markdown export failed: {str(e)}")
def export_to_pdf(
self,
result: OCRResult,
output_path: Path,
css_template: str = "default",
metadata: Optional[Dict] = None
) -> Path:
"""
Export single result to PDF with layout preservation
Args:
result: OCR result
output_path: Output PDF path
css_template: CSS template name or custom CSS
metadata: Optional PDF metadata
Returns:
Path: Output PDF path
Raises:
ExportError: If export fails
"""
try:
if not result.markdown_path or not Path(result.markdown_path).exists():
raise ExportError(f"Markdown file not found for result {result.id}")
markdown_path = Path(result.markdown_path)
# Prepare metadata
pdf_metadata = metadata or {}
if "title" not in pdf_metadata:
pdf_metadata["title"] = result.file.original_filename
# Generate PDF
self.pdf_generator.generate_pdf(
markdown_path=markdown_path,
output_path=output_path,
css_template=css_template,
metadata=pdf_metadata
)
logger.info(f"Exported result {result.id} to PDF: {output_path}")
return output_path
except PDFGenerationError as e:
raise ExportError(f"PDF generation failed: {str(e)}")
except Exception as e:
raise ExportError(f"PDF export failed: {str(e)}")
def export_batch_to_zip(
self,
db: Session,
batch_id: int,
output_path: Path,
include_formats: Optional[List[str]] = None
) -> Path:
"""
Export entire batch to ZIP archive
Args:
db: Database session
batch_id: Batch ID
output_path: Output ZIP path
include_formats: List of formats to include (markdown, json, txt, excel, pdf)
Returns:
Path: Output ZIP path
Raises:
ExportError: If export fails
"""
try:
include_formats = include_formats or ["markdown", "json"]
# Get batch and results
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
if not batch:
raise ExportError(f"Batch {batch_id} not found")
results = db.query(OCRResult).join(OCRFile).filter(
OCRFile.batch_id == batch_id,
OCRFile.status == FileStatus.COMPLETED
).all()
if not results:
raise ExportError(f"No completed results found for batch {batch_id}")
# Create temporary export directory
temp_dir = output_path.parent / f"temp_export_{batch_id}"
temp_dir.mkdir(parents=True, exist_ok=True)
try:
# Export in requested formats
if "markdown" in include_formats:
md_dir = temp_dir / "markdown"
self.export_to_markdown(results, md_dir, combine=False)
if "json" in include_formats:
json_path = temp_dir / "batch_results.json"
self.export_to_json(results, json_path)
if "txt" in include_formats:
txt_path = temp_dir / "batch_results.txt"
self.export_to_txt(results, txt_path)
if "excel" in include_formats:
excel_path = temp_dir / "batch_results.xlsx"
self.export_to_excel(results, excel_path)
# Create ZIP archive
output_path.parent.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for file_path in temp_dir.rglob('*'):
if file_path.is_file():
arcname = file_path.relative_to(temp_dir)
zipf.write(file_path, arcname)
logger.info(f"Exported batch {batch_id} to ZIP: {output_path}")
return output_path
finally:
# Clean up temporary directory
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
except Exception as e:
raise ExportError(f"Batch ZIP export failed: {str(e)}")
def apply_export_rule(
self,
db: Session,
results: List[OCRResult],
rule_id: int
) -> List[OCRResult]:
"""
Apply export rule to filter and format results
Args:
db: Database session
results: List of OCR results
rule_id: Export rule ID
Returns:
List[OCRResult]: Filtered results
Raises:
ExportError: If rule not found
"""
rule = db.query(ExportRule).filter(ExportRule.id == rule_id).first()
if not rule:
raise ExportError(f"Export rule {rule_id} not found")
config = rule.config_json
# Apply filters
if "filters" in config:
results = self.apply_filters(results, config["filters"])
# Note: Formatting options are applied in individual export methods
return results
def get_export_formats(self) -> Dict[str, str]:
"""
Get available export formats
Returns:
Dict mapping format codes to descriptions
"""
return {
"txt": "純文本格式 (.txt)",
"json": "JSON 格式 - 包含完整元數據 (.json)",
"excel": "Excel 表格格式 (.xlsx)",
"markdown": "Markdown 格式 (.md)",
"pdf": "版面保留 PDF 格式 (.pdf)",
"zip": "批次打包格式 (.zip)",
}

View File

@@ -1,420 +0,0 @@
"""
Tool_OCR - File Management Service
Handles file uploads, storage, validation, and cleanup
"""
import logging
import shutil
import uuid
from pathlib import Path
from typing import List, Tuple, Optional
from datetime import datetime, timedelta
from fastapi import UploadFile
from sqlalchemy.orm import Session
from app.core.config import settings
from app.models.ocr import OCRBatch, OCRFile, FileStatus
from app.services.preprocessor import DocumentPreprocessor
logger = logging.getLogger(__name__)
class FileManagementError(Exception):
"""Exception raised for file management errors"""
pass
class FileManager:
"""
File management service for upload, storage, and cleanup
Directory structure:
uploads/
├── batches/
│ └── {batch_id}/
│ ├── inputs/ # Original uploaded files
│ ├── outputs/ # OCR results
│ │ ├── markdown/ # Markdown files
│ │ ├── json/ # JSON files
│ │ └── images/ # Extracted images
│ └── exports/ # Export files (PDF, Excel, etc.)
"""
def __init__(self):
"""Initialize file manager"""
self.preprocessor = DocumentPreprocessor()
self.base_upload_dir = Path(settings.upload_dir)
self.base_upload_dir.mkdir(parents=True, exist_ok=True)
def create_batch_directory(self, batch_id: int) -> Path:
"""
Create directory structure for a batch
Args:
batch_id: Batch ID
Returns:
Path: Batch directory path
"""
batch_dir = self.base_upload_dir / "batches" / str(batch_id)
# Create subdirectories
(batch_dir / "inputs").mkdir(parents=True, exist_ok=True)
(batch_dir / "outputs" / "markdown").mkdir(parents=True, exist_ok=True)
(batch_dir / "outputs" / "json").mkdir(parents=True, exist_ok=True)
(batch_dir / "outputs" / "images").mkdir(parents=True, exist_ok=True)
(batch_dir / "exports").mkdir(parents=True, exist_ok=True)
logger.info(f"Created batch directory: {batch_dir}")
return batch_dir
def get_batch_directory(self, batch_id: int) -> Path:
"""
Get batch directory path
Args:
batch_id: Batch ID
Returns:
Path: Batch directory path
"""
return self.base_upload_dir / "batches" / str(batch_id)
def validate_upload(self, file: UploadFile) -> Tuple[bool, Optional[str]]:
"""
Validate uploaded file before saving
Args:
file: Uploaded file
Returns:
Tuple of (is_valid, error_message)
"""
# Check filename
if not file.filename:
return False, "文件名不能為空"
# Check file size (read content size)
file.file.seek(0, 2) # Seek to end
file_size = file.file.tell()
file.file.seek(0) # Reset to beginning
if file_size == 0:
return False, "文件為空"
if file_size > settings.max_upload_size:
max_mb = settings.max_upload_size / (1024 * 1024)
return False, f"文件大小超過限制 ({max_mb}MB)"
# Check file extension
file_ext = Path(file.filename).suffix.lower()
allowed_extensions = {'.png', '.jpg', '.jpeg', '.pdf', '.doc', '.docx', '.ppt', '.pptx'}
if file_ext not in allowed_extensions:
return False, f"不支持的文件格式 ({file_ext}),僅支持: {', '.join(allowed_extensions)}"
return True, None
def save_upload(
self,
file: UploadFile,
batch_id: int,
validate: bool = True
) -> Tuple[Path, str]:
"""
Save uploaded file to batch directory
Args:
file: Uploaded file
batch_id: Batch ID
validate: Whether to validate file
Returns:
Tuple of (file_path, original_filename)
Raises:
FileManagementError: If file validation or saving fails
"""
# Validate if requested
if validate:
is_valid, error_msg = self.validate_upload(file)
if not is_valid:
raise FileManagementError(error_msg)
# Generate unique filename to avoid conflicts
original_filename = file.filename
file_ext = Path(original_filename).suffix
unique_filename = f"{uuid.uuid4()}{file_ext}"
# Get batch input directory
batch_dir = self.get_batch_directory(batch_id)
input_dir = batch_dir / "inputs"
input_dir.mkdir(parents=True, exist_ok=True)
# Save file
file_path = input_dir / unique_filename
try:
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
logger.info(f"Saved upload: {file_path} (original: {original_filename})")
return file_path, original_filename
except Exception as e:
# Clean up partial file if exists
file_path.unlink(missing_ok=True)
raise FileManagementError(f"保存文件失敗: {str(e)}")
def validate_saved_file(self, file_path: Path) -> Tuple[bool, Optional[str], Optional[str]]:
"""
Validate saved file using preprocessor
Args:
file_path: Path to saved file
Returns:
Tuple of (is_valid, error_message, detected_format)
"""
return self.preprocessor.validate_file(file_path)
def create_batch(
self,
db: Session,
user_id: int,
batch_name: Optional[str] = None
) -> OCRBatch:
"""
Create new OCR batch
Args:
db: Database session
user_id: User ID
batch_name: Optional batch name
Returns:
OCRBatch: Created batch object
"""
# Create batch record
batch = OCRBatch(
user_id=user_id,
batch_name=batch_name or f"Batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
)
db.add(batch)
db.commit()
db.refresh(batch)
# Create directory structure
self.create_batch_directory(batch.id)
logger.info(f"Created batch: {batch.id} for user {user_id}")
return batch
def add_file_to_batch(
self,
db: Session,
batch_id: int,
file: UploadFile
) -> OCRFile:
"""
Add file to batch and save to disk
Args:
db: Database session
batch_id: Batch ID
file: Uploaded file
Returns:
OCRFile: Created file record
Raises:
FileManagementError: If file operations fail
"""
# Save file to disk
file_path, original_filename = self.save_upload(file, batch_id)
# Validate saved file
is_valid, detected_format, error_msg = self.validate_saved_file(file_path)
# Create file record
ocr_file = OCRFile(
batch_id=batch_id,
filename=file_path.name,
original_filename=original_filename,
file_path=str(file_path),
file_size=file_path.stat().st_size,
file_format=detected_format or Path(original_filename).suffix.lower().lstrip('.'),
status=FileStatus.PENDING if is_valid else FileStatus.FAILED,
error_message=error_msg if not is_valid else None
)
db.add(ocr_file)
# Update batch total_files count
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
if batch:
batch.total_files += 1
if not is_valid:
batch.failed_files += 1
db.commit()
db.refresh(ocr_file)
logger.info(f"Added file to batch {batch_id}: {ocr_file.id} (status: {ocr_file.status})")
return ocr_file
def add_files_to_batch(
self,
db: Session,
batch_id: int,
files: List[UploadFile]
) -> List[OCRFile]:
"""
Add multiple files to batch
Args:
db: Database session
batch_id: Batch ID
files: List of uploaded files
Returns:
List[OCRFile]: List of created file records
"""
ocr_files = []
for file in files:
try:
ocr_file = self.add_file_to_batch(db, batch_id, file)
ocr_files.append(ocr_file)
except FileManagementError as e:
logger.error(f"Failed to add file {file.filename} to batch {batch_id}: {e}")
# Continue with other files
continue
return ocr_files
def get_file_paths(self, batch_id: int, file_id: int) -> dict:
"""
Get all paths for a file in a batch
Args:
batch_id: Batch ID
file_id: File ID
Returns:
Dict containing all relevant paths
"""
batch_dir = self.get_batch_directory(batch_id)
return {
"input_dir": batch_dir / "inputs",
"output_dir": batch_dir / "outputs",
"markdown_dir": batch_dir / "outputs" / "markdown",
"json_dir": batch_dir / "outputs" / "json",
"images_dir": batch_dir / "outputs" / "images" / str(file_id),
"export_dir": batch_dir / "exports",
}
def cleanup_expired_batches(self, db: Session, retention_hours: int = 24) -> int:
"""
Clean up expired batch files
Args:
db: Database session
retention_hours: Number of hours to retain files
Returns:
int: Number of batches cleaned up
"""
cutoff_time = datetime.utcnow() - timedelta(hours=retention_hours)
# Find expired batches
expired_batches = db.query(OCRBatch).filter(
OCRBatch.created_at < cutoff_time
).all()
cleaned_count = 0
for batch in expired_batches:
try:
# Delete batch directory
batch_dir = self.get_batch_directory(batch.id)
if batch_dir.exists():
shutil.rmtree(batch_dir)
logger.info(f"Deleted batch directory: {batch_dir}")
# Delete database records (cascade will handle related records)
db.delete(batch)
cleaned_count += 1
except Exception as e:
logger.error(f"Failed to cleanup batch {batch.id}: {e}")
continue
if cleaned_count > 0:
db.commit()
logger.info(f"Cleaned up {cleaned_count} expired batches")
return cleaned_count
def verify_file_ownership(
self,
db: Session,
user_id: int,
batch_id: int
) -> bool:
"""
Verify user owns the batch
Args:
db: Database session
user_id: User ID
batch_id: Batch ID
Returns:
bool: True if user owns batch, False otherwise
"""
batch = db.query(OCRBatch).filter(
OCRBatch.id == batch_id,
OCRBatch.user_id == user_id
).first()
return batch is not None
def get_batch_statistics(self, db: Session, batch_id: int) -> dict:
"""
Get statistics for a batch
Args:
db: Database session
batch_id: Batch ID
Returns:
Dict containing batch statistics
"""
batch = db.query(OCRBatch).filter(OCRBatch.id == batch_id).first()
if not batch:
return {}
# Calculate total file size
total_size = sum(f.file_size for f in batch.files)
# Calculate processing time
processing_time = None
if batch.completed_at and batch.started_at:
processing_time = (batch.completed_at - batch.started_at).total_seconds()
return {
"batch_id": batch.id,
"batch_name": batch.batch_name,
"status": batch.status,
"total_files": batch.total_files,
"completed_files": batch.completed_files,
"failed_files": batch.failed_files,
"pending_files": batch.total_files - batch.completed_files - batch.failed_files,
"progress_percentage": batch.progress_percentage,
"total_file_size": total_size,
"total_file_size_mb": round(total_size / (1024 * 1024), 2),
"created_at": batch.created_at.isoformat(),
"started_at": batch.started_at.isoformat() if batch.started_at else None,
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
"processing_time": processing_time,
}

View File

@@ -1,282 +0,0 @@
"""
Tool_OCR - Translation Service (RESERVED)
Abstract interface and stub implementation for future translation feature
"""
from abc import ABC, abstractmethod
from typing import Dict, Optional, List
from enum import Enum
import logging
logger = logging.getLogger(__name__)
class TranslationEngine(str, Enum):
"""Supported translation engines"""
OFFLINE = "offline" # Argos Translate (offline)
ERNIE = "ernie" # Baidu ERNIE API
GOOGLE = "google" # Google Translate API
DEEPL = "deepl" # DeepL API
class LanguageCode(str, Enum):
"""Supported language codes"""
CHINESE = "zh"
ENGLISH = "en"
JAPANESE = "ja"
KOREAN = "ko"
FRENCH = "fr"
GERMAN = "de"
SPANISH = "es"
class TranslationServiceInterface(ABC):
"""
Abstract interface for translation services
This interface defines the contract for all translation engine implementations.
Future implementations should inherit from this class.
"""
@abstractmethod
def translate_text(
self,
text: str,
source_lang: str,
target_lang: str,
**kwargs
) -> str:
"""
Translate a single text string
Args:
text: Text to translate
source_lang: Source language code
target_lang: Target language code
**kwargs: Engine-specific parameters
Returns:
str: Translated text
"""
pass
@abstractmethod
def translate_document(
self,
markdown_content: str,
source_lang: str,
target_lang: str,
preserve_structure: bool = True,
**kwargs
) -> Dict[str, any]:
"""
Translate a Markdown document while preserving structure
Args:
markdown_content: Markdown content to translate
source_lang: Source language code
target_lang: Target language code
preserve_structure: Whether to preserve markdown structure
**kwargs: Engine-specific parameters
Returns:
Dict containing:
- translated_content: Translated markdown
- metadata: Translation metadata (engine, time, etc.)
"""
pass
@abstractmethod
def batch_translate(
self,
texts: List[str],
source_lang: str,
target_lang: str,
**kwargs
) -> List[str]:
"""
Translate multiple texts in batch
Args:
texts: List of texts to translate
source_lang: Source language code
target_lang: Target language code
**kwargs: Engine-specific parameters
Returns:
List[str]: List of translated texts
"""
pass
@abstractmethod
def get_supported_languages(self) -> List[str]:
"""
Get list of supported language codes for this engine
Returns:
List[str]: List of supported language codes
"""
pass
@abstractmethod
def validate_config(self) -> bool:
"""
Validate engine configuration (API keys, model files, etc.)
Returns:
bool: True if configuration is valid
"""
pass
class TranslationEngineFactory:
"""
Factory for creating translation engine instances
RESERVED: This is a placeholder for future implementation.
When translation feature is implemented, this factory will instantiate
the appropriate translation engine based on configuration.
"""
@staticmethod
def create_engine(
engine_type: TranslationEngine,
config: Optional[Dict] = None
) -> TranslationServiceInterface:
"""
Create a translation engine instance
Args:
engine_type: Type of translation engine
config: Engine-specific configuration
Returns:
TranslationServiceInterface: Translation engine instance
Raises:
NotImplementedError: Always raised (stub implementation)
"""
raise NotImplementedError(
"Translation feature is not yet implemented. "
"This is a reserved placeholder for future development."
)
@staticmethod
def get_available_engines() -> List[str]:
"""
Get list of available translation engines
Returns:
List[str]: List of engine types (currently empty)
"""
return []
@staticmethod
def is_engine_available(engine_type: TranslationEngine) -> bool:
"""
Check if a specific engine is available
Args:
engine_type: Engine type to check
Returns:
bool: Always False (stub implementation)
"""
return False
class StubTranslationService:
"""
Stub translation service for API endpoints
This service provides placeholder responses for translation endpoints
until the feature is fully implemented.
"""
@staticmethod
def get_feature_status() -> Dict[str, any]:
"""
Get translation feature status
Returns:
Dict with feature status information
"""
return {
"available": False,
"status": "reserved",
"message": "Translation feature is reserved for future implementation",
"supported_engines": [],
"planned_engines": [
{
"type": "offline",
"name": "Argos Translate",
"description": "Offline neural translation",
"status": "planned"
},
{
"type": "ernie",
"name": "Baidu ERNIE",
"description": "Baidu AI translation API",
"status": "planned"
},
{
"type": "google",
"name": "Google Translate",
"description": "Google Cloud Translation API",
"status": "planned"
},
{
"type": "deepl",
"name": "DeepL",
"description": "DeepL translation API",
"status": "planned"
}
],
"roadmap": {
"phase": "Phase 5",
"priority": "low",
"implementation_after": "Production deployment and user feedback"
}
}
@staticmethod
def get_supported_languages() -> List[Dict[str, str]]:
"""
Get list of languages planned for translation support
Returns:
List of language info dicts
"""
return [
{"code": "zh", "name": "Chinese (Simplified)", "status": "planned"},
{"code": "en", "name": "English", "status": "planned"},
{"code": "ja", "name": "Japanese", "status": "planned"},
{"code": "ko", "name": "Korean", "status": "planned"},
{"code": "fr", "name": "French", "status": "planned"},
{"code": "de", "name": "German", "status": "planned"},
{"code": "es", "name": "Spanish", "status": "planned"},
]
# Example placeholder for future engine implementations:
#
# class ArgosTranslationEngine(TranslationServiceInterface):
# """Offline translation using Argos Translate"""
# def __init__(self, model_path: str):
# self.model_path = model_path
# # Initialize Argos models
#
# def translate_text(self, text, source_lang, target_lang, **kwargs):
# # Implementation here
# pass
#
# class ERNIETranslationEngine(TranslationServiceInterface):
# """Baidu ERNIE API translation"""
# def __init__(self, api_key: str, api_secret: str):
# self.api_key = api_key
# self.api_secret = api_secret
#
# def translate_text(self, text, source_lang, target_lang, **kwargs):
# # Implementation here
# pass