fix: make torch import optional and add PaddlePaddle GPU memory management

Problem:
- Backend failed to start with ModuleNotFoundError for torch module
- torch was imported as hard dependency but not in requirements.txt
- Project uses PaddlePaddle which has its own CUDA implementation

Changes:
- Make torch import optional with try/except in ocr_service.py
- Make torch import optional in pp_structure_enhanced.py
- Add cleanup_gpu_memory() method using PaddlePaddle's memory management
- Add check_gpu_memory() method to monitor available GPU memory
- Use paddle.device.cuda.empty_cache() for GPU cleanup
- Use torch.cuda only if TORCH_AVAILABLE flag is True
- Add cleanup calls after OCR processing to prevent OOM errors
- Add memory checks before GPU-intensive operations

Benefits:
- Backend can start without torch installed
- GPU memory is properly managed using PaddlePaddle
- Optional torch support provides additional memory monitoring
- Prevents GPU OOM errors during document processing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
egg
2025-11-20 16:40:44 +08:00
parent 7064ea30d5
commit b997f9355a
2 changed files with 121 additions and 0 deletions

View File

@@ -9,12 +9,20 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from datetime import datetime from datetime import datetime
import uuid import uuid
import gc # For garbage collection
from paddleocr import PaddleOCR, PPStructureV3 from paddleocr import PaddleOCR, PPStructureV3
from PIL import Image from PIL import Image
from pdf2image import convert_from_path from pdf2image import convert_from_path
import paddle import paddle
# Optional torch import for additional GPU memory management
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
from app.core.config import settings from app.core.config import settings
from app.services.office_converter import OfficeConverter, OfficeConverterError from app.services.office_converter import OfficeConverter, OfficeConverterError
@@ -401,6 +409,78 @@ class OCRService:
return self.structure_engine return self.structure_engine
def cleanup_gpu_memory(self):
"""
Clean up GPU memory to prevent OOM errors.
This should be called after processing each document or batch.
Uses PaddlePaddle's built-in memory management and optionally torch if available.
"""
try:
# Clear PyTorch GPU cache if torch is available
if TORCH_AVAILABLE and torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.debug("Cleared PyTorch GPU cache")
# Clear PaddlePaddle GPU cache
if paddle.device.is_compiled_with_cuda():
paddle.device.cuda.empty_cache()
logger.debug("Cleared PaddlePaddle GPU cache")
# Force garbage collection
gc.collect()
# Log current GPU memory status
if TORCH_AVAILABLE and torch.cuda.is_available():
allocated_mb = torch.cuda.memory_allocated() / 1024**2
reserved_mb = torch.cuda.memory_reserved() / 1024**2
logger.debug(f"GPU memory after cleanup - Allocated: {allocated_mb:.1f}MB, Reserved: {reserved_mb:.1f}MB")
except Exception as e:
logger.warning(f"GPU memory cleanup failed (non-critical): {e}")
# Don't fail the processing if cleanup fails
def check_gpu_memory(self, required_mb: int = 2000) -> bool:
"""
Check if sufficient GPU memory is available.
Args:
required_mb: Required memory in MB (default 2000MB for OCR models)
Returns:
True if sufficient memory is available or GPU is not used
"""
try:
# Check GPU memory using torch if available, otherwise use PaddlePaddle
free_memory = None
if TORCH_AVAILABLE and torch.cuda.is_available():
free_memory = torch.cuda.mem_get_info()[0] / 1024**2
elif paddle.device.is_compiled_with_cuda():
# PaddlePaddle doesn't have direct API to get free memory,
# so we rely on cleanup and continue
logger.debug("Using PaddlePaddle GPU, memory info not directly available")
return True
if free_memory is not None:
if free_memory < required_mb:
logger.warning(f"Low GPU memory: {free_memory:.0f}MB available, {required_mb}MB required")
# Try to free memory
self.cleanup_gpu_memory()
# Check again
if TORCH_AVAILABLE and torch.cuda.is_available():
free_memory = torch.cuda.mem_get_info()[0] / 1024**2
if free_memory < required_mb:
logger.error(f"Insufficient GPU memory after cleanup: {free_memory:.0f}MB")
return False
logger.debug(f"GPU memory check passed: {free_memory:.0f}MB available")
return True
except Exception as e:
logger.warning(f"GPU memory check failed: {e}")
return True # Continue processing even if check fails
def convert_pdf_to_images(self, pdf_path: Path, output_dir: Path) -> List[Path]: def convert_pdf_to_images(self, pdf_path: Path, output_dir: Path) -> List[Path]:
""" """
Convert PDF to images (one per page) Convert PDF to images (one per page)
@@ -587,6 +667,10 @@ class OCRService:
# Get OCR engine (for non-PDF images) # Get OCR engine (for non-PDF images)
ocr_engine = self.get_ocr_engine(lang) ocr_engine = self.get_ocr_engine(lang)
# Check GPU memory before OCR processing
if not self.check_gpu_memory(required_mb=1500):
logger.warning("Insufficient GPU memory for OCR, attempting to proceed anyway")
# Get the actual image dimensions that OCR will use # Get the actual image dimensions that OCR will use
from PIL import Image from PIL import Image
with Image.open(image_path) as img: with Image.open(image_path) as img:
@@ -686,6 +770,9 @@ class OCRService:
f"{processing_time:.2f}s" f"{processing_time:.2f}s"
) )
# Clean up GPU memory after processing
self.cleanup_gpu_memory()
return result return result
except Exception as e: except Exception as e:
@@ -804,6 +891,8 @@ class OCRService:
'bbox': elem['bbox'] 'bbox': elem['bbox']
}) })
# Clean up GPU memory after enhanced processing
self.cleanup_gpu_memory()
return layout_data, images_metadata return layout_data, images_metadata
else: else:
logger.info("parsing_res_list not available, using standard processing") logger.info("parsing_res_list not available, using standard processing")
@@ -815,6 +904,11 @@ class OCRService:
# Standard processing (original implementation) # Standard processing (original implementation)
logger.info(f"Running standard layout analysis on {image_path.name}") logger.info(f"Running standard layout analysis on {image_path.name}")
# Check GPU memory before processing
if not self.check_gpu_memory(required_mb=2000):
logger.warning("Insufficient GPU memory for PP-StructureV3, attempting to proceed anyway")
results = structure_engine.predict(str(image_path)) results = structure_engine.predict(str(image_path))
layout_elements = [] layout_elements = []
@@ -910,6 +1004,8 @@ class OCRService:
'reading_order': list(range(len(layout_elements))), 'reading_order': list(range(len(layout_elements))),
} }
logger.info(f"Detected {len(layout_elements)} layout elements") logger.info(f"Detected {len(layout_elements)} layout elements")
# Clean up GPU memory after standard processing
self.cleanup_gpu_memory()
return layout_data, images_metadata return layout_data, images_metadata
else: else:
logger.warning("No layout elements detected") logger.warning("No layout elements detected")
@@ -1135,6 +1231,10 @@ class OCRService:
# Combine results # Combine results
combined_result = self._combine_results(all_results) combined_result = self._combine_results(all_results)
combined_result['filename'] = file_path.name combined_result['filename'] = file_path.name
# Clean up GPU memory after processing all pages
self.cleanup_gpu_memory()
return combined_result return combined_result
else: else:

View File

@@ -9,7 +9,16 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any from typing import Dict, List, Optional, Tuple, Any
import json import json
import gc
# Optional torch import for additional GPU memory management
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
import paddle
from paddleocr import PPStructureV3 from paddleocr import PPStructureV3
from app.models.unified_document import ElementType from app.models.unified_document import ElementType
@@ -155,6 +164,18 @@ class PPStructureEnhanced:
logger.error(f"Enhanced PP-StructureV3 analysis error: {e}") logger.error(f"Enhanced PP-StructureV3 analysis error: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
# Clean up GPU memory on error
try:
if TORCH_AVAILABLE and torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
if paddle.device.is_compiled_with_cuda():
paddle.device.cuda.empty_cache()
gc.collect()
except:
pass # Ignore cleanup errors
return { return {
'elements': [], 'elements': [],
'total_elements': 0, 'total_elements': 0,