feat: refactor dual-track architecture (Phase 1-5)
## Backend Changes - **Service Layer Refactoring**: - Add ProcessingOrchestrator for unified document processing - Add PDFTableRenderer for table rendering extraction - Add PDFFontManager for font management with CJK support - Add MemoryPolicyEngine (73% code reduction from MemoryGuard) - **Bug Fixes**: - Fix Direct Track table row span calculation - Fix OCR Track image path handling - Add cell_boxes coordinate validation - Filter out small decorative images - Add covering image detection ## Frontend Changes - **State Management**: - Add TaskStore for centralized task state management - Add localStorage persistence for recent tasks - Add processing state tracking - **Type Consolidation**: - Merge shared types from api.ts to apiV2.ts - Update imports in authStore, uploadStore, ResultsTable, SettingsPage - **Page Integration**: - Integrate TaskStore in ProcessingPage and TaskDetailPage - Update useTaskValidation hook with cache sync ## Testing - Direct Track: edit.pdf (3 pages, 1.281s), edit3.pdf (2 pages, 0.203s) - Cell boxes validation: 43 valid, 0 invalid - Table merging: 12 merged cells verified 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1048,19 +1048,24 @@ class DirectExtractionEngine:
|
||||
bbox=cell_bbox
|
||||
))
|
||||
|
||||
# Try to detect visual column boundaries from page drawings
|
||||
# Try to detect visual column and row boundaries from page drawings
|
||||
# This is more accurate than PyMuPDF's column detection for complex tables
|
||||
visual_boundaries = self._detect_visual_column_boundaries(
|
||||
fitz_page, bbox_data, column_widths
|
||||
)
|
||||
# Use table.cells (flat list of bboxes) for more accurate row detection
|
||||
raw_table_cells = getattr(table, 'cells', None)
|
||||
row_boundaries = self._detect_visual_row_boundaries(
|
||||
fitz_page, bbox_data, raw_table_cells
|
||||
)
|
||||
|
||||
if visual_boundaries:
|
||||
# Remap cells to visual columns
|
||||
cells, column_widths, num_cols = self._remap_cells_to_visual_columns(
|
||||
cells, column_widths, num_rows, num_cols, visual_boundaries
|
||||
# Remap cells to visual columns and rows
|
||||
cells, column_widths, num_cols, num_rows = self._remap_cells_to_visual_columns(
|
||||
cells, column_widths, num_rows, num_cols, visual_boundaries, row_boundaries
|
||||
)
|
||||
else:
|
||||
# Fallback to narrow column merging
|
||||
# Fallback to narrow column merging (doesn't modify rows)
|
||||
cells, column_widths, num_cols = self._merge_narrow_columns(
|
||||
cells, column_widths, num_rows, num_cols,
|
||||
min_column_width=10.0
|
||||
@@ -1290,7 +1295,13 @@ class DirectExtractionEngine:
|
||||
|
||||
For tables with complex merged cells, PyMuPDF's column detection often
|
||||
creates too many columns. This method analyzes the visual rectangles
|
||||
(cell backgrounds) to find the true column boundaries.
|
||||
(cell backgrounds) to find the MAIN column boundaries by frequency analysis.
|
||||
|
||||
Strategy:
|
||||
1. Collect all cell rectangles from drawings
|
||||
2. Count how frequently each x boundary appears (rounded to 5pt)
|
||||
3. Keep only boundaries that appear frequently (>= threshold)
|
||||
4. These are the main column boundaries that span most rows
|
||||
|
||||
Args:
|
||||
page: PyMuPDF page object
|
||||
@@ -1301,67 +1312,215 @@ class DirectExtractionEngine:
|
||||
List of column boundary x-coordinates, or None if detection fails
|
||||
"""
|
||||
try:
|
||||
table_rect = fitz.Rect(table_bbox)
|
||||
from collections import Counter
|
||||
|
||||
# Collect cell rectangles from page drawings
|
||||
cell_rects = []
|
||||
drawings = page.get_drawings()
|
||||
for d in drawings:
|
||||
rect = fitz.Rect(d.get('rect', (0, 0, 0, 0)))
|
||||
# Filter: must intersect table, must be large enough to be a cell
|
||||
if (table_rect.intersects(rect) and
|
||||
rect.width > 30 and rect.height > 15):
|
||||
cell_rects.append(rect)
|
||||
if d.get('items'):
|
||||
for item in d['items']:
|
||||
if item[0] == 're': # Rectangle
|
||||
rect = item[1]
|
||||
# Filter: within table bounds, large enough to be a cell
|
||||
if (rect.x0 >= table_bbox[0] - 5 and
|
||||
rect.x1 <= table_bbox[2] + 5 and
|
||||
rect.y0 >= table_bbox[1] - 5 and
|
||||
rect.y1 <= table_bbox[3] + 5):
|
||||
width = rect.x1 - rect.x0
|
||||
height = rect.y1 - rect.y0
|
||||
if width > 30 and height > 15:
|
||||
cell_rects.append(rect)
|
||||
|
||||
if len(cell_rects) < 4:
|
||||
# Not enough cell rectangles detected
|
||||
logger.debug(f"Only {len(cell_rects)} cell rectangles found, skipping visual detection")
|
||||
return None
|
||||
|
||||
# Collect unique x boundaries
|
||||
all_x = set()
|
||||
logger.debug(f"Found {len(cell_rects)} cell rectangles for visual column detection")
|
||||
|
||||
# Count frequency of each boundary (rounded to 5pt)
|
||||
boundary_counts = Counter()
|
||||
for r in cell_rects:
|
||||
all_x.add(round(r.x0, 0))
|
||||
all_x.add(round(r.x1, 0))
|
||||
boundary_counts[round(r.x0 / 5) * 5] += 1
|
||||
boundary_counts[round(r.x1 / 5) * 5] += 1
|
||||
|
||||
# Merge close boundaries (within 15pt threshold)
|
||||
def merge_close(values, threshold=15):
|
||||
if not values:
|
||||
return []
|
||||
values = sorted(values)
|
||||
result = [values[0]]
|
||||
for v in values[1:]:
|
||||
if v - result[-1] > threshold:
|
||||
result.append(v)
|
||||
return result
|
||||
# Keep only boundaries that appear frequently
|
||||
# Use 8% threshold to catch internal column boundaries (like nested sub-columns)
|
||||
min_frequency = max(3, len(cell_rects) * 0.08)
|
||||
frequent_boundaries = sorted([
|
||||
x for x, count in boundary_counts.items()
|
||||
if count >= min_frequency
|
||||
])
|
||||
|
||||
boundaries = merge_close(list(all_x), threshold=15)
|
||||
# Always include table edges
|
||||
table_left = round(table_bbox[0] / 5) * 5
|
||||
table_right = round(table_bbox[2] / 5) * 5
|
||||
if not frequent_boundaries or frequent_boundaries[0] > table_left + 10:
|
||||
frequent_boundaries.insert(0, table_left)
|
||||
if not frequent_boundaries or frequent_boundaries[-1] < table_right - 10:
|
||||
frequent_boundaries.append(table_right)
|
||||
|
||||
if len(boundaries) < 3:
|
||||
logger.debug(f"Frequent boundaries (min_freq={min_frequency:.0f}): {frequent_boundaries}")
|
||||
|
||||
if len(frequent_boundaries) < 3:
|
||||
# Need at least 3 boundaries for 2 columns
|
||||
return None
|
||||
|
||||
# Calculate column widths from visual boundaries
|
||||
visual_widths = [boundaries[i+1] - boundaries[i]
|
||||
for i in range(len(boundaries)-1)]
|
||||
# Merge close boundaries (within 10pt) - take the one with higher frequency
|
||||
def merge_close_by_frequency(boundaries, counts, threshold=10):
|
||||
if not boundaries:
|
||||
return []
|
||||
result = [boundaries[0]]
|
||||
for b in boundaries[1:]:
|
||||
if b - result[-1] <= threshold:
|
||||
# Keep the one with higher frequency
|
||||
if counts[b] > counts[result[-1]]:
|
||||
result[-1] = b
|
||||
else:
|
||||
result.append(b)
|
||||
return result
|
||||
|
||||
# Filter out narrow "separator" columns (< 20pt)
|
||||
# and keep only content columns
|
||||
content_boundaries = [boundaries[0]]
|
||||
for i, width in enumerate(visual_widths):
|
||||
if width >= 20: # Content column
|
||||
content_boundaries.append(boundaries[i+1])
|
||||
# Skip narrow separator columns
|
||||
merged_boundaries = merge_close_by_frequency(
|
||||
frequent_boundaries, boundary_counts, threshold=10
|
||||
)
|
||||
|
||||
if len(content_boundaries) < 3:
|
||||
if len(merged_boundaries) < 3:
|
||||
return None
|
||||
|
||||
logger.info(f"Visual column detection: {len(content_boundaries)-1} columns from drawings")
|
||||
logger.debug(f"Visual boundaries: {content_boundaries}")
|
||||
# Calculate column widths
|
||||
widths = [merged_boundaries[i+1] - merged_boundaries[i]
|
||||
for i in range(len(merged_boundaries)-1)]
|
||||
|
||||
return content_boundaries
|
||||
logger.info(f"Visual column detection: {len(widths)} columns")
|
||||
logger.info(f" Boundaries: {merged_boundaries}")
|
||||
logger.info(f" Widths: {[round(w) for w in widths]}")
|
||||
|
||||
return merged_boundaries
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Visual column detection failed: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def _detect_visual_row_boundaries(
|
||||
self,
|
||||
page: fitz.Page,
|
||||
table_bbox: Tuple[float, float, float, float],
|
||||
table_cells: Optional[List] = None
|
||||
) -> Optional[List[float]]:
|
||||
"""
|
||||
Detect actual row boundaries from table cell bboxes.
|
||||
|
||||
Uses cell bboxes from PyMuPDF table detection for more accurate
|
||||
row boundary detection than page drawings.
|
||||
|
||||
Args:
|
||||
page: PyMuPDF page object
|
||||
table_bbox: Table bounding box (x0, y0, x1, y1)
|
||||
table_cells: List of cell bboxes from table.cells (preferred)
|
||||
|
||||
Returns:
|
||||
List of row boundary y-coordinates, or None if detection fails
|
||||
"""
|
||||
try:
|
||||
from collections import Counter
|
||||
|
||||
boundary_counts = Counter()
|
||||
cell_count = 0
|
||||
|
||||
if table_cells:
|
||||
# Use table cells directly (more accurate for row detection)
|
||||
for cell_bbox in table_cells:
|
||||
if cell_bbox:
|
||||
y0 = round(cell_bbox[1] / 5) * 5
|
||||
y1 = round(cell_bbox[3] / 5) * 5
|
||||
boundary_counts[y0] += 1
|
||||
boundary_counts[y1] += 1
|
||||
cell_count += 1
|
||||
else:
|
||||
# Fallback to page drawings
|
||||
drawings = page.get_drawings()
|
||||
for d in drawings:
|
||||
if d.get('items'):
|
||||
for item in d['items']:
|
||||
if item[0] == 're':
|
||||
rect = item[1]
|
||||
if (rect.x0 >= table_bbox[0] - 5 and
|
||||
rect.x1 <= table_bbox[2] + 5 and
|
||||
rect.y0 >= table_bbox[1] - 5 and
|
||||
rect.y1 <= table_bbox[3] + 5):
|
||||
width = rect.x1 - rect.x0
|
||||
height = rect.y1 - rect.y0
|
||||
if width > 30 and height > 15:
|
||||
y0 = round(rect.y0 / 5) * 5
|
||||
y1 = round(rect.y1 / 5) * 5
|
||||
boundary_counts[y0] += 1
|
||||
boundary_counts[y1] += 1
|
||||
cell_count += 1
|
||||
|
||||
if cell_count < 4:
|
||||
logger.debug(f"Only {cell_count} cells found, skipping visual row detection")
|
||||
return None
|
||||
|
||||
# Keep only boundaries that appear frequently
|
||||
# Use 8% threshold similar to column detection
|
||||
min_frequency = max(3, cell_count * 0.08)
|
||||
frequent_boundaries = sorted([
|
||||
y for y, count in boundary_counts.items()
|
||||
if count >= min_frequency
|
||||
])
|
||||
|
||||
# Always include table edges
|
||||
table_top = round(table_bbox[1] / 5) * 5
|
||||
table_bottom = round(table_bbox[3] / 5) * 5
|
||||
if not frequent_boundaries or frequent_boundaries[0] > table_top + 10:
|
||||
frequent_boundaries.insert(0, table_top)
|
||||
if not frequent_boundaries or frequent_boundaries[-1] < table_bottom - 10:
|
||||
frequent_boundaries.append(table_bottom)
|
||||
|
||||
logger.debug(f"Frequent Y boundaries (min_freq={min_frequency:.0f}): {frequent_boundaries}")
|
||||
|
||||
if len(frequent_boundaries) < 3:
|
||||
# Need at least 3 boundaries for 2 rows
|
||||
return None
|
||||
|
||||
# Merge close boundaries (within 10pt) - take the one with higher frequency
|
||||
def merge_close_by_frequency(boundaries, counts, threshold=10):
|
||||
if not boundaries:
|
||||
return []
|
||||
result = [boundaries[0]]
|
||||
for b in boundaries[1:]:
|
||||
if b - result[-1] <= threshold:
|
||||
# Keep the one with higher frequency
|
||||
if counts[b] > counts[result[-1]]:
|
||||
result[-1] = b
|
||||
else:
|
||||
result.append(b)
|
||||
return result
|
||||
|
||||
merged_boundaries = merge_close_by_frequency(
|
||||
frequent_boundaries, boundary_counts, threshold=10
|
||||
)
|
||||
|
||||
if len(merged_boundaries) < 3:
|
||||
return None
|
||||
|
||||
# Calculate row heights
|
||||
heights = [merged_boundaries[i+1] - merged_boundaries[i]
|
||||
for i in range(len(merged_boundaries)-1)]
|
||||
|
||||
logger.info(f"Visual row detection: {len(heights)} rows")
|
||||
logger.info(f" Y Boundaries: {merged_boundaries}")
|
||||
logger.info(f" Heights: {[round(h) for h in heights]}")
|
||||
|
||||
return merged_boundaries
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Visual row detection failed: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def _remap_cells_to_visual_columns(
|
||||
@@ -1370,8 +1529,9 @@ class DirectExtractionEngine:
|
||||
column_widths: List[float],
|
||||
num_rows: int,
|
||||
num_cols: int,
|
||||
visual_boundaries: List[float]
|
||||
) -> Tuple[List[TableCell], List[float], int]:
|
||||
visual_boundaries: List[float],
|
||||
row_boundaries: Optional[List[float]] = None
|
||||
) -> Tuple[List[TableCell], List[float], int, int]:
|
||||
"""
|
||||
Remap cells from PyMuPDF columns to visual columns based on cell bbox.
|
||||
|
||||
@@ -1381,35 +1541,64 @@ class DirectExtractionEngine:
|
||||
num_rows: Number of rows
|
||||
num_cols: Original number of columns
|
||||
visual_boundaries: Column boundaries from visual detection
|
||||
row_boundaries: Row boundaries from visual detection (optional)
|
||||
|
||||
Returns:
|
||||
Tuple of (remapped_cells, new_widths, new_num_cols)
|
||||
Tuple of (remapped_cells, new_widths, new_num_cols, new_num_rows)
|
||||
"""
|
||||
try:
|
||||
new_num_cols = len(visual_boundaries) - 1
|
||||
new_widths = [visual_boundaries[i+1] - visual_boundaries[i]
|
||||
for i in range(new_num_cols)]
|
||||
|
||||
logger.info(f"Remapping {len(cells)} cells from {num_cols} to {new_num_cols} visual columns")
|
||||
new_num_rows = len(row_boundaries) - 1 if row_boundaries else num_rows
|
||||
|
||||
# Map each cell to visual column based on its bbox center
|
||||
cell_map = {} # (row, new_col) -> list of cells
|
||||
logger.info(f"Remapping {len(cells)} cells from {num_cols} to {new_num_cols} visual columns")
|
||||
if row_boundaries:
|
||||
logger.info(f"Using {new_num_rows} visual rows for row_span calculation")
|
||||
|
||||
# Map each cell to visual column and row based on its bbox
|
||||
# This ensures spanning cells are placed at their correct position
|
||||
cell_map = {} # (visual_row, start_col) -> list of cells
|
||||
|
||||
for cell in cells:
|
||||
if not cell.bbox:
|
||||
continue
|
||||
|
||||
# Find which visual column this cell belongs to
|
||||
cell_center_x = (cell.bbox.x0 + cell.bbox.x1) / 2
|
||||
new_col = 0
|
||||
for i in range(new_num_cols):
|
||||
if visual_boundaries[i] <= cell_center_x < visual_boundaries[i+1]:
|
||||
new_col = i
|
||||
break
|
||||
elif cell_center_x >= visual_boundaries[-1]:
|
||||
new_col = new_num_cols - 1
|
||||
# Find start column based on left edge of cell
|
||||
cell_x0 = cell.bbox.x0
|
||||
start_col = 0
|
||||
|
||||
key = (cell.row, new_col)
|
||||
# First check if cell_x0 is very close to any boundary (within 5pt)
|
||||
# If so, it belongs to the column that starts at that boundary
|
||||
snapped = False
|
||||
for i in range(1, len(visual_boundaries)): # Skip first (left edge)
|
||||
if abs(cell_x0 - visual_boundaries[i]) <= 5:
|
||||
start_col = min(i, new_num_cols - 1)
|
||||
snapped = True
|
||||
break
|
||||
|
||||
# If not snapped to boundary, use standard containment check
|
||||
if not snapped:
|
||||
for i in range(new_num_cols):
|
||||
if visual_boundaries[i] <= cell_x0 < visual_boundaries[i+1]:
|
||||
start_col = i
|
||||
break
|
||||
elif cell_x0 >= visual_boundaries[-1]:
|
||||
start_col = new_num_cols - 1
|
||||
|
||||
# Find visual row based on top edge of cell
|
||||
visual_row = cell.row # Default to original row
|
||||
if row_boundaries:
|
||||
cell_y0 = cell.bbox.y0
|
||||
for i in range(new_num_rows):
|
||||
if row_boundaries[i] <= cell_y0 + 5 < row_boundaries[i+1]:
|
||||
visual_row = i
|
||||
break
|
||||
elif cell_y0 >= row_boundaries[-1] - 5:
|
||||
visual_row = new_num_rows - 1
|
||||
|
||||
key = (visual_row, start_col)
|
||||
if key not in cell_map:
|
||||
cell_map[key] = []
|
||||
cell_map[key].append(cell)
|
||||
@@ -1418,8 +1607,8 @@ class DirectExtractionEngine:
|
||||
remapped_cells = []
|
||||
processed = set()
|
||||
|
||||
for (row, new_col), cell_list in sorted(cell_map.items()):
|
||||
if (row, new_col) in processed:
|
||||
for (visual_row, start_col), cell_list in sorted(cell_map.items()):
|
||||
if (visual_row, start_col) in processed:
|
||||
continue
|
||||
|
||||
# Sort by original column
|
||||
@@ -1433,23 +1622,35 @@ class DirectExtractionEngine:
|
||||
|
||||
merged_content = '\n'.join(contents) if contents else ''
|
||||
|
||||
# Use the first cell for span info
|
||||
base_cell = cell_list[0]
|
||||
# Use the cell with tallest bbox for row span calculation
|
||||
# (handles case where multiple cells merge into one)
|
||||
tallest_cell = max(cell_list, key=lambda c: (c.bbox.y1 - c.bbox.y0) if c.bbox else 0)
|
||||
widest_cell = max(cell_list, key=lambda c: (c.bbox.x1 - c.bbox.x0) if c.bbox else 0)
|
||||
|
||||
# Calculate col_span based on visual boundaries
|
||||
if base_cell.bbox:
|
||||
cell_x1 = base_cell.bbox.x1
|
||||
# Find end column
|
||||
end_col = new_col
|
||||
for i in range(new_col, new_num_cols):
|
||||
if visual_boundaries[i+1] <= cell_x1 + 5: # 5pt tolerance
|
||||
# Calculate col_span based on right edge of widest cell
|
||||
col_span = 1
|
||||
if widest_cell.bbox:
|
||||
cell_x1 = widest_cell.bbox.x1
|
||||
end_col = start_col
|
||||
for i in range(start_col, new_num_cols):
|
||||
if cell_x1 > visual_boundaries[i] + 5: # 5pt tolerance
|
||||
end_col = i
|
||||
col_span = max(1, end_col - new_col + 1)
|
||||
else:
|
||||
col_span = 1
|
||||
col_span = max(1, end_col - start_col + 1)
|
||||
|
||||
# Calculate row_span based on visual row boundaries
|
||||
row_span = 1
|
||||
if row_boundaries and tallest_cell.bbox:
|
||||
cell_y1 = tallest_cell.bbox.y1
|
||||
|
||||
# Find end row based on bottom edge of tallest cell
|
||||
end_row = visual_row
|
||||
for i in range(visual_row, new_num_rows):
|
||||
if cell_y1 > row_boundaries[i] + 5: # 5pt tolerance
|
||||
end_row = i
|
||||
row_span = max(1, end_row - visual_row + 1)
|
||||
|
||||
# Merge bbox from all cells
|
||||
merged_bbox = base_cell.bbox
|
||||
merged_bbox = tallest_cell.bbox
|
||||
for c in cell_list:
|
||||
if c.bbox and merged_bbox:
|
||||
merged_bbox = BoundingBox(
|
||||
@@ -1462,23 +1663,39 @@ class DirectExtractionEngine:
|
||||
merged_bbox = c.bbox
|
||||
|
||||
remapped_cells.append(TableCell(
|
||||
row=row,
|
||||
col=new_col,
|
||||
row_span=base_cell.row_span,
|
||||
row=visual_row,
|
||||
col=start_col,
|
||||
row_span=row_span,
|
||||
col_span=col_span,
|
||||
content=merged_content,
|
||||
bbox=merged_bbox
|
||||
))
|
||||
processed.add((row, new_col))
|
||||
processed.add((visual_row, start_col))
|
||||
|
||||
logger.info(f"Remapped to {len(remapped_cells)} cells in {new_num_cols} columns")
|
||||
# Filter out cells that are covered by spans from other cells
|
||||
# Build a set of positions covered by spans
|
||||
covered_positions = set()
|
||||
for cell in remapped_cells:
|
||||
if cell.col_span > 1 or cell.row_span > 1:
|
||||
for r in range(cell.row, cell.row + cell.row_span):
|
||||
for c in range(cell.col, cell.col + cell.col_span):
|
||||
if (r, c) != (cell.row, cell.col): # Don't cover the origin
|
||||
covered_positions.add((r, c))
|
||||
|
||||
return remapped_cells, new_widths, new_num_cols
|
||||
# Remove covered cells
|
||||
final_cells = [
|
||||
cell for cell in remapped_cells
|
||||
if (cell.row, cell.col) not in covered_positions
|
||||
]
|
||||
|
||||
logger.info(f"Remapped to {len(final_cells)} cells in {new_num_cols} columns x {new_num_rows} rows (filtered {len(remapped_cells) - len(final_cells)} covered cells)")
|
||||
|
||||
return final_cells, new_widths, new_num_cols, new_num_rows
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cell remapping failed: {e}")
|
||||
# Fallback to original
|
||||
return cells, column_widths, num_cols
|
||||
return cells, column_widths, num_cols, num_rows
|
||||
|
||||
def _detect_tables_by_position(self, page: fitz.Page, page_num: int, counter: int) -> List[DocumentElement]:
|
||||
"""Detect tables by analyzing text positioning"""
|
||||
@@ -2138,12 +2355,23 @@ class DirectExtractionEngine:
|
||||
logger.warning(f"Custom clustering failed ({e}), using fallback method")
|
||||
drawing_clusters = self._cluster_drawings_fallback(page, non_table_drawings)
|
||||
|
||||
# Get page dimensions for filtering
|
||||
page_rect = page.rect
|
||||
page_area = page_rect.width * page_rect.height
|
||||
|
||||
for cluster_idx, bbox in enumerate(drawing_clusters):
|
||||
# Ignore small regions (likely noise or separator lines)
|
||||
if bbox.width < 50 or bbox.height < 50:
|
||||
logger.debug(f"Skipping small cluster {cluster_idx}: {bbox.width:.1f}x{bbox.height:.1f}")
|
||||
continue
|
||||
|
||||
# Ignore very large regions that cover most of the page
|
||||
# These are usually background elements, page borders, or misdetected regions
|
||||
cluster_area = bbox.width * bbox.height
|
||||
if cluster_area > page_area * 0.7: # More than 70% of page
|
||||
logger.debug(f"Skipping large cluster {cluster_idx}: covers {cluster_area/page_area*100:.0f}% of page")
|
||||
continue
|
||||
|
||||
# Render the region to a raster image
|
||||
# matrix=fitz.Matrix(2, 2) increases resolution to ~200 DPI
|
||||
try:
|
||||
|
||||
791
backend/app/services/memory_policy_engine.py
Normal file
791
backend/app/services/memory_policy_engine.py
Normal file
@@ -0,0 +1,791 @@
|
||||
"""
|
||||
Memory Policy Engine - Simplified memory management for OCR processing.
|
||||
|
||||
This module consolidates the essential memory management features:
|
||||
- GPU memory monitoring
|
||||
- Prediction concurrency control
|
||||
- Model lifecycle management
|
||||
|
||||
Removed unused features from the original memory_manager.py:
|
||||
- BatchProcessor
|
||||
- ProgressiveLoader
|
||||
- PriorityOperationQueue
|
||||
- RecoveryManager
|
||||
- MemoryDumper
|
||||
- PrometheusMetrics
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class MemoryPolicyConfig:
|
||||
"""
|
||||
Simplified memory policy configuration.
|
||||
|
||||
Only includes parameters that are actually used in production.
|
||||
"""
|
||||
# GPU memory thresholds (ratio 0.0-1.0)
|
||||
warning_threshold: float = 0.80
|
||||
critical_threshold: float = 0.95
|
||||
emergency_threshold: float = 0.98
|
||||
|
||||
# Model management
|
||||
model_idle_timeout_seconds: int = 300 # 5 minutes
|
||||
memory_check_interval_seconds: int = 30
|
||||
|
||||
# Concurrency control
|
||||
max_concurrent_predictions: int = 2
|
||||
prediction_timeout_seconds: float = 300.0
|
||||
|
||||
# GPU settings
|
||||
gpu_memory_limit_mb: int = 6144 # 6GB default
|
||||
|
||||
|
||||
class MemoryBackend(Enum):
|
||||
"""Available memory monitoring backends."""
|
||||
PYNVML = "pynvml"
|
||||
TORCH = "torch"
|
||||
PADDLE = "paddle"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Memory Statistics
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class MemoryStats:
|
||||
"""Memory usage statistics."""
|
||||
gpu_used_mb: float = 0.0
|
||||
gpu_free_mb: float = 0.0
|
||||
gpu_total_mb: float = 0.0
|
||||
gpu_used_ratio: float = 0.0
|
||||
cpu_used_mb: float = 0.0
|
||||
cpu_available_mb: float = 0.0
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
backend: str = "none"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryAlert:
|
||||
"""Memory alert record."""
|
||||
level: str # warning, critical, emergency
|
||||
message: str
|
||||
stats: MemoryStats
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# GPU Memory Monitor
|
||||
# ============================================================================
|
||||
|
||||
class GPUMemoryMonitor:
|
||||
"""
|
||||
Monitors GPU memory usage with multiple backend support.
|
||||
|
||||
Priority: pynvml > torch > paddle > none
|
||||
"""
|
||||
|
||||
def __init__(self, config: MemoryPolicyConfig):
|
||||
self.config = config
|
||||
self._backend: MemoryBackend = MemoryBackend.NONE
|
||||
self._nvml_handle = None
|
||||
self._history: deque = deque(maxlen=100)
|
||||
self._alerts: deque = deque(maxlen=50)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._init_backend()
|
||||
|
||||
def _init_backend(self):
|
||||
"""Initialize the best available memory monitoring backend."""
|
||||
# Try pynvml first (most accurate)
|
||||
try:
|
||||
import pynvml
|
||||
pynvml.nvmlInit()
|
||||
self._nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
self._backend = MemoryBackend.PYNVML
|
||||
logger.info("GPU memory monitoring using pynvml")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"pynvml not available: {e}")
|
||||
|
||||
# Try torch
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
self._backend = MemoryBackend.TORCH
|
||||
logger.info("GPU memory monitoring using torch")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"torch CUDA not available: {e}")
|
||||
|
||||
# Try paddle
|
||||
try:
|
||||
import paddle
|
||||
if paddle.is_compiled_with_cuda():
|
||||
self._backend = MemoryBackend.PADDLE
|
||||
logger.info("GPU memory monitoring using paddle")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"paddle CUDA not available: {e}")
|
||||
|
||||
logger.warning("No GPU memory monitoring available")
|
||||
|
||||
def get_stats(self, device_id: int = 0) -> MemoryStats:
|
||||
"""Get current memory statistics."""
|
||||
stats = MemoryStats(backend=self._backend.value)
|
||||
|
||||
try:
|
||||
if self._backend == MemoryBackend.PYNVML:
|
||||
stats = self._get_pynvml_stats(device_id)
|
||||
elif self._backend == MemoryBackend.TORCH:
|
||||
stats = self._get_torch_stats(device_id)
|
||||
elif self._backend == MemoryBackend.PADDLE:
|
||||
stats = self._get_paddle_stats(device_id)
|
||||
|
||||
# Add CPU stats
|
||||
stats = self._add_cpu_stats(stats)
|
||||
|
||||
# Store in history
|
||||
with self._lock:
|
||||
self._history.append(stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get memory stats: {e}")
|
||||
|
||||
return stats
|
||||
|
||||
def _get_pynvml_stats(self, device_id: int) -> MemoryStats:
|
||||
"""Get stats using pynvml."""
|
||||
import pynvml
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
|
||||
return MemoryStats(
|
||||
gpu_used_mb=info.used / (1024 * 1024),
|
||||
gpu_free_mb=info.free / (1024 * 1024),
|
||||
gpu_total_mb=info.total / (1024 * 1024),
|
||||
gpu_used_ratio=info.used / info.total if info.total > 0 else 0,
|
||||
backend="pynvml"
|
||||
)
|
||||
|
||||
def _get_torch_stats(self, device_id: int) -> MemoryStats:
|
||||
"""Get stats using torch."""
|
||||
import torch
|
||||
allocated = torch.cuda.memory_allocated(device_id)
|
||||
reserved = torch.cuda.memory_reserved(device_id)
|
||||
total = torch.cuda.get_device_properties(device_id).total_memory
|
||||
|
||||
return MemoryStats(
|
||||
gpu_used_mb=reserved / (1024 * 1024),
|
||||
gpu_free_mb=(total - reserved) / (1024 * 1024),
|
||||
gpu_total_mb=total / (1024 * 1024),
|
||||
gpu_used_ratio=reserved / total if total > 0 else 0,
|
||||
backend="torch"
|
||||
)
|
||||
|
||||
def _get_paddle_stats(self, device_id: int) -> MemoryStats:
|
||||
"""Get stats using paddle."""
|
||||
import paddle
|
||||
allocated = paddle.device.cuda.memory_allocated(device_id)
|
||||
reserved = paddle.device.cuda.memory_reserved(device_id)
|
||||
total = paddle.device.cuda.get_device_properties(device_id).total_memory
|
||||
|
||||
return MemoryStats(
|
||||
gpu_used_mb=reserved / (1024 * 1024),
|
||||
gpu_free_mb=(total - reserved) / (1024 * 1024),
|
||||
gpu_total_mb=total / (1024 * 1024),
|
||||
gpu_used_ratio=reserved / total if total > 0 else 0,
|
||||
backend="paddle"
|
||||
)
|
||||
|
||||
def _add_cpu_stats(self, stats: MemoryStats) -> MemoryStats:
|
||||
"""Add CPU memory stats."""
|
||||
try:
|
||||
import psutil
|
||||
mem = psutil.virtual_memory()
|
||||
stats.cpu_used_mb = mem.used / (1024 * 1024)
|
||||
stats.cpu_available_mb = mem.available / (1024 * 1024)
|
||||
except Exception:
|
||||
pass
|
||||
return stats
|
||||
|
||||
def check_memory(self, required_mb: float = 0, device_id: int = 0) -> Tuple[bool, str]:
|
||||
"""
|
||||
Check if memory is available.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, message)
|
||||
"""
|
||||
stats = self.get_stats(device_id)
|
||||
|
||||
# Check thresholds
|
||||
if stats.gpu_used_ratio >= self.config.emergency_threshold:
|
||||
msg = f"Emergency: GPU at {stats.gpu_used_ratio*100:.1f}%"
|
||||
self._add_alert("emergency", msg, stats)
|
||||
return False, msg
|
||||
|
||||
if stats.gpu_used_ratio >= self.config.critical_threshold:
|
||||
msg = f"Critical: GPU at {stats.gpu_used_ratio*100:.1f}%"
|
||||
self._add_alert("critical", msg, stats)
|
||||
return False, msg
|
||||
|
||||
if stats.gpu_used_ratio >= self.config.warning_threshold:
|
||||
msg = f"Warning: GPU at {stats.gpu_used_ratio*100:.1f}%"
|
||||
self._add_alert("warning", msg, stats)
|
||||
# Warning doesn't block, just logs
|
||||
|
||||
# Check if required memory is available
|
||||
if required_mb > 0 and stats.gpu_free_mb < required_mb:
|
||||
msg = f"Insufficient memory: need {required_mb}MB, have {stats.gpu_free_mb:.0f}MB"
|
||||
return False, msg
|
||||
|
||||
return True, "OK"
|
||||
|
||||
def _add_alert(self, level: str, message: str, stats: MemoryStats):
|
||||
"""Add alert to history."""
|
||||
with self._lock:
|
||||
self._alerts.append(MemoryAlert(
|
||||
level=level,
|
||||
message=message,
|
||||
stats=stats
|
||||
))
|
||||
log_func = getattr(logger, level if level != "emergency" else "error")
|
||||
log_func(message)
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear GPU memory caches."""
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
import paddle
|
||||
if paddle.is_compiled_with_cuda():
|
||||
paddle.device.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
gc.collect()
|
||||
|
||||
def get_alerts(self, limit: int = 10) -> List[MemoryAlert]:
|
||||
"""Get recent alerts."""
|
||||
with self._lock:
|
||||
return list(self._alerts)[-limit:]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Prediction Semaphore
|
||||
# ============================================================================
|
||||
|
||||
class PredictionSemaphore:
|
||||
"""
|
||||
Controls concurrent predictions to prevent GPU OOM.
|
||||
|
||||
Singleton pattern ensures single point of concurrency control.
|
||||
"""
|
||||
|
||||
_instance: Optional['PredictionSemaphore'] = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, max_concurrent: int = 2):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
instance = super().__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instance = instance
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, max_concurrent: int = 2):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._semaphore = threading.Semaphore(max_concurrent)
|
||||
self._max_concurrent = max_concurrent
|
||||
self._active_count = 0
|
||||
self._queue_depth = 0
|
||||
self._stats_lock = threading.Lock()
|
||||
|
||||
# Metrics
|
||||
self._total_predictions = 0
|
||||
self._total_timeouts = 0
|
||||
self._total_wait_time = 0.0
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"PredictionSemaphore initialized: max_concurrent={max_concurrent}")
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset singleton (for testing)."""
|
||||
with cls._lock:
|
||||
cls._instance = None
|
||||
|
||||
def acquire(self, timeout: float = 300.0, task_id: str = "") -> bool:
|
||||
"""
|
||||
Acquire a prediction slot.
|
||||
|
||||
Args:
|
||||
timeout: Maximum wait time in seconds
|
||||
task_id: Optional task identifier for logging
|
||||
|
||||
Returns:
|
||||
True if acquired, False on timeout
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
with self._stats_lock:
|
||||
self._queue_depth += 1
|
||||
|
||||
try:
|
||||
acquired = self._semaphore.acquire(timeout=timeout)
|
||||
|
||||
wait_time = time.time() - start_time
|
||||
|
||||
with self._stats_lock:
|
||||
self._queue_depth -= 1
|
||||
if acquired:
|
||||
self._active_count += 1
|
||||
self._total_predictions += 1
|
||||
self._total_wait_time += wait_time
|
||||
else:
|
||||
self._total_timeouts += 1
|
||||
|
||||
if not acquired:
|
||||
logger.warning(f"Prediction semaphore timeout after {timeout}s")
|
||||
|
||||
return acquired
|
||||
|
||||
except Exception as e:
|
||||
with self._stats_lock:
|
||||
self._queue_depth -= 1
|
||||
logger.error(f"Semaphore acquire error: {e}")
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
"""Release a prediction slot."""
|
||||
with self._stats_lock:
|
||||
if self._active_count > 0:
|
||||
self._active_count -= 1
|
||||
self._semaphore.release()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get semaphore statistics."""
|
||||
with self._stats_lock:
|
||||
avg_wait = (self._total_wait_time / self._total_predictions
|
||||
if self._total_predictions > 0 else 0)
|
||||
return {
|
||||
"max_concurrent": self._max_concurrent,
|
||||
"active_predictions": self._active_count,
|
||||
"queue_depth": self._queue_depth,
|
||||
"total_predictions": self._total_predictions,
|
||||
"total_timeouts": self._total_timeouts,
|
||||
"average_wait_seconds": avg_wait
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def prediction_context(timeout: float = 300.0, task_id: str = ""):
|
||||
"""
|
||||
Context manager for prediction semaphore.
|
||||
|
||||
Usage:
|
||||
with prediction_context(timeout=300) as acquired:
|
||||
if acquired:
|
||||
# run prediction
|
||||
"""
|
||||
semaphore = get_prediction_semaphore()
|
||||
acquired = semaphore.acquire(timeout=timeout, task_id=task_id)
|
||||
try:
|
||||
yield acquired
|
||||
finally:
|
||||
if acquired:
|
||||
semaphore.release()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Model Manager
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about a loaded model."""
|
||||
model_id: str
|
||||
model: Any
|
||||
reference_count: int = 0
|
||||
loaded_at: datetime = field(default_factory=datetime.now)
|
||||
last_used: datetime = field(default_factory=datetime.now)
|
||||
estimated_memory_mb: float = 0.0
|
||||
cleanup_callback: Optional[Callable] = None
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""
|
||||
Manages model lifecycle with reference counting and idle cleanup.
|
||||
|
||||
Features:
|
||||
- Reference-counted model loading
|
||||
- Automatic unload after idle timeout
|
||||
- LRU eviction on memory pressure
|
||||
"""
|
||||
|
||||
_instance: Optional['ModelManager'] = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, config: Optional[MemoryPolicyConfig] = None):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
instance = super().__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instance = instance
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Optional[MemoryPolicyConfig] = None):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.config = config or MemoryPolicyConfig()
|
||||
self._models: Dict[str, ModelInfo] = {}
|
||||
self._models_lock = threading.Lock()
|
||||
self._monitor = GPUMemoryMonitor(self.config)
|
||||
|
||||
# Background cleanup thread
|
||||
self._shutdown = threading.Event()
|
||||
self._cleanup_thread = threading.Thread(
|
||||
target=self._cleanup_loop,
|
||||
daemon=True,
|
||||
name="ModelManager-Cleanup"
|
||||
)
|
||||
self._cleanup_thread.start()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("ModelManager initialized")
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset singleton (for testing)."""
|
||||
with cls._lock:
|
||||
if cls._instance is not None:
|
||||
cls._instance.shutdown()
|
||||
cls._instance = None
|
||||
|
||||
def get_or_load(
|
||||
self,
|
||||
model_id: str,
|
||||
loader_func: Callable[[], Any],
|
||||
estimated_memory_mb: float = 0,
|
||||
cleanup_callback: Optional[Callable] = None
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Get a model, loading it if necessary.
|
||||
|
||||
Args:
|
||||
model_id: Unique identifier for the model
|
||||
loader_func: Function to load the model if not cached
|
||||
estimated_memory_mb: Estimated GPU memory usage
|
||||
cleanup_callback: Optional cleanup function when unloading
|
||||
|
||||
Returns:
|
||||
The model, or None if loading failed
|
||||
"""
|
||||
with self._models_lock:
|
||||
# Check if already loaded
|
||||
if model_id in self._models:
|
||||
info = self._models[model_id]
|
||||
info.reference_count += 1
|
||||
info.last_used = datetime.now()
|
||||
logger.debug(f"Model {model_id} retrieved (refs={info.reference_count})")
|
||||
return info.model
|
||||
|
||||
# Check memory before loading
|
||||
if estimated_memory_mb > 0:
|
||||
available, msg = self._monitor.check_memory(estimated_memory_mb)
|
||||
if not available:
|
||||
logger.warning(f"Cannot load {model_id}: {msg}")
|
||||
# Try eviction
|
||||
if not self._evict_lru(estimated_memory_mb):
|
||||
return None
|
||||
|
||||
# Load the model
|
||||
try:
|
||||
logger.info(f"Loading model {model_id}")
|
||||
model = loader_func()
|
||||
|
||||
self._models[model_id] = ModelInfo(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
reference_count=1,
|
||||
estimated_memory_mb=estimated_memory_mb,
|
||||
cleanup_callback=cleanup_callback
|
||||
)
|
||||
|
||||
logger.info(f"Model {model_id} loaded successfully")
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model {model_id}: {e}")
|
||||
return None
|
||||
|
||||
def release(self, model_id: str):
|
||||
"""Release a reference to a model."""
|
||||
with self._models_lock:
|
||||
if model_id in self._models:
|
||||
info = self._models[model_id]
|
||||
info.reference_count = max(0, info.reference_count - 1)
|
||||
logger.debug(f"Model {model_id} released (refs={info.reference_count})")
|
||||
|
||||
def unload(self, model_id: str, force: bool = False) -> bool:
|
||||
"""
|
||||
Unload a model from memory.
|
||||
|
||||
Args:
|
||||
model_id: Model to unload
|
||||
force: If True, unload even if references exist
|
||||
|
||||
Returns:
|
||||
True if unloaded
|
||||
"""
|
||||
with self._models_lock:
|
||||
if model_id not in self._models:
|
||||
return False
|
||||
|
||||
info = self._models[model_id]
|
||||
|
||||
if not force and info.reference_count > 0:
|
||||
logger.warning(f"Cannot unload {model_id}: {info.reference_count} refs")
|
||||
return False
|
||||
|
||||
# Run cleanup callback
|
||||
if info.cleanup_callback:
|
||||
try:
|
||||
info.cleanup_callback(info.model)
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup callback failed for {model_id}: {e}")
|
||||
|
||||
# Remove model
|
||||
del self._models[model_id]
|
||||
logger.info(f"Model {model_id} unloaded")
|
||||
|
||||
# Clear GPU cache
|
||||
self._monitor.clear_cache()
|
||||
return True
|
||||
|
||||
def _evict_lru(self, required_mb: float) -> bool:
|
||||
"""Evict least-recently-used models to free memory."""
|
||||
freed_mb = 0.0
|
||||
|
||||
# Sort by last_used (oldest first)
|
||||
candidates = sorted(
|
||||
[(k, v) for k, v in self._models.items() if v.reference_count == 0],
|
||||
key=lambda x: x[1].last_used
|
||||
)
|
||||
|
||||
for model_id, info in candidates:
|
||||
if freed_mb >= required_mb:
|
||||
break
|
||||
|
||||
if self.unload(model_id, force=True):
|
||||
freed_mb += info.estimated_memory_mb
|
||||
logger.info(f"Evicted {model_id}, freed ~{info.estimated_memory_mb}MB")
|
||||
|
||||
return freed_mb >= required_mb
|
||||
|
||||
def _cleanup_loop(self):
|
||||
"""Background thread for idle model cleanup."""
|
||||
while not self._shutdown.is_set():
|
||||
self._shutdown.wait(self.config.memory_check_interval_seconds)
|
||||
|
||||
if self._shutdown.is_set():
|
||||
break
|
||||
|
||||
self._cleanup_idle_models()
|
||||
|
||||
def _cleanup_idle_models(self):
|
||||
"""Unload models that have been idle too long."""
|
||||
now = datetime.now()
|
||||
timeout = self.config.model_idle_timeout_seconds
|
||||
|
||||
with self._models_lock:
|
||||
to_unload = []
|
||||
|
||||
for model_id, info in self._models.items():
|
||||
if info.reference_count > 0:
|
||||
continue
|
||||
|
||||
idle_seconds = (now - info.last_used).total_seconds()
|
||||
if idle_seconds > timeout:
|
||||
to_unload.append(model_id)
|
||||
|
||||
for model_id in to_unload:
|
||||
self.unload(model_id)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get model manager statistics."""
|
||||
with self._models_lock:
|
||||
models_info = {}
|
||||
for model_id, info in self._models.items():
|
||||
models_info[model_id] = {
|
||||
"reference_count": info.reference_count,
|
||||
"loaded_at": info.loaded_at.isoformat(),
|
||||
"last_used": info.last_used.isoformat(),
|
||||
"estimated_memory_mb": info.estimated_memory_mb
|
||||
}
|
||||
|
||||
return {
|
||||
"total_models": len(self._models),
|
||||
"models": models_info,
|
||||
"memory": self._monitor.get_stats().__dict__
|
||||
}
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the model manager."""
|
||||
logger.info("Shutting down ModelManager")
|
||||
self._shutdown.set()
|
||||
|
||||
# Unload all models
|
||||
with self._models_lock:
|
||||
for model_id in list(self._models.keys()):
|
||||
self.unload(model_id, force=True)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Memory Policy Engine (Unified Interface)
|
||||
# ============================================================================
|
||||
|
||||
class MemoryPolicyEngine:
|
||||
"""
|
||||
Unified memory policy engine.
|
||||
|
||||
Provides a single entry point for all memory management operations:
|
||||
- GPU memory monitoring
|
||||
- Prediction concurrency control
|
||||
- Model lifecycle management
|
||||
"""
|
||||
|
||||
_instance: Optional['MemoryPolicyEngine'] = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, config: Optional[MemoryPolicyConfig] = None):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
instance = super().__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instance = instance
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Optional[MemoryPolicyConfig] = None):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.config = config or MemoryPolicyConfig()
|
||||
self._monitor = GPUMemoryMonitor(self.config)
|
||||
self._model_manager = ModelManager(self.config)
|
||||
self._prediction_semaphore = PredictionSemaphore(
|
||||
self.config.max_concurrent_predictions
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("MemoryPolicyEngine initialized")
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset singleton (for testing)."""
|
||||
with cls._lock:
|
||||
if cls._instance is not None:
|
||||
cls._instance.shutdown()
|
||||
cls._instance = None
|
||||
ModelManager.reset()
|
||||
PredictionSemaphore.reset()
|
||||
|
||||
@property
|
||||
def monitor(self) -> GPUMemoryMonitor:
|
||||
"""Get the GPU memory monitor."""
|
||||
return self._monitor
|
||||
|
||||
@property
|
||||
def model_manager(self) -> ModelManager:
|
||||
"""Get the model manager."""
|
||||
return self._model_manager
|
||||
|
||||
@property
|
||||
def prediction_semaphore(self) -> PredictionSemaphore:
|
||||
"""Get the prediction semaphore."""
|
||||
return self._prediction_semaphore
|
||||
|
||||
def check_memory(self, required_mb: float = 0) -> Tuple[bool, str]:
|
||||
"""Check if memory is available."""
|
||||
return self._monitor.check_memory(required_mb)
|
||||
|
||||
def get_memory_stats(self) -> MemoryStats:
|
||||
"""Get current memory statistics."""
|
||||
return self._monitor.get_stats()
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear GPU memory caches."""
|
||||
self._monitor.clear_cache()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive statistics."""
|
||||
return {
|
||||
"memory": self._monitor.get_stats().__dict__,
|
||||
"models": self._model_manager.get_stats(),
|
||||
"predictions": self._prediction_semaphore.get_stats()
|
||||
}
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown all components."""
|
||||
logger.info("Shutting down MemoryPolicyEngine")
|
||||
self._model_manager.shutdown()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Convenience Functions
|
||||
# ============================================================================
|
||||
|
||||
_engine: Optional[MemoryPolicyEngine] = None
|
||||
|
||||
|
||||
def get_memory_policy_engine(config: Optional[MemoryPolicyConfig] = None) -> MemoryPolicyEngine:
|
||||
"""Get the global MemoryPolicyEngine instance."""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = MemoryPolicyEngine(config)
|
||||
return _engine
|
||||
|
||||
|
||||
def get_prediction_semaphore(max_concurrent: int = 2) -> PredictionSemaphore:
|
||||
"""Get the global PredictionSemaphore instance."""
|
||||
return PredictionSemaphore(max_concurrent)
|
||||
|
||||
|
||||
def get_model_manager(config: Optional[MemoryPolicyConfig] = None) -> ModelManager:
|
||||
"""Get the global ModelManager instance."""
|
||||
return ModelManager(config)
|
||||
|
||||
|
||||
def shutdown_memory_policy():
|
||||
"""Shutdown all memory management components."""
|
||||
global _engine
|
||||
if _engine is not None:
|
||||
_engine.shutdown()
|
||||
_engine = None
|
||||
MemoryPolicyEngine.reset()
|
||||
@@ -26,6 +26,10 @@ except ImportError:
|
||||
from app.core.config import settings
|
||||
from app.services.office_converter import OfficeConverter, OfficeConverterError
|
||||
from app.services.memory_manager import get_model_manager, MemoryConfig, MemoryGuard, prediction_context
|
||||
from app.services.memory_policy_engine import (
|
||||
MemoryPolicyEngine, MemoryPolicyConfig, get_memory_policy_engine,
|
||||
prediction_context as new_prediction_context
|
||||
)
|
||||
from app.services.layout_preprocessing_service import (
|
||||
get_layout_preprocessing_service,
|
||||
LayoutPreprocessingService,
|
||||
@@ -38,6 +42,9 @@ try:
|
||||
from app.services.direct_extraction_engine import DirectExtractionEngine
|
||||
from app.services.ocr_to_unified_converter import OCRToUnifiedConverter
|
||||
from app.services.unified_document_exporter import UnifiedDocumentExporter
|
||||
from app.services.processing_orchestrator import (
|
||||
ProcessingOrchestrator, ProcessingConfig, ProcessingResult
|
||||
)
|
||||
from app.models.unified_document import (
|
||||
UnifiedDocument, DocumentMetadata,
|
||||
ProcessingTrack, ElementType, DocumentElement, Page, Dimensions,
|
||||
@@ -48,6 +55,7 @@ except ImportError as e:
|
||||
logging.getLogger(__name__).warning(f"Dual-track components not available: {e}")
|
||||
DUAL_TRACK_AVAILABLE = False
|
||||
UnifiedDocumentExporter = None
|
||||
ProcessingOrchestrator = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -98,11 +106,16 @@ class OCRService:
|
||||
)
|
||||
self.ocr_to_unified_converter = OCRToUnifiedConverter()
|
||||
self.dual_track_enabled = True
|
||||
logger.info("Dual-track processing enabled")
|
||||
|
||||
# Initialize ProcessingOrchestrator for cleaner flow control
|
||||
self._orchestrator = ProcessingOrchestrator()
|
||||
self._orchestrator.set_ocr_service(self) # Dependency injection
|
||||
logger.info("Dual-track processing enabled (with ProcessingOrchestrator)")
|
||||
else:
|
||||
self.document_detector = None
|
||||
self.direct_extraction_engine = None
|
||||
self.ocr_to_unified_converter = None
|
||||
self._orchestrator = None
|
||||
self.dual_track_enabled = False
|
||||
logger.info("Dual-track processing not available, using OCR-only mode")
|
||||
|
||||
@@ -115,22 +128,39 @@ class OCRService:
|
||||
self._model_last_used = {} # Track last usage time for each model
|
||||
self._memory_warning_logged = False
|
||||
|
||||
# Initialize MemoryGuard for enhanced memory monitoring
|
||||
# Initialize memory management (use new MemoryPolicyEngine)
|
||||
self._memory_guard = None
|
||||
self._memory_policy_engine = None
|
||||
if settings.enable_model_lifecycle_management:
|
||||
try:
|
||||
memory_config = MemoryConfig(
|
||||
# Use new MemoryPolicyEngine (simplified, consolidated)
|
||||
policy_config = MemoryPolicyConfig(
|
||||
warning_threshold=settings.memory_warning_threshold,
|
||||
critical_threshold=settings.memory_critical_threshold,
|
||||
emergency_threshold=settings.memory_emergency_threshold,
|
||||
model_idle_timeout_seconds=settings.pp_structure_idle_timeout_seconds,
|
||||
gpu_memory_limit_mb=settings.gpu_memory_limit_mb,
|
||||
enable_cpu_fallback=settings.enable_cpu_fallback,
|
||||
max_concurrent_predictions=2,
|
||||
prediction_timeout_seconds=settings.service_acquire_timeout_seconds,
|
||||
)
|
||||
self._memory_guard = MemoryGuard(memory_config)
|
||||
logger.debug("MemoryGuard initialized for OCRService")
|
||||
self._memory_policy_engine = get_memory_policy_engine(policy_config)
|
||||
logger.info("MemoryPolicyEngine initialized for OCRService")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize MemoryGuard: {e}")
|
||||
logger.warning(f"Failed to initialize MemoryPolicyEngine: {e}")
|
||||
# Fallback to legacy MemoryGuard
|
||||
try:
|
||||
memory_config = MemoryConfig(
|
||||
warning_threshold=settings.memory_warning_threshold,
|
||||
critical_threshold=settings.memory_critical_threshold,
|
||||
emergency_threshold=settings.memory_emergency_threshold,
|
||||
model_idle_timeout_seconds=settings.pp_structure_idle_timeout_seconds,
|
||||
gpu_memory_limit_mb=settings.gpu_memory_limit_mb,
|
||||
enable_cpu_fallback=settings.enable_cpu_fallback,
|
||||
)
|
||||
self._memory_guard = MemoryGuard(memory_config)
|
||||
logger.debug("Fallback: MemoryGuard initialized for OCRService")
|
||||
except Exception as e2:
|
||||
logger.warning(f"Failed to initialize MemoryGuard fallback: {e2}")
|
||||
|
||||
# Track if CPU fallback was activated
|
||||
self._cpu_fallback_active = False
|
||||
@@ -262,9 +292,9 @@ class OCRService:
|
||||
return
|
||||
|
||||
try:
|
||||
# Use MemoryGuard if available for better monitoring
|
||||
if self._memory_guard:
|
||||
stats = self._memory_guard.get_memory_stats()
|
||||
# Use MemoryPolicyEngine (preferred) or MemoryGuard for monitoring
|
||||
if self._memory_policy_engine:
|
||||
stats = self._memory_policy_engine.get_memory_stats()
|
||||
|
||||
# Log based on usage ratio
|
||||
if stats.gpu_used_ratio > 0.90 and not self._memory_warning_logged:
|
||||
@@ -278,15 +308,33 @@ class OCRService:
|
||||
# Trigger emergency cleanup if enabled
|
||||
if settings.enable_emergency_cleanup:
|
||||
self._cleanup_unused_models()
|
||||
self._memory_guard.clear_gpu_cache()
|
||||
self._memory_policy_engine.clear_cache()
|
||||
|
||||
elif stats.gpu_used_ratio > 0.75:
|
||||
logger.info(
|
||||
f"GPU memory: {stats.gpu_used_mb:.0f}MB / {stats.gpu_total_mb:.0f}MB "
|
||||
f"({stats.gpu_used_ratio*100:.1f}%)"
|
||||
)
|
||||
elif self._memory_guard:
|
||||
# Fallback to legacy MemoryGuard
|
||||
stats = self._memory_guard.get_memory_stats()
|
||||
|
||||
if stats.gpu_used_ratio > 0.90 and not self._memory_warning_logged:
|
||||
logger.warning(
|
||||
f"GPU memory usage critical: {stats.gpu_used_mb:.0f}MB / {stats.gpu_total_mb:.0f}MB "
|
||||
f"({stats.gpu_used_ratio*100:.1f}%)"
|
||||
)
|
||||
self._memory_warning_logged = True
|
||||
if settings.enable_emergency_cleanup:
|
||||
self._cleanup_unused_models()
|
||||
self._memory_guard.clear_gpu_cache()
|
||||
elif stats.gpu_used_ratio > 0.75:
|
||||
logger.info(
|
||||
f"GPU memory: {stats.gpu_used_mb:.0f}MB / {stats.gpu_total_mb:.0f}MB "
|
||||
f"({stats.gpu_used_ratio*100:.1f}%)"
|
||||
)
|
||||
else:
|
||||
# Fallback to original implementation
|
||||
# No memory monitoring available - use direct paddle query
|
||||
device_id = self.gpu_info.get('device_id', 0)
|
||||
memory_allocated = paddle.device.cuda.memory_allocated(device_id)
|
||||
memory_allocated_mb = memory_allocated / (1024**2)
|
||||
@@ -296,7 +344,6 @@ class OCRService:
|
||||
|
||||
if utilization > 90 and not self._memory_warning_logged:
|
||||
logger.warning(f"GPU memory usage high: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
|
||||
logger.warning("Consider enabling auto_unload_unused_models or reducing batch size")
|
||||
self._memory_warning_logged = True
|
||||
elif utilization > 75:
|
||||
logger.info(f"GPU memory: {memory_allocated_mb:.0f}MB / {memory_limit_mb}MB ({utilization:.1f}%)")
|
||||
@@ -830,8 +877,50 @@ class OCRService:
|
||||
return True
|
||||
|
||||
try:
|
||||
# Use MemoryGuard if available for accurate multi-backend memory queries
|
||||
if self._memory_guard:
|
||||
# Use MemoryPolicyEngine (preferred) or MemoryGuard for memory checks
|
||||
if self._memory_policy_engine:
|
||||
is_available, msg = self._memory_policy_engine.check_memory(required_mb)
|
||||
|
||||
if not is_available:
|
||||
stats = self._memory_policy_engine.get_memory_stats()
|
||||
logger.warning(
|
||||
f"GPU memory check failed: {stats.gpu_free_mb:.0f}MB free, "
|
||||
f"{required_mb}MB required ({stats.gpu_used_ratio*100:.1f}% used)"
|
||||
)
|
||||
|
||||
# Try to free memory
|
||||
logger.info("Attempting memory cleanup before retry...")
|
||||
self._cleanup_unused_models()
|
||||
self._memory_policy_engine.clear_cache()
|
||||
|
||||
# Check again
|
||||
is_available, msg = self._memory_policy_engine.check_memory(required_mb)
|
||||
|
||||
if not is_available:
|
||||
stats = self._memory_policy_engine.get_memory_stats()
|
||||
if enable_fallback and settings.enable_cpu_fallback:
|
||||
logger.warning(
|
||||
f"Insufficient GPU memory ({stats.gpu_free_mb:.0f}MB) after cleanup. "
|
||||
f"Activating CPU fallback mode."
|
||||
)
|
||||
self._activate_cpu_fallback()
|
||||
return True
|
||||
else:
|
||||
logger.error(
|
||||
f"Insufficient GPU memory: {stats.gpu_free_mb:.0f}MB available, "
|
||||
f"{required_mb}MB required"
|
||||
)
|
||||
return False
|
||||
|
||||
stats = self._memory_policy_engine.get_memory_stats()
|
||||
logger.debug(
|
||||
f"GPU memory check passed: {stats.gpu_free_mb:.0f}MB free "
|
||||
f"({stats.gpu_used_ratio*100:.1f}% used)"
|
||||
)
|
||||
return True
|
||||
|
||||
elif self._memory_guard:
|
||||
# Fallback to legacy MemoryGuard
|
||||
is_available, stats = self._memory_guard.check_memory(
|
||||
required_mb=required_mb,
|
||||
device_id=self.gpu_info.get('device_id', 0)
|
||||
@@ -843,23 +932,20 @@ class OCRService:
|
||||
f"{required_mb}MB required ({stats.gpu_used_ratio*100:.1f}% used)"
|
||||
)
|
||||
|
||||
# Try to free memory
|
||||
logger.info("Attempting memory cleanup before retry...")
|
||||
self._cleanup_unused_models()
|
||||
self._memory_guard.clear_gpu_cache()
|
||||
|
||||
# Check again
|
||||
is_available, stats = self._memory_guard.check_memory(required_mb=required_mb)
|
||||
|
||||
if not is_available:
|
||||
# Memory still insufficient after cleanup
|
||||
if enable_fallback and settings.enable_cpu_fallback:
|
||||
logger.warning(
|
||||
f"Insufficient GPU memory ({stats.gpu_free_mb:.0f}MB) after cleanup. "
|
||||
f"Activating CPU fallback mode."
|
||||
)
|
||||
self._activate_cpu_fallback()
|
||||
return True # Continue with CPU
|
||||
return True
|
||||
else:
|
||||
logger.error(
|
||||
f"Insufficient GPU memory: {stats.gpu_free_mb:.0f}MB available, "
|
||||
@@ -937,7 +1023,9 @@ class OCRService:
|
||||
self.gpu_info['fallback_reason'] = 'GPU memory insufficient'
|
||||
|
||||
# Clear GPU cache to free memory
|
||||
if self._memory_guard:
|
||||
if self._memory_policy_engine:
|
||||
self._memory_policy_engine.clear_cache()
|
||||
elif self._memory_guard:
|
||||
self._memory_guard.clear_gpu_cache()
|
||||
|
||||
def _restore_gpu_mode(self):
|
||||
@@ -952,7 +1040,17 @@ class OCRService:
|
||||
return
|
||||
|
||||
# Check if GPU memory is now available
|
||||
if self._memory_guard:
|
||||
if self._memory_policy_engine:
|
||||
is_available, msg = self._memory_policy_engine.check_memory(
|
||||
settings.structure_model_memory_mb
|
||||
)
|
||||
if is_available:
|
||||
logger.info("GPU memory available, restoring GPU mode")
|
||||
self._cpu_fallback_active = False
|
||||
self.use_gpu = True
|
||||
self.gpu_info.pop('cpu_fallback', None)
|
||||
self.gpu_info.pop('fallback_reason', None)
|
||||
elif self._memory_guard:
|
||||
is_available, stats = self._memory_guard.check_memory(
|
||||
required_mb=settings.structure_model_memory_mb
|
||||
)
|
||||
@@ -2204,6 +2302,81 @@ class OCRService:
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir
|
||||
)
|
||||
|
||||
@property
|
||||
def orchestrator(self) -> Optional['ProcessingOrchestrator']:
|
||||
"""Get the ProcessingOrchestrator instance (if available)."""
|
||||
return self._orchestrator
|
||||
|
||||
def process_with_orchestrator(
|
||||
self,
|
||||
file_path: Path,
|
||||
lang: str = 'ch',
|
||||
detect_layout: bool = True,
|
||||
confidence_threshold: Optional[float] = None,
|
||||
output_dir: Optional[Path] = None,
|
||||
force_track: Optional[str] = None,
|
||||
layout_model: Optional[str] = None,
|
||||
preprocessing_mode: Optional[PreprocessingModeEnum] = None,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
table_detection_config: Optional[TableDetectionConfig] = None
|
||||
) -> Union[UnifiedDocument, Dict]:
|
||||
"""
|
||||
Process document using the ProcessingOrchestrator.
|
||||
|
||||
This method provides a cleaner separation of concerns by delegating
|
||||
to the orchestrator, which coordinates the processing pipelines.
|
||||
|
||||
Args:
|
||||
file_path: Path to document file
|
||||
lang: Language for OCR (if needed)
|
||||
detect_layout: Whether to perform layout analysis
|
||||
confidence_threshold: Minimum confidence threshold
|
||||
output_dir: Optional output directory
|
||||
force_track: Force specific track ("ocr" or "direct")
|
||||
layout_model: Layout detection model
|
||||
preprocessing_mode: Layout preprocessing mode
|
||||
preprocessing_config: Manual preprocessing config
|
||||
table_detection_config: Table detection config
|
||||
|
||||
Returns:
|
||||
UnifiedDocument with processed results
|
||||
"""
|
||||
if not self._orchestrator:
|
||||
logger.warning("ProcessingOrchestrator not available, falling back to legacy processing")
|
||||
return self.process_with_dual_track(
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir,
|
||||
force_track, layout_model, preprocessing_mode, preprocessing_config, table_detection_config
|
||||
)
|
||||
|
||||
# Build ProcessingConfig
|
||||
config = ProcessingConfig(
|
||||
detect_layout=detect_layout,
|
||||
confidence_threshold=confidence_threshold or self.confidence_threshold,
|
||||
output_dir=Path(output_dir) if output_dir else None,
|
||||
lang=lang,
|
||||
layout_model=layout_model or "default",
|
||||
preprocessing_mode=preprocessing_mode.value if preprocessing_mode else "auto",
|
||||
preprocessing_config=preprocessing_config.dict() if preprocessing_config else None,
|
||||
table_detection_config=table_detection_config.dict() if table_detection_config else None,
|
||||
force_track=force_track,
|
||||
use_dual_track=True
|
||||
)
|
||||
|
||||
# Process using orchestrator
|
||||
result = self._orchestrator.process(Path(file_path), config)
|
||||
|
||||
if result.success and result.document:
|
||||
return result.document
|
||||
elif result.legacy_result:
|
||||
return result.legacy_result
|
||||
else:
|
||||
logger.error(f"Orchestrator processing failed: {result.error}")
|
||||
# Fallback to legacy processing
|
||||
return self.process_with_dual_track(
|
||||
file_path, lang, detect_layout, confidence_threshold, output_dir,
|
||||
force_track, layout_model, preprocessing_mode, preprocessing_config, table_detection_config
|
||||
)
|
||||
|
||||
def get_track_recommendation(self, file_path: Path) -> Optional[ProcessingTrackRecommendation]:
|
||||
"""
|
||||
Get processing track recommendation for a file.
|
||||
|
||||
312
backend/app/services/pdf_font_manager.py
Normal file
312
backend/app/services/pdf_font_manager.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
PDF Font Manager - Handles font loading, registration, and fallback.
|
||||
|
||||
This module provides unified font management for PDF generation,
|
||||
including CJK font support and font fallback mechanisms.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from reportlab.pdfbase import pdfmetrics
|
||||
from reportlab.pdfbase.ttfonts import TTFont
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class FontConfig:
|
||||
"""Configuration for font management."""
|
||||
# Primary fonts
|
||||
chinese_font_name: str = "NotoSansSC"
|
||||
chinese_font_path: Optional[Path] = None
|
||||
|
||||
# Fallback fonts (built-in)
|
||||
fallback_font_name: str = "Helvetica"
|
||||
fallback_cjk_font_name: str = "HeiseiMin-W3" # Built-in ReportLab CJK
|
||||
|
||||
# Font sizes
|
||||
default_font_size: int = 10
|
||||
min_font_size: int = 6
|
||||
max_font_size: int = 14
|
||||
|
||||
# Font registration options
|
||||
auto_register: bool = True
|
||||
enable_cjk_fallback: bool = True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Font Manager
|
||||
# ============================================================================
|
||||
|
||||
class FontManager:
|
||||
"""
|
||||
Manages font registration and selection for PDF generation.
|
||||
|
||||
Features:
|
||||
- Lazy font registration
|
||||
- CJK (Chinese/Japanese/Korean) font support
|
||||
- Automatic fallback to built-in fonts
|
||||
- Font caching to avoid duplicate registration
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_registered_fonts: Dict[str, Path] = {}
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Singleton pattern to avoid duplicate font registration."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Optional[FontConfig] = None):
|
||||
"""
|
||||
Initialize FontManager.
|
||||
|
||||
Args:
|
||||
config: FontConfig instance (uses defaults if None)
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.config = config or FontConfig()
|
||||
self._primary_font_registered = False
|
||||
self._cjk_fallback_available = False
|
||||
|
||||
# Auto-register fonts if enabled
|
||||
if self.config.auto_register:
|
||||
self._register_fonts()
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@property
|
||||
def primary_font_name(self) -> str:
|
||||
"""Get the primary font name to use."""
|
||||
if self._primary_font_registered:
|
||||
return self.config.chinese_font_name
|
||||
return self.config.fallback_font_name
|
||||
|
||||
@property
|
||||
def is_cjk_enabled(self) -> bool:
|
||||
"""Check if CJK fonts are available."""
|
||||
return self._primary_font_registered or self._cjk_fallback_available
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset singleton instance (for testing)."""
|
||||
cls._instance = None
|
||||
cls._registered_fonts = {}
|
||||
|
||||
def get_font_for_text(self, text: str) -> str:
|
||||
"""
|
||||
Get appropriate font name for given text.
|
||||
|
||||
Args:
|
||||
text: Text to render
|
||||
|
||||
Returns:
|
||||
Font name suitable for the text content
|
||||
"""
|
||||
if self._contains_cjk(text):
|
||||
if self._primary_font_registered:
|
||||
return self.config.chinese_font_name
|
||||
elif self._cjk_fallback_available:
|
||||
return self.config.fallback_cjk_font_name
|
||||
return self.primary_font_name
|
||||
|
||||
def get_font_size(
|
||||
self,
|
||||
text: str,
|
||||
available_width: float,
|
||||
available_height: float,
|
||||
pdf_canvas=None
|
||||
) -> int:
|
||||
"""
|
||||
Calculate optimal font size for text to fit within bounds.
|
||||
|
||||
Args:
|
||||
text: Text to render
|
||||
available_width: Maximum width available
|
||||
available_height: Maximum height available
|
||||
pdf_canvas: Optional canvas for precise measurement
|
||||
|
||||
Returns:
|
||||
Font size that fits within bounds
|
||||
"""
|
||||
font_name = self.get_font_for_text(text)
|
||||
|
||||
for size in range(self.config.max_font_size, self.config.min_font_size - 1, -1):
|
||||
if pdf_canvas:
|
||||
# Precise measurement with canvas
|
||||
text_width = pdf_canvas.stringWidth(text, font_name, size)
|
||||
else:
|
||||
# Approximate measurement
|
||||
text_width = len(text) * size * 0.6 # Rough estimate
|
||||
|
||||
text_height = size * 1.2 # Line height
|
||||
|
||||
if text_width <= available_width and text_height <= available_height:
|
||||
return size
|
||||
|
||||
return self.config.min_font_size
|
||||
|
||||
def register_font(
|
||||
self,
|
||||
font_name: str,
|
||||
font_path: Path,
|
||||
force: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Register a custom font.
|
||||
|
||||
Args:
|
||||
font_name: Name to register font under
|
||||
font_path: Path to TTF font file
|
||||
force: Force re-registration if already registered
|
||||
|
||||
Returns:
|
||||
True if registration successful
|
||||
"""
|
||||
if font_name in self._registered_fonts and not force:
|
||||
logger.debug(f"Font {font_name} already registered")
|
||||
return True
|
||||
|
||||
try:
|
||||
if not font_path.exists():
|
||||
logger.error(f"Font file not found: {font_path}")
|
||||
return False
|
||||
|
||||
pdfmetrics.registerFont(TTFont(font_name, str(font_path)))
|
||||
self._registered_fonts[font_name] = font_path
|
||||
logger.info(f"Font registered: {font_name} from {font_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register font {font_name}: {e}")
|
||||
return False
|
||||
|
||||
def get_registered_fonts(self) -> List[str]:
|
||||
"""Get list of registered custom font names."""
|
||||
return list(self._registered_fonts.keys())
|
||||
|
||||
# =========================================================================
|
||||
# Private Methods
|
||||
# =========================================================================
|
||||
|
||||
def _register_fonts(self):
|
||||
"""Register configured fonts."""
|
||||
# Register primary Chinese font
|
||||
if self.config.chinese_font_path:
|
||||
self._register_chinese_font()
|
||||
|
||||
# Setup CJK fallback
|
||||
if self.config.enable_cjk_fallback:
|
||||
self._setup_cjk_fallback()
|
||||
|
||||
def _register_chinese_font(self):
|
||||
"""Register the primary Chinese font."""
|
||||
font_path = self.config.chinese_font_path
|
||||
|
||||
if font_path is None:
|
||||
# Try to load from settings
|
||||
try:
|
||||
from app.core.config import settings
|
||||
font_path = Path(settings.chinese_font_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load font path from settings: {e}")
|
||||
return
|
||||
|
||||
# Resolve relative path
|
||||
if not font_path.is_absolute():
|
||||
# Try project root
|
||||
project_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
font_path = project_root / font_path
|
||||
|
||||
if not font_path.exists():
|
||||
logger.warning(f"Chinese font not found at {font_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
pdfmetrics.registerFont(TTFont(self.config.chinese_font_name, str(font_path)))
|
||||
self._registered_fonts[self.config.chinese_font_name] = font_path
|
||||
self._primary_font_registered = True
|
||||
logger.info(f"Chinese font registered: {self.config.chinese_font_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register Chinese font: {e}")
|
||||
|
||||
def _setup_cjk_fallback(self):
|
||||
"""Setup CJK fallback using built-in fonts."""
|
||||
try:
|
||||
# ReportLab includes CID fonts for CJK
|
||||
from reportlab.pdfbase.cidfonts import UnicodeCIDFont
|
||||
|
||||
# Register CJK fonts if not already registered
|
||||
try:
|
||||
pdfmetrics.registerFont(UnicodeCIDFont('HeiseiMin-W3'))
|
||||
self._cjk_fallback_available = True
|
||||
logger.debug("CJK fallback font available: HeiseiMin-W3")
|
||||
except Exception:
|
||||
pass # Font may already be registered
|
||||
|
||||
except ImportError:
|
||||
logger.debug("CID fonts not available for CJK fallback")
|
||||
|
||||
def _contains_cjk(self, text: str) -> bool:
|
||||
"""
|
||||
Check if text contains CJK characters.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text contains Chinese, Japanese, or Korean characters
|
||||
"""
|
||||
if not text:
|
||||
return False
|
||||
|
||||
for char in text:
|
||||
code = ord(char)
|
||||
# CJK Unified Ideographs and related ranges
|
||||
if any([
|
||||
0x4E00 <= code <= 0x9FFF, # CJK Unified Ideographs
|
||||
0x3400 <= code <= 0x4DBF, # CJK Extension A
|
||||
0x20000 <= code <= 0x2A6DF, # CJK Extension B
|
||||
0x3000 <= code <= 0x303F, # CJK Punctuation
|
||||
0x3040 <= code <= 0x309F, # Hiragana
|
||||
0x30A0 <= code <= 0x30FF, # Katakana
|
||||
0xAC00 <= code <= 0xD7AF, # Korean Hangul
|
||||
]):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Convenience Functions
|
||||
# ============================================================================
|
||||
|
||||
_default_manager: Optional[FontManager] = None
|
||||
|
||||
|
||||
def get_font_manager() -> FontManager:
|
||||
"""Get the default FontManager instance."""
|
||||
global _default_manager
|
||||
if _default_manager is None:
|
||||
_default_manager = FontManager()
|
||||
return _default_manager
|
||||
|
||||
|
||||
def register_font(font_name: str, font_path: Path) -> bool:
|
||||
"""Register a font using the default manager."""
|
||||
return get_font_manager().register_font(font_name, font_path)
|
||||
|
||||
|
||||
def get_font_for_text(text: str) -> str:
|
||||
"""Get appropriate font for text using the default manager."""
|
||||
return get_font_manager().get_font_for_text(text)
|
||||
@@ -925,7 +925,9 @@ class PDFGeneratorService:
|
||||
element.type = ElementType.LIST_ITEM
|
||||
elif element.is_text or element.type in [
|
||||
ElementType.TEXT, ElementType.TITLE, ElementType.HEADER,
|
||||
ElementType.FOOTER, ElementType.PARAGRAPH
|
||||
ElementType.FOOTER, ElementType.PARAGRAPH,
|
||||
ElementType.FOOTNOTE, ElementType.REFERENCE,
|
||||
ElementType.EQUATION, ElementType.CAPTION
|
||||
]:
|
||||
text_elements.append(element)
|
||||
|
||||
|
||||
917
backend/app/services/pdf_table_renderer.py
Normal file
917
backend/app/services/pdf_table_renderer.py
Normal file
@@ -0,0 +1,917 @@
|
||||
"""
|
||||
PDF Table Renderer - Handles table rendering for PDF generation.
|
||||
|
||||
This module provides unified table rendering capabilities extracted from
|
||||
PDFGeneratorService, supporting multiple input formats:
|
||||
- HTML tables
|
||||
- Cell boxes (layered approach)
|
||||
- Cells dictionary (Direct track)
|
||||
- TableData objects
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from html.parser import HTMLParser
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from reportlab.lib import colors
|
||||
from reportlab.lib.enums import TA_CENTER, TA_LEFT, TA_RIGHT
|
||||
from reportlab.lib.styles import ParagraphStyle
|
||||
from reportlab.lib.utils import ImageReader
|
||||
from reportlab.platypus import Paragraph, Table, TableStyle
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class TableRenderConfig:
|
||||
"""Configuration for table rendering."""
|
||||
font_name: str = "Helvetica"
|
||||
font_size: int = 8
|
||||
min_font_size: int = 6
|
||||
max_font_size: int = 10
|
||||
|
||||
# Padding options
|
||||
left_padding: int = 2
|
||||
right_padding: int = 2
|
||||
top_padding: int = 2
|
||||
bottom_padding: int = 2
|
||||
|
||||
# Border options
|
||||
border_color: Any = colors.black
|
||||
border_width: float = 0.5
|
||||
|
||||
# Alignment
|
||||
horizontal_align: str = "CENTER"
|
||||
vertical_align: str = "MIDDLE"
|
||||
|
||||
# Header styling
|
||||
header_background: Any = colors.lightgrey
|
||||
|
||||
# Grid normalization threshold
|
||||
grid_threshold: float = 10.0
|
||||
|
||||
# Merged cells threshold
|
||||
merge_boundary_threshold: float = 5.0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HTML Table Parser
|
||||
# ============================================================================
|
||||
|
||||
class HTMLTableParser(HTMLParser):
|
||||
"""
|
||||
Parse HTML table structure for rendering.
|
||||
|
||||
Extracts table rows, cells, and merged cell information (colspan/rowspan)
|
||||
from HTML table markup.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tables = []
|
||||
self.current_table = None
|
||||
self.current_row = None
|
||||
self.current_cell = None
|
||||
self.in_cell = False
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: List[Tuple[str, str]]):
|
||||
if tag == 'table':
|
||||
self.current_table = {'rows': []}
|
||||
elif tag == 'tr':
|
||||
self.current_row = {'cells': []}
|
||||
elif tag in ('td', 'th'):
|
||||
# Extract colspan and rowspan attributes
|
||||
attrs_dict = dict(attrs)
|
||||
colspan = int(attrs_dict.get('colspan', 1))
|
||||
rowspan = int(attrs_dict.get('rowspan', 1))
|
||||
self.current_cell = {
|
||||
'text': '',
|
||||
'is_header': tag == 'th',
|
||||
'colspan': colspan,
|
||||
'rowspan': rowspan
|
||||
}
|
||||
self.in_cell = True
|
||||
|
||||
def handle_endtag(self, tag: str):
|
||||
if tag == 'table' and self.current_table:
|
||||
self.tables.append(self.current_table)
|
||||
self.current_table = None
|
||||
elif tag == 'tr' and self.current_row:
|
||||
if self.current_table:
|
||||
self.current_table['rows'].append(self.current_row)
|
||||
self.current_row = None
|
||||
elif tag in ('td', 'th') and self.current_cell:
|
||||
if self.current_row:
|
||||
self.current_row['cells'].append(self.current_cell)
|
||||
self.current_cell = None
|
||||
self.in_cell = False
|
||||
|
||||
def handle_data(self, data: str):
|
||||
if self.in_cell and self.current_cell is not None:
|
||||
self.current_cell['text'] += data
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Table Renderer
|
||||
# ============================================================================
|
||||
|
||||
class TableRenderer:
|
||||
"""
|
||||
Unified table rendering engine for PDF generation.
|
||||
|
||||
Supports multiple input formats and rendering modes:
|
||||
- HTML table parsing and rendering
|
||||
- Cell boxes rendering (layered approach)
|
||||
- Direct track cells dictionary
|
||||
- Translated content with dynamic font sizing
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[TableRenderConfig] = None):
|
||||
"""
|
||||
Initialize TableRenderer with configuration.
|
||||
|
||||
Args:
|
||||
config: TableRenderConfig instance (uses defaults if None)
|
||||
"""
|
||||
self.config = config or TableRenderConfig()
|
||||
|
||||
def render_from_html(
|
||||
self,
|
||||
pdf_canvas,
|
||||
html_content: str,
|
||||
table_bbox: Tuple[float, float, float, float],
|
||||
page_height: float,
|
||||
scale_w: float = 1.0,
|
||||
scale_h: float = 1.0
|
||||
) -> bool:
|
||||
"""
|
||||
Parse HTML and render table to PDF canvas.
|
||||
|
||||
Args:
|
||||
pdf_canvas: ReportLab canvas
|
||||
html_content: HTML table string
|
||||
table_bbox: (x0, y0, x1, y1) bounding box
|
||||
page_height: PDF page height for Y coordinate flip
|
||||
scale_w: Horizontal scale factor
|
||||
scale_h: Vertical scale factor
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Parse HTML
|
||||
parser = HTMLTableParser()
|
||||
parser.feed(html_content)
|
||||
|
||||
if not parser.tables:
|
||||
logger.warning("No tables found in HTML content")
|
||||
return False
|
||||
|
||||
table_data = parser.tables[0]
|
||||
return self._render_parsed_table(
|
||||
pdf_canvas, table_data, table_bbox, page_height, scale_w, scale_h
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"HTML table rendering failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def render_from_cells_dict(
|
||||
self,
|
||||
pdf_canvas,
|
||||
cells_dict: Dict,
|
||||
table_bbox: Tuple[float, float, float, float],
|
||||
page_height: float,
|
||||
cell_boxes: Optional[List] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Render table from Direct track cell structure.
|
||||
|
||||
Args:
|
||||
pdf_canvas: ReportLab canvas
|
||||
cells_dict: Dict with 'rows', 'cols', 'cells' keys
|
||||
table_bbox: (x0, y0, x1, y1) bounding box
|
||||
page_height: PDF page height
|
||||
cell_boxes: Optional precomputed cell boxes
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Convert cells dict to row format
|
||||
rows = self._build_rows_from_cells_dict(cells_dict)
|
||||
|
||||
if not rows:
|
||||
logger.warning("No rows built from cells dict")
|
||||
return False
|
||||
|
||||
# Build table data structure
|
||||
table_data = {'rows': rows}
|
||||
|
||||
# Calculate dimensions
|
||||
x0, y0, x1, y1 = table_bbox
|
||||
table_width = (x1 - x0)
|
||||
table_height = (y1 - y0)
|
||||
|
||||
# Determine grid dimensions
|
||||
num_rows = cells_dict.get('rows', len(rows))
|
||||
num_cols = cells_dict.get('cols',
|
||||
max(len(row['cells']) for row in rows) if rows else 1
|
||||
)
|
||||
|
||||
# Calculate column widths and row heights
|
||||
if cell_boxes:
|
||||
col_widths, row_heights = self.compute_grid_from_cell_boxes(
|
||||
cell_boxes, table_bbox, num_rows, num_cols
|
||||
)
|
||||
else:
|
||||
col_widths = [table_width / num_cols] * num_cols
|
||||
row_heights = [table_height / num_rows] * num_rows
|
||||
|
||||
return self._render_with_dimensions(
|
||||
pdf_canvas, table_data, table_bbox, page_height,
|
||||
col_widths, row_heights
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cells dict rendering failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def render_cell_borders(
|
||||
self,
|
||||
pdf_canvas,
|
||||
cell_boxes: List[List[float]],
|
||||
table_bbox: Tuple[float, float, float, float],
|
||||
page_height: float,
|
||||
embedded_images: Optional[List] = None,
|
||||
output_dir: Optional[Path] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Render table cell borders only (layered approach).
|
||||
|
||||
This renders only the cell borders, not the text content.
|
||||
Text is typically rendered separately by GapFillingService.
|
||||
|
||||
Args:
|
||||
pdf_canvas: ReportLab canvas
|
||||
cell_boxes: List of [x0, y0, x1, y1] for each cell
|
||||
table_bbox: Table bounding box
|
||||
page_height: PDF page height
|
||||
embedded_images: Optional list of images within cells
|
||||
output_dir: Directory for image files
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not cell_boxes:
|
||||
# Draw outer border only
|
||||
return self._draw_table_border(
|
||||
pdf_canvas, table_bbox, page_height
|
||||
)
|
||||
|
||||
# Normalize cell boxes to grid
|
||||
normalized_boxes = self.normalize_cell_boxes_to_grid(cell_boxes)
|
||||
|
||||
# Draw each cell border
|
||||
pdf_canvas.saveState()
|
||||
pdf_canvas.setStrokeColor(self.config.border_color)
|
||||
pdf_canvas.setLineWidth(self.config.border_width)
|
||||
|
||||
for box in normalized_boxes:
|
||||
if box is None:
|
||||
continue
|
||||
|
||||
x0, y0, x1, y1 = box
|
||||
# Convert to PDF coordinates (flip Y)
|
||||
pdf_x0 = x0
|
||||
pdf_y0 = page_height - y1
|
||||
pdf_x1 = x1
|
||||
pdf_y1 = page_height - y0
|
||||
|
||||
# Draw cell rectangle
|
||||
pdf_canvas.rect(pdf_x0, pdf_y0, pdf_x1 - pdf_x0, pdf_y1 - pdf_y0)
|
||||
|
||||
pdf_canvas.restoreState()
|
||||
|
||||
# Draw embedded images if any
|
||||
if embedded_images and output_dir:
|
||||
for img_info in embedded_images:
|
||||
self._draw_embedded_image(
|
||||
pdf_canvas, img_info, page_height, output_dir
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cell borders rendering failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def render_with_translated_text(
|
||||
self,
|
||||
pdf_canvas,
|
||||
cells: List[Dict],
|
||||
cell_boxes: List,
|
||||
table_bbox: Tuple[float, float, float, float],
|
||||
page_height: float
|
||||
) -> bool:
|
||||
"""
|
||||
Render table with translated content and dynamic font sizing.
|
||||
|
||||
Args:
|
||||
pdf_canvas: ReportLab canvas
|
||||
cells: List of cell dicts with 'translated_content'
|
||||
cell_boxes: List of cell bounding boxes
|
||||
table_bbox: Table bounding box
|
||||
page_height: PDF page height
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Draw outer border
|
||||
self._draw_table_border(pdf_canvas, table_bbox, page_height)
|
||||
|
||||
# Normalize cell boxes
|
||||
if cell_boxes:
|
||||
normalized_boxes = self.normalize_cell_boxes_to_grid(cell_boxes)
|
||||
else:
|
||||
logger.warning("No cell boxes for translated table")
|
||||
return False
|
||||
|
||||
pdf_canvas.saveState()
|
||||
pdf_canvas.setStrokeColor(self.config.border_color)
|
||||
pdf_canvas.setLineWidth(self.config.border_width)
|
||||
|
||||
# Draw cell borders
|
||||
for box in normalized_boxes:
|
||||
if box is None:
|
||||
continue
|
||||
x0, y0, x1, y1 = box
|
||||
pdf_y0 = page_height - y1
|
||||
pdf_canvas.rect(x0, pdf_y0, x1 - x0, y1 - y0)
|
||||
|
||||
pdf_canvas.restoreState()
|
||||
|
||||
# Render text in cells with dynamic font sizing
|
||||
for i, cell in enumerate(cells):
|
||||
if i >= len(normalized_boxes):
|
||||
break
|
||||
|
||||
box = normalized_boxes[i]
|
||||
if box is None:
|
||||
continue
|
||||
|
||||
translated_text = cell.get('translated_content', '')
|
||||
if not translated_text:
|
||||
continue
|
||||
|
||||
x0, y0, x1, y1 = box
|
||||
cell_width = x1 - x0
|
||||
cell_height = y1 - y0
|
||||
|
||||
# Find appropriate font size
|
||||
font_size = self._fit_text_to_cell(
|
||||
pdf_canvas, translated_text, cell_width, cell_height
|
||||
)
|
||||
|
||||
# Render centered text
|
||||
pdf_canvas.setFont(self.config.font_name, font_size)
|
||||
|
||||
# Calculate text position (centered)
|
||||
text_width = pdf_canvas.stringWidth(translated_text, self.config.font_name, font_size)
|
||||
text_x = x0 + (cell_width - text_width) / 2
|
||||
text_y = page_height - y0 - cell_height / 2 - font_size / 3
|
||||
|
||||
pdf_canvas.drawString(text_x, text_y, translated_text)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Translated table rendering failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# =========================================================================
|
||||
# Grid and Cell Box Helpers
|
||||
# =========================================================================
|
||||
|
||||
def compute_grid_from_cell_boxes(
|
||||
self,
|
||||
cell_boxes: List,
|
||||
table_bbox: Tuple[float, float, float, float],
|
||||
num_rows: int,
|
||||
num_cols: int
|
||||
) -> Tuple[Optional[List[float]], Optional[List[float]]]:
|
||||
"""
|
||||
Calculate column widths and row heights from cell bounding boxes.
|
||||
|
||||
Args:
|
||||
cell_boxes: List of [x0, y0, x1, y1] for each cell
|
||||
table_bbox: Table bounding box
|
||||
num_rows: Expected number of rows
|
||||
num_cols: Expected number of columns
|
||||
|
||||
Returns:
|
||||
Tuple of (col_widths, row_heights) or (None, None) on failure
|
||||
"""
|
||||
try:
|
||||
if not cell_boxes:
|
||||
return None, None
|
||||
|
||||
# Filter valid boxes
|
||||
valid_boxes = [b for b in cell_boxes if b is not None and len(b) >= 4]
|
||||
if not valid_boxes:
|
||||
return None, None
|
||||
|
||||
# Extract unique X and Y boundaries
|
||||
x_boundaries = set()
|
||||
y_boundaries = set()
|
||||
|
||||
for box in valid_boxes:
|
||||
x0, y0, x1, y1 = box[:4]
|
||||
x_boundaries.add(round(x0, 1))
|
||||
x_boundaries.add(round(x1, 1))
|
||||
y_boundaries.add(round(y0, 1))
|
||||
y_boundaries.add(round(y1, 1))
|
||||
|
||||
# Sort boundaries
|
||||
x_sorted = sorted(x_boundaries)
|
||||
y_sorted = sorted(y_boundaries)
|
||||
|
||||
# Merge nearby boundaries
|
||||
x_merged = self._merge_boundaries(x_sorted, self.config.merge_boundary_threshold)
|
||||
y_merged = self._merge_boundaries(y_sorted, self.config.merge_boundary_threshold)
|
||||
|
||||
# Calculate widths and heights
|
||||
col_widths = []
|
||||
for i in range(len(x_merged) - 1):
|
||||
col_widths.append(x_merged[i + 1] - x_merged[i])
|
||||
|
||||
row_heights = []
|
||||
for i in range(len(y_merged) - 1):
|
||||
row_heights.append(y_merged[i + 1] - y_merged[i])
|
||||
|
||||
# Validate against expected dimensions (allow for merged cells)
|
||||
tolerance = max(num_cols, num_rows) // 2 + 1
|
||||
if abs(len(col_widths) - num_cols) > tolerance:
|
||||
logger.debug(f"Column count mismatch: {len(col_widths)} vs {num_cols}")
|
||||
if abs(len(row_heights) - num_rows) > tolerance:
|
||||
logger.debug(f"Row count mismatch: {len(row_heights)} vs {num_rows}")
|
||||
|
||||
return col_widths if col_widths else None, row_heights if row_heights else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Grid computation failed: {e}")
|
||||
return None, None
|
||||
|
||||
def normalize_cell_boxes_to_grid(
|
||||
self,
|
||||
cell_boxes: List,
|
||||
threshold: Optional[float] = None
|
||||
) -> List:
|
||||
"""
|
||||
Snap cell boxes to aligned grid to eliminate coordinate variations.
|
||||
|
||||
Args:
|
||||
cell_boxes: List of [x0, y0, x1, y1] for each cell
|
||||
threshold: Clustering threshold (uses config default if None)
|
||||
|
||||
Returns:
|
||||
Normalized cell boxes
|
||||
"""
|
||||
threshold = threshold or self.config.grid_threshold
|
||||
|
||||
if not cell_boxes:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Collect all coordinates
|
||||
all_x = []
|
||||
all_y = []
|
||||
|
||||
for box in cell_boxes:
|
||||
if box is None or len(box) < 4:
|
||||
continue
|
||||
x0, y0, x1, y1 = box[:4]
|
||||
all_x.extend([x0, x1])
|
||||
all_y.extend([y0, y1])
|
||||
|
||||
if not all_x or not all_y:
|
||||
return cell_boxes
|
||||
|
||||
# Cluster and normalize X coordinates
|
||||
x_clusters = self._cluster_values(sorted(all_x), threshold)
|
||||
y_clusters = self._cluster_values(sorted(all_y), threshold)
|
||||
|
||||
# Build mapping
|
||||
x_map = {v: avg for avg, values in x_clusters for v in values}
|
||||
y_map = {v: avg for avg, values in y_clusters for v in values}
|
||||
|
||||
# Normalize boxes
|
||||
normalized = []
|
||||
for box in cell_boxes:
|
||||
if box is None or len(box) < 4:
|
||||
normalized.append(box)
|
||||
continue
|
||||
|
||||
x0, y0, x1, y1 = box[:4]
|
||||
normalized.append([
|
||||
x_map.get(x0, x0),
|
||||
y_map.get(y0, y0),
|
||||
x_map.get(x1, x1),
|
||||
y_map.get(y1, y1)
|
||||
])
|
||||
|
||||
return normalized
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cell box normalization failed: {e}")
|
||||
return cell_boxes
|
||||
|
||||
# =========================================================================
|
||||
# Private Helper Methods
|
||||
# =========================================================================
|
||||
|
||||
def _render_parsed_table(
|
||||
self,
|
||||
pdf_canvas,
|
||||
table_data: Dict,
|
||||
table_bbox: Tuple[float, float, float, float],
|
||||
page_height: float,
|
||||
scale_w: float = 1.0,
|
||||
scale_h: float = 1.0
|
||||
) -> bool:
|
||||
"""Render a parsed table structure."""
|
||||
rows = table_data.get('rows', [])
|
||||
if not rows:
|
||||
return False
|
||||
|
||||
# Build grid content
|
||||
num_rows = len(rows)
|
||||
num_cols = max(len(row.get('cells', [])) for row in rows)
|
||||
|
||||
# Track occupied cells for rowspan handling
|
||||
occupied = [[False] * num_cols for _ in range(num_rows)]
|
||||
|
||||
grid = []
|
||||
span_commands = []
|
||||
|
||||
for row_idx, row in enumerate(rows):
|
||||
grid_row = [''] * num_cols
|
||||
col_idx = 0
|
||||
|
||||
for cell in row.get('cells', []):
|
||||
# Skip occupied cells
|
||||
while col_idx < num_cols and occupied[row_idx][col_idx]:
|
||||
col_idx += 1
|
||||
|
||||
if col_idx >= num_cols:
|
||||
break
|
||||
|
||||
text = cell.get('text', '').strip()
|
||||
colspan = cell.get('colspan', 1)
|
||||
rowspan = cell.get('rowspan', 1)
|
||||
|
||||
# Place cell content
|
||||
grid_row[col_idx] = text
|
||||
|
||||
# Mark occupied cells and build SPAN command
|
||||
if colspan > 1 or rowspan > 1:
|
||||
end_col = min(col_idx + colspan - 1, num_cols - 1)
|
||||
end_row = min(row_idx + rowspan - 1, num_rows - 1)
|
||||
span_commands.append(
|
||||
('SPAN', (col_idx, row_idx), (end_col, end_row))
|
||||
)
|
||||
|
||||
for r in range(row_idx, end_row + 1):
|
||||
for c in range(col_idx, end_col + 1):
|
||||
if r < num_rows and c < num_cols:
|
||||
occupied[r][c] = True
|
||||
else:
|
||||
occupied[row_idx][col_idx] = True
|
||||
|
||||
col_idx += colspan
|
||||
|
||||
grid.append(grid_row)
|
||||
|
||||
# Calculate dimensions
|
||||
x0, y0, x1, y1 = table_bbox
|
||||
table_width = (x1 - x0) * scale_w
|
||||
table_height = (y1 - y0) * scale_h
|
||||
|
||||
col_widths = [table_width / num_cols] * num_cols
|
||||
row_heights = [table_height / num_rows] * num_rows
|
||||
|
||||
# Create paragraph style
|
||||
style = ParagraphStyle(
|
||||
'TableCell',
|
||||
fontName=self.config.font_name,
|
||||
fontSize=self.config.font_size,
|
||||
alignment=TA_CENTER,
|
||||
leading=self.config.font_size * 1.2
|
||||
)
|
||||
|
||||
# Convert to Paragraph objects
|
||||
para_grid = []
|
||||
for row in grid:
|
||||
para_row = []
|
||||
for cell in row:
|
||||
if cell:
|
||||
para_row.append(Paragraph(cell, style))
|
||||
else:
|
||||
para_row.append('')
|
||||
para_grid.append(para_row)
|
||||
|
||||
# Build TableStyle
|
||||
table_style_commands = [
|
||||
('GRID', (0, 0), (-1, -1), self.config.border_width, self.config.border_color),
|
||||
('VALIGN', (0, 0), (-1, -1), self.config.vertical_align),
|
||||
('ALIGN', (0, 0), (-1, -1), self.config.horizontal_align),
|
||||
('LEFTPADDING', (0, 0), (-1, -1), self.config.left_padding),
|
||||
('RIGHTPADDING', (0, 0), (-1, -1), self.config.right_padding),
|
||||
('TOPPADDING', (0, 0), (-1, -1), self.config.top_padding),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), self.config.bottom_padding),
|
||||
('FONTNAME', (0, 0), (-1, -1), self.config.font_name),
|
||||
('FONTSIZE', (0, 0), (-1, -1), self.config.font_size),
|
||||
]
|
||||
table_style_commands.extend(span_commands)
|
||||
|
||||
# Create and draw table
|
||||
table = Table(para_grid, colWidths=col_widths, rowHeights=row_heights)
|
||||
table.setStyle(TableStyle(table_style_commands))
|
||||
|
||||
# Position and draw
|
||||
pdf_x = x0
|
||||
pdf_y = page_height - y1 # Flip Y
|
||||
|
||||
table.wrapOn(pdf_canvas, table_width, table_height)
|
||||
table.drawOn(pdf_canvas, pdf_x, pdf_y)
|
||||
|
||||
return True
|
||||
|
||||
def _render_with_dimensions(
|
||||
self,
|
||||
pdf_canvas,
|
||||
table_data: Dict,
|
||||
table_bbox: Tuple[float, float, float, float],
|
||||
page_height: float,
|
||||
col_widths: List[float],
|
||||
row_heights: List[float]
|
||||
) -> bool:
|
||||
"""Render table with specified dimensions."""
|
||||
rows = table_data.get('rows', [])
|
||||
if not rows:
|
||||
return False
|
||||
|
||||
num_rows = len(rows)
|
||||
num_cols = max(len(row.get('cells', [])) for row in rows)
|
||||
|
||||
# Adjust widths/heights if needed
|
||||
if len(col_widths) != num_cols:
|
||||
x0, y0, x1, y1 = table_bbox
|
||||
col_widths = [(x1 - x0) / num_cols] * num_cols
|
||||
if len(row_heights) != num_rows:
|
||||
x0, y0, x1, y1 = table_bbox
|
||||
row_heights = [(y1 - y0) / num_rows] * num_rows
|
||||
|
||||
# Build grid with proper positioning
|
||||
grid = []
|
||||
span_commands = []
|
||||
occupied = [[False] * num_cols for _ in range(num_rows)]
|
||||
|
||||
for row_idx, row in enumerate(rows):
|
||||
grid_row = [''] * num_cols
|
||||
|
||||
for cell in row.get('cells', []):
|
||||
# Get column position
|
||||
col_idx = cell.get('col', 0)
|
||||
|
||||
# Skip if out of bounds or occupied
|
||||
while col_idx < num_cols and occupied[row_idx][col_idx]:
|
||||
col_idx += 1
|
||||
if col_idx >= num_cols:
|
||||
continue
|
||||
|
||||
text = cell.get('text', '').strip()
|
||||
colspan = cell.get('colspan', 1)
|
||||
rowspan = cell.get('rowspan', 1)
|
||||
|
||||
grid_row[col_idx] = text
|
||||
|
||||
if colspan > 1 or rowspan > 1:
|
||||
end_col = min(col_idx + colspan - 1, num_cols - 1)
|
||||
end_row = min(row_idx + rowspan - 1, num_rows - 1)
|
||||
span_commands.append(
|
||||
('SPAN', (col_idx, row_idx), (end_col, end_row))
|
||||
)
|
||||
for r in range(row_idx, end_row + 1):
|
||||
for c in range(col_idx, end_col + 1):
|
||||
if r < num_rows and c < num_cols:
|
||||
occupied[r][c] = True
|
||||
else:
|
||||
occupied[row_idx][col_idx] = True
|
||||
|
||||
grid.append(grid_row)
|
||||
|
||||
# Create style and table
|
||||
style = ParagraphStyle(
|
||||
'TableCell',
|
||||
fontName=self.config.font_name,
|
||||
fontSize=self.config.font_size,
|
||||
alignment=TA_CENTER
|
||||
)
|
||||
|
||||
para_grid = []
|
||||
for row in grid:
|
||||
para_row = [Paragraph(cell, style) if cell else '' for cell in row]
|
||||
para_grid.append(para_row)
|
||||
|
||||
table_style_commands = [
|
||||
('GRID', (0, 0), (-1, -1), self.config.border_width, self.config.border_color),
|
||||
('VALIGN', (0, 0), (-1, -1), self.config.vertical_align),
|
||||
('LEFTPADDING', (0, 0), (-1, -1), 0),
|
||||
('RIGHTPADDING', (0, 0), (-1, -1), 0),
|
||||
('TOPPADDING', (0, 0), (-1, -1), 0),
|
||||
('BOTTOMPADDING', (0, 0), (-1, -1), 1),
|
||||
]
|
||||
table_style_commands.extend(span_commands)
|
||||
|
||||
table = Table(para_grid, colWidths=col_widths, rowHeights=row_heights)
|
||||
table.setStyle(TableStyle(table_style_commands))
|
||||
|
||||
x0, y0, x1, y1 = table_bbox
|
||||
pdf_x = x0
|
||||
pdf_y = page_height - y1
|
||||
|
||||
table.wrapOn(pdf_canvas, x1 - x0, y1 - y0)
|
||||
table.drawOn(pdf_canvas, pdf_x, pdf_y)
|
||||
|
||||
return True
|
||||
|
||||
def _build_rows_from_cells_dict(self, cells_dict: Dict) -> List[Dict]:
|
||||
"""Convert Direct track cell structure to row format."""
|
||||
cells = cells_dict.get('cells', [])
|
||||
if not cells:
|
||||
return []
|
||||
|
||||
num_rows = cells_dict.get('rows', 0)
|
||||
num_cols = cells_dict.get('cols', 0)
|
||||
|
||||
# Group cells by row
|
||||
rows_data = {}
|
||||
for cell in cells:
|
||||
row_idx = cell.get('row', 0)
|
||||
if row_idx not in rows_data:
|
||||
rows_data[row_idx] = []
|
||||
rows_data[row_idx].append(cell)
|
||||
|
||||
# Build row list
|
||||
rows = []
|
||||
for row_idx in range(num_rows):
|
||||
row_cells = rows_data.get(row_idx, [])
|
||||
|
||||
# Sort by column
|
||||
row_cells.sort(key=lambda c: c.get('col', 0))
|
||||
|
||||
formatted_cells = []
|
||||
for cell in row_cells:
|
||||
content = cell.get('content', '')
|
||||
if isinstance(content, list):
|
||||
content = '\n'.join(str(c) for c in content)
|
||||
|
||||
formatted_cells.append({
|
||||
'text': str(content) if content else '',
|
||||
'colspan': cell.get('col_span', 1),
|
||||
'rowspan': cell.get('row_span', 1),
|
||||
'col': cell.get('col', 0),
|
||||
'is_header': cell.get('is_header', False)
|
||||
})
|
||||
|
||||
rows.append({'cells': formatted_cells})
|
||||
|
||||
return rows
|
||||
|
||||
def _draw_table_border(
|
||||
self,
|
||||
pdf_canvas,
|
||||
table_bbox: Tuple[float, float, float, float],
|
||||
page_height: float
|
||||
) -> bool:
|
||||
"""Draw outer table border."""
|
||||
try:
|
||||
x0, y0, x1, y1 = table_bbox
|
||||
pdf_y0 = page_height - y1
|
||||
pdf_y1 = page_height - y0
|
||||
|
||||
pdf_canvas.saveState()
|
||||
pdf_canvas.setStrokeColor(self.config.border_color)
|
||||
pdf_canvas.setLineWidth(self.config.border_width)
|
||||
pdf_canvas.rect(x0, pdf_y0, x1 - x0, pdf_y1 - pdf_y0)
|
||||
pdf_canvas.restoreState()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to draw table border: {e}")
|
||||
return False
|
||||
|
||||
def _draw_embedded_image(
|
||||
self,
|
||||
pdf_canvas,
|
||||
img_info: Dict,
|
||||
page_height: float,
|
||||
output_dir: Path
|
||||
) -> bool:
|
||||
"""Draw an image embedded within a table cell."""
|
||||
try:
|
||||
img_path = img_info.get('path')
|
||||
if not img_path:
|
||||
return False
|
||||
|
||||
# Resolve path
|
||||
if not Path(img_path).is_absolute():
|
||||
img_path = output_dir / img_path
|
||||
|
||||
if not Path(img_path).exists():
|
||||
logger.warning(f"Embedded image not found: {img_path}")
|
||||
return False
|
||||
|
||||
bbox = img_info.get('bbox', {})
|
||||
x0 = bbox.get('x0', 0)
|
||||
y0 = bbox.get('y0', 0)
|
||||
width = bbox.get('width', 100)
|
||||
height = bbox.get('height', 100)
|
||||
|
||||
# Flip Y coordinate
|
||||
pdf_y = page_height - y0 - height
|
||||
|
||||
# Draw image
|
||||
img = ImageReader(str(img_path))
|
||||
pdf_canvas.drawImage(img, x0, pdf_y, width, height)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to draw embedded image: {e}")
|
||||
return False
|
||||
|
||||
def _fit_text_to_cell(
|
||||
self,
|
||||
pdf_canvas,
|
||||
text: str,
|
||||
cell_width: float,
|
||||
cell_height: float
|
||||
) -> int:
|
||||
"""Find font size that fits text in cell."""
|
||||
for size in range(self.config.max_font_size, self.config.min_font_size - 1, -1):
|
||||
text_width = pdf_canvas.stringWidth(text, self.config.font_name, size)
|
||||
if text_width <= cell_width - 6: # 3pt padding each side
|
||||
return size
|
||||
return self.config.min_font_size
|
||||
|
||||
def _merge_boundaries(self, values: List[float], threshold: float) -> List[float]:
|
||||
"""Merge nearby boundary values."""
|
||||
if not values:
|
||||
return []
|
||||
|
||||
merged = [values[0]]
|
||||
for v in values[1:]:
|
||||
if abs(v - merged[-1]) > threshold:
|
||||
merged.append(v)
|
||||
|
||||
return merged
|
||||
|
||||
def _cluster_values(self, values: List[float], threshold: float) -> List[Tuple[float, List[float]]]:
|
||||
"""Cluster nearby values and return (average, members) pairs."""
|
||||
if not values:
|
||||
return []
|
||||
|
||||
clusters = []
|
||||
current_cluster = [values[0]]
|
||||
|
||||
for v in values[1:]:
|
||||
if abs(v - current_cluster[-1]) <= threshold:
|
||||
current_cluster.append(v)
|
||||
else:
|
||||
avg = sum(current_cluster) / len(current_cluster)
|
||||
clusters.append((avg, current_cluster))
|
||||
current_cluster = [v]
|
||||
|
||||
if current_cluster:
|
||||
avg = sum(current_cluster) / len(current_cluster)
|
||||
clusters.append((avg, current_cluster))
|
||||
|
||||
return clusters
|
||||
645
backend/app/services/processing_orchestrator.py
Normal file
645
backend/app/services/processing_orchestrator.py
Normal file
@@ -0,0 +1,645 @@
|
||||
"""
|
||||
Processing Orchestrator - Coordinates document processing across tracks.
|
||||
|
||||
This module provides a unified orchestration layer for document processing,
|
||||
separating the high-level flow control from track-specific implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from app.models.unified_document import (
|
||||
ProcessingTrack,
|
||||
UnifiedDocument,
|
||||
DocumentMetadata,
|
||||
ElementType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Data Classes
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class ProcessingConfig:
|
||||
"""Configuration for document processing."""
|
||||
detect_layout: bool = True
|
||||
confidence_threshold: float = 0.5
|
||||
output_dir: Optional[Path] = None
|
||||
lang: str = "ch"
|
||||
layout_model: str = "ppyolov2_r50vd_dcn_365e_publaynet"
|
||||
preprocessing_mode: str = "auto"
|
||||
preprocessing_config: Optional[Dict] = None
|
||||
table_detection_config: Optional[Dict] = None
|
||||
force_track: Optional[str] = None # "direct" or "ocr"
|
||||
use_dual_track: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackRecommendation:
|
||||
"""Recommendation for which processing track to use."""
|
||||
track: ProcessingTrack
|
||||
confidence: float
|
||||
reason: str
|
||||
metrics: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingResult:
|
||||
"""Result of document processing."""
|
||||
document: Optional[UnifiedDocument] = None
|
||||
legacy_result: Optional[Dict] = None
|
||||
track_used: ProcessingTrack = ProcessingTrack.DIRECT
|
||||
processing_time: float = 0.0
|
||||
success: bool = True
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Pipeline Interface
|
||||
# ============================================================================
|
||||
|
||||
class ProcessingPipeline(ABC):
|
||||
"""Abstract base class for processing pipelines."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def track_type(self) -> ProcessingTrack:
|
||||
"""Return the processing track type for this pipeline."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process(
|
||||
self,
|
||||
file_path: Path,
|
||||
config: ProcessingConfig
|
||||
) -> ProcessingResult:
|
||||
"""
|
||||
Process a document through this pipeline.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document
|
||||
config: Processing configuration
|
||||
|
||||
Returns:
|
||||
ProcessingResult with the extracted document
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_process(self, file_path: Path) -> bool:
|
||||
"""
|
||||
Check if this pipeline can process the given file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document
|
||||
|
||||
Returns:
|
||||
True if the pipeline can process this file type
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Direct Track Pipeline
|
||||
# ============================================================================
|
||||
|
||||
class DirectPipeline(ProcessingPipeline):
|
||||
"""Pipeline for processing editable PDFs via direct text extraction."""
|
||||
|
||||
def __init__(self):
|
||||
self._engine = None
|
||||
self._office_converter = None
|
||||
|
||||
@property
|
||||
def track_type(self) -> ProcessingTrack:
|
||||
return ProcessingTrack.DIRECT
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
"""Lazy-load DirectExtractionEngine."""
|
||||
if self._engine is None:
|
||||
from app.services.direct_extraction_engine import DirectExtractionEngine
|
||||
self._engine = DirectExtractionEngine(
|
||||
enable_table_detection=True,
|
||||
enable_image_extraction=True,
|
||||
min_image_area=200.0,
|
||||
enable_whiteout_detection=True,
|
||||
enable_content_sanitization=True
|
||||
)
|
||||
return self._engine
|
||||
|
||||
@property
|
||||
def office_converter(self):
|
||||
"""Lazy-load OfficeConverter."""
|
||||
if self._office_converter is None:
|
||||
from app.services.office_converter import OfficeConverter
|
||||
self._office_converter = OfficeConverter()
|
||||
return self._office_converter
|
||||
|
||||
def can_process(self, file_path: Path) -> bool:
|
||||
"""Check if file is processable (PDF or Office document)."""
|
||||
suffix = file_path.suffix.lower()
|
||||
return suffix in ['.pdf', '.docx', '.doc', '.xlsx', '.xls', '.pptx', '.ppt']
|
||||
|
||||
def process(
|
||||
self,
|
||||
file_path: Path,
|
||||
config: ProcessingConfig
|
||||
) -> ProcessingResult:
|
||||
"""Process document using direct text extraction."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info(f"DirectPipeline: Processing {file_path.name}")
|
||||
|
||||
# Handle Office document conversion
|
||||
actual_path = file_path
|
||||
if self._is_office_document(file_path):
|
||||
actual_path = self._convert_office_to_pdf(file_path, config.output_dir)
|
||||
if actual_path is None:
|
||||
return ProcessingResult(
|
||||
success=False,
|
||||
error=f"Failed to convert Office document: {file_path.name}",
|
||||
track_used=ProcessingTrack.DIRECT,
|
||||
processing_time=time.time() - start_time
|
||||
)
|
||||
|
||||
# Extract document
|
||||
unified_doc = self.engine.extract(
|
||||
actual_path,
|
||||
output_dir=config.output_dir
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Update metadata
|
||||
if unified_doc.metadata is None:
|
||||
unified_doc.metadata = DocumentMetadata()
|
||||
unified_doc.metadata.processing_track = ProcessingTrack.DIRECT
|
||||
unified_doc.metadata.processing_time = processing_time
|
||||
|
||||
logger.info(f"DirectPipeline: Completed in {processing_time:.2f}s, "
|
||||
f"{len(unified_doc.pages)} pages extracted")
|
||||
|
||||
return ProcessingResult(
|
||||
document=unified_doc,
|
||||
track_used=ProcessingTrack.DIRECT,
|
||||
processing_time=processing_time,
|
||||
success=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DirectPipeline: Error processing {file_path.name}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return ProcessingResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
track_used=ProcessingTrack.DIRECT,
|
||||
processing_time=time.time() - start_time
|
||||
)
|
||||
|
||||
def check_for_missing_images(
|
||||
self,
|
||||
file_path: Path,
|
||||
unified_doc: UnifiedDocument
|
||||
) -> List[int]:
|
||||
"""
|
||||
Check if document has pages with missing inline images.
|
||||
|
||||
Args:
|
||||
file_path: Path to the PDF
|
||||
unified_doc: Extracted document
|
||||
|
||||
Returns:
|
||||
List of page indices with missing images
|
||||
"""
|
||||
return self.engine.check_document_for_missing_images(file_path)
|
||||
|
||||
def render_missing_images(
|
||||
self,
|
||||
file_path: Path,
|
||||
unified_doc: UnifiedDocument,
|
||||
page_list: List[int],
|
||||
output_dir: Path
|
||||
) -> UnifiedDocument:
|
||||
"""
|
||||
Render inline image regions that couldn't be extracted.
|
||||
|
||||
Args:
|
||||
file_path: Path to the PDF
|
||||
unified_doc: Document to update
|
||||
page_list: Pages with missing images
|
||||
output_dir: Directory for output images
|
||||
|
||||
Returns:
|
||||
Updated UnifiedDocument
|
||||
"""
|
||||
return self.engine.render_inline_image_regions(
|
||||
file_path, unified_doc, page_list, output_dir
|
||||
)
|
||||
|
||||
def _is_office_document(self, file_path: Path) -> bool:
|
||||
"""Check if file is an Office document."""
|
||||
return self.office_converter.is_office_document(file_path)
|
||||
|
||||
def _convert_office_to_pdf(
|
||||
self,
|
||||
file_path: Path,
|
||||
output_dir: Optional[Path]
|
||||
) -> Optional[Path]:
|
||||
"""Convert Office document to PDF."""
|
||||
try:
|
||||
return self.office_converter.convert_to_pdf(file_path, output_dir)
|
||||
except Exception as e:
|
||||
logger.error(f"Office conversion failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OCR Track Pipeline
|
||||
# ============================================================================
|
||||
|
||||
class OCRPipeline(ProcessingPipeline):
|
||||
"""Pipeline for processing scanned documents via OCR."""
|
||||
|
||||
def __init__(self):
|
||||
self._ocr_service = None
|
||||
self._converter = None
|
||||
|
||||
@property
|
||||
def track_type(self) -> ProcessingTrack:
|
||||
return ProcessingTrack.OCR
|
||||
|
||||
@property
|
||||
def ocr_service(self):
|
||||
"""
|
||||
Get reference to OCR service.
|
||||
Note: This creates a circular dependency that needs careful handling.
|
||||
The OCRPipeline should receive the service via dependency injection.
|
||||
"""
|
||||
if self._ocr_service is None:
|
||||
raise RuntimeError(
|
||||
"OCRPipeline requires OCR service to be set via set_ocr_service()"
|
||||
)
|
||||
return self._ocr_service
|
||||
|
||||
def set_ocr_service(self, service):
|
||||
"""Set the OCR service for this pipeline (dependency injection)."""
|
||||
self._ocr_service = service
|
||||
|
||||
@property
|
||||
def converter(self):
|
||||
"""Lazy-load OCR to Unified converter."""
|
||||
if self._converter is None:
|
||||
from app.services.ocr_to_unified_converter import OCRToUnifiedConverter
|
||||
self._converter = OCRToUnifiedConverter()
|
||||
return self._converter
|
||||
|
||||
def can_process(self, file_path: Path) -> bool:
|
||||
"""Check if file is processable (images or PDFs)."""
|
||||
suffix = file_path.suffix.lower()
|
||||
return suffix in ['.pdf', '.png', '.jpg', '.jpeg', '.tiff', '.tif', '.bmp']
|
||||
|
||||
def process(
|
||||
self,
|
||||
file_path: Path,
|
||||
config: ProcessingConfig
|
||||
) -> ProcessingResult:
|
||||
"""Process document using OCR."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info(f"OCRPipeline: Processing {file_path.name}")
|
||||
|
||||
# Use OCR service's traditional processing
|
||||
ocr_result = self.ocr_service.process_file_traditional(
|
||||
file_path,
|
||||
detect_layout=config.detect_layout,
|
||||
confidence_threshold=config.confidence_threshold,
|
||||
output_dir=config.output_dir,
|
||||
lang=config.lang,
|
||||
layout_model=config.layout_model,
|
||||
preprocessing_mode=config.preprocessing_mode,
|
||||
preprocessing_config=config.preprocessing_config,
|
||||
table_detection_config=config.table_detection_config
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Convert to UnifiedDocument
|
||||
unified_doc = self.converter.convert(
|
||||
ocr_result,
|
||||
file_path,
|
||||
processing_time,
|
||||
config.lang
|
||||
)
|
||||
|
||||
# Update metadata
|
||||
if unified_doc.metadata is None:
|
||||
unified_doc.metadata = DocumentMetadata()
|
||||
unified_doc.metadata.processing_track = ProcessingTrack.OCR
|
||||
unified_doc.metadata.processing_time = processing_time
|
||||
|
||||
logger.info(f"OCRPipeline: Completed in {processing_time:.2f}s")
|
||||
|
||||
return ProcessingResult(
|
||||
document=unified_doc,
|
||||
legacy_result=ocr_result,
|
||||
track_used=ProcessingTrack.OCR,
|
||||
processing_time=processing_time,
|
||||
success=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OCRPipeline: Error processing {file_path.name}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return ProcessingResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
track_used=ProcessingTrack.OCR,
|
||||
processing_time=time.time() - start_time
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Processing Orchestrator
|
||||
# ============================================================================
|
||||
|
||||
class ProcessingOrchestrator:
|
||||
"""
|
||||
Orchestrates document processing across Direct and OCR tracks.
|
||||
|
||||
This class coordinates the high-level processing flow:
|
||||
1. Determines the optimal processing track
|
||||
2. Routes to the appropriate pipeline
|
||||
3. Handles hybrid mode (Direct + OCR fallback)
|
||||
4. Manages result format conversion
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._document_detector = None
|
||||
self._direct_pipeline = DirectPipeline()
|
||||
self._ocr_pipeline = OCRPipeline()
|
||||
|
||||
@property
|
||||
def document_detector(self):
|
||||
"""Lazy-load DocumentTypeDetector."""
|
||||
if self._document_detector is None:
|
||||
from app.services.document_type_detector import DocumentTypeDetector
|
||||
self._document_detector = DocumentTypeDetector()
|
||||
return self._document_detector
|
||||
|
||||
@property
|
||||
def direct_pipeline(self) -> DirectPipeline:
|
||||
return self._direct_pipeline
|
||||
|
||||
@property
|
||||
def ocr_pipeline(self) -> OCRPipeline:
|
||||
return self._ocr_pipeline
|
||||
|
||||
def set_ocr_service(self, service):
|
||||
"""Set OCR service for the OCR pipeline (dependency injection)."""
|
||||
self._ocr_pipeline.set_ocr_service(service)
|
||||
|
||||
def determine_processing_track(
|
||||
self,
|
||||
file_path: Path,
|
||||
force_track: Optional[str] = None
|
||||
) -> TrackRecommendation:
|
||||
"""
|
||||
Determine the optimal processing track for a document.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document
|
||||
force_track: Optional override ("direct" or "ocr")
|
||||
|
||||
Returns:
|
||||
TrackRecommendation with track, confidence, and reason
|
||||
"""
|
||||
# Handle forced track
|
||||
if force_track:
|
||||
track = ProcessingTrack.DIRECT if force_track == "direct" else ProcessingTrack.OCR
|
||||
return TrackRecommendation(
|
||||
track=track,
|
||||
confidence=1.0,
|
||||
reason=f"Forced to use {force_track} track"
|
||||
)
|
||||
|
||||
# Use document detector
|
||||
try:
|
||||
recommendation = self.document_detector.detect(file_path)
|
||||
# Convert string track to ProcessingTrack enum
|
||||
track_str = recommendation.track
|
||||
if isinstance(track_str, str):
|
||||
track = ProcessingTrack.DIRECT if track_str == "direct" else ProcessingTrack.OCR
|
||||
else:
|
||||
track = track_str # Already an enum
|
||||
return TrackRecommendation(
|
||||
track=track,
|
||||
confidence=recommendation.confidence,
|
||||
reason=recommendation.reason,
|
||||
metrics=getattr(recommendation, 'metrics', {})
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Document detection failed: {e}, defaulting to DIRECT")
|
||||
return TrackRecommendation(
|
||||
track=ProcessingTrack.DIRECT,
|
||||
confidence=0.5,
|
||||
reason=f"Detection failed ({e}), using default"
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
file_path: Path,
|
||||
config: ProcessingConfig
|
||||
) -> ProcessingResult:
|
||||
"""
|
||||
Process a document using the optimal track.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document
|
||||
config: Processing configuration
|
||||
|
||||
Returns:
|
||||
ProcessingResult with extracted document
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(f"ProcessingOrchestrator: Processing {file_path.name}")
|
||||
|
||||
# Determine track
|
||||
recommendation = self.determine_processing_track(
|
||||
file_path,
|
||||
config.force_track
|
||||
)
|
||||
|
||||
logger.info(f"Track recommendation: {recommendation.track.value} "
|
||||
f"(confidence: {recommendation.confidence:.2f}, "
|
||||
f"reason: {recommendation.reason})")
|
||||
|
||||
# Route to appropriate pipeline
|
||||
if recommendation.track == ProcessingTrack.DIRECT:
|
||||
result = self._execute_direct_with_fallback(file_path, config)
|
||||
else:
|
||||
result = self._ocr_pipeline.process(file_path, config)
|
||||
|
||||
# Update total processing time
|
||||
result.processing_time = time.time() - start_time
|
||||
|
||||
return result
|
||||
|
||||
def _execute_direct_with_fallback(
|
||||
self,
|
||||
file_path: Path,
|
||||
config: ProcessingConfig
|
||||
) -> ProcessingResult:
|
||||
"""
|
||||
Execute direct track with hybrid fallback for missing images.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document
|
||||
config: Processing configuration
|
||||
|
||||
Returns:
|
||||
ProcessingResult (may be HYBRID if OCR was used for images)
|
||||
"""
|
||||
# Run direct extraction
|
||||
result = self._direct_pipeline.process(file_path, config)
|
||||
|
||||
if not result.success or result.document is None:
|
||||
logger.warning("Direct extraction failed, falling back to OCR")
|
||||
return self._ocr_pipeline.process(file_path, config)
|
||||
|
||||
# Check for missing images
|
||||
try:
|
||||
missing_pages = self._direct_pipeline.check_for_missing_images(
|
||||
file_path, result.document
|
||||
)
|
||||
|
||||
if missing_pages:
|
||||
logger.info(f"Found {len(missing_pages)} pages with missing images, "
|
||||
f"entering hybrid mode")
|
||||
return self._execute_hybrid(
|
||||
file_path, config, result.document, missing_pages
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Missing image check failed: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def _execute_hybrid(
|
||||
self,
|
||||
file_path: Path,
|
||||
config: ProcessingConfig,
|
||||
direct_doc: UnifiedDocument,
|
||||
missing_pages: List[int]
|
||||
) -> ProcessingResult:
|
||||
"""
|
||||
Execute hybrid mode: Direct extraction + OCR for missing images.
|
||||
|
||||
Args:
|
||||
file_path: Path to the document
|
||||
config: Processing configuration
|
||||
direct_doc: Document from direct extraction
|
||||
missing_pages: Pages with missing images
|
||||
|
||||
Returns:
|
||||
ProcessingResult with HYBRID track
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Try OCR for missing images
|
||||
ocr_result = self._ocr_pipeline.process(file_path, config)
|
||||
|
||||
if ocr_result.success and ocr_result.document:
|
||||
# Merge OCR images into direct result
|
||||
images_added = self._merge_ocr_images(
|
||||
direct_doc,
|
||||
ocr_result.document,
|
||||
missing_pages
|
||||
)
|
||||
logger.info(f"Hybrid mode: Added {images_added} images from OCR")
|
||||
else:
|
||||
# Fallback: render inline images directly
|
||||
logger.warning("OCR failed, rendering inline images as fallback")
|
||||
if config.output_dir:
|
||||
direct_doc = self._direct_pipeline.render_missing_images(
|
||||
file_path,
|
||||
direct_doc,
|
||||
missing_pages,
|
||||
config.output_dir
|
||||
)
|
||||
|
||||
# Update metadata
|
||||
if direct_doc.metadata is None:
|
||||
direct_doc.metadata = DocumentMetadata()
|
||||
direct_doc.metadata.processing_track = ProcessingTrack.HYBRID
|
||||
|
||||
return ProcessingResult(
|
||||
document=direct_doc,
|
||||
track_used=ProcessingTrack.HYBRID,
|
||||
processing_time=time.time() - start_time,
|
||||
success=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hybrid processing failed: {e}")
|
||||
# Return direct result as-is
|
||||
return ProcessingResult(
|
||||
document=direct_doc,
|
||||
track_used=ProcessingTrack.DIRECT,
|
||||
processing_time=time.time() - start_time,
|
||||
success=True
|
||||
)
|
||||
|
||||
def _merge_ocr_images(
|
||||
self,
|
||||
direct_doc: UnifiedDocument,
|
||||
ocr_doc: UnifiedDocument,
|
||||
target_pages: List[int]
|
||||
) -> int:
|
||||
"""
|
||||
Merge image elements from OCR result into direct result.
|
||||
|
||||
Args:
|
||||
direct_doc: Target document
|
||||
ocr_doc: Source document with images
|
||||
target_pages: Page indices to merge images from
|
||||
|
||||
Returns:
|
||||
Number of images added
|
||||
"""
|
||||
images_added = 0
|
||||
|
||||
for page_idx in target_pages:
|
||||
if page_idx >= len(direct_doc.pages) or page_idx >= len(ocr_doc.pages):
|
||||
continue
|
||||
|
||||
direct_page = direct_doc.pages[page_idx]
|
||||
ocr_page = ocr_doc.pages[page_idx]
|
||||
|
||||
# Find image elements in OCR result
|
||||
for elem in ocr_page.elements:
|
||||
if elem.type in [
|
||||
ElementType.IMAGE, ElementType.FIGURE,
|
||||
ElementType.CHART, ElementType.DIAGRAM,
|
||||
ElementType.LOGO, ElementType.STAMP
|
||||
]:
|
||||
# Generate unique element ID
|
||||
elem.element_id = f"ocr_img_{page_idx}_{images_added}"
|
||||
direct_page.elements.append(elem)
|
||||
images_added += 1
|
||||
|
||||
return images_added
|
||||
@@ -14,6 +14,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from app.services.memory_manager import get_model_manager, MemoryConfig
|
||||
from app.services.memory_policy_engine import get_memory_policy_engine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.ocr_service import OCRService
|
||||
@@ -263,10 +264,16 @@ class OCRServicePool:
|
||||
|
||||
# Clean up GPU memory after release
|
||||
try:
|
||||
model_manager = get_model_manager()
|
||||
model_manager.memory_guard.clear_gpu_cache()
|
||||
except Exception as e:
|
||||
logger.debug(f"Cache clear after release failed: {e}")
|
||||
# Prefer new MemoryPolicyEngine
|
||||
engine = get_memory_policy_engine()
|
||||
engine.clear_cache()
|
||||
except Exception:
|
||||
# Fallback to legacy model_manager
|
||||
try:
|
||||
model_manager = get_model_manager()
|
||||
model_manager.memory_guard.clear_gpu_cache()
|
||||
except Exception as e:
|
||||
logger.debug(f"Cache clear after release failed: {e}")
|
||||
|
||||
# Notify waiting threads
|
||||
self._condition.notify_all()
|
||||
|
||||
Reference in New Issue
Block a user