Skip to content

Commit

Permalink
fix tabled inference (layout model changes)
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Nov 30, 2024
1 parent 46345b6 commit 694ea62
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 16 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ tabled_gui
```python
from tabled.extract import extract_tables
from tabled.fileinput import load_pdfs_images
from tabled.inference.models import load_detection_models, load_recognition_models
from tabled.inference.models import load_detection_models, load_recognition_models, load_layout_models

det_models, rec_models = load_detection_models(), load_recognition_models()
det_models, rec_models, layout_models = load_detection_models(), load_recognition_models(), load_layout_models()
images, highres_images, names, text_lines = load_pdfs_images(IN_PATH)

page_results = extract_tables(images, highres_images, text_lines, det_models, rec_models)
page_results = extract_tables(images, highres_images, text_lines, det_models, layout_models, rec_models)
```

# Benchmarks
Expand Down
5 changes: 3 additions & 2 deletions extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tabled.extract import extract_tables
from tabled.formats import formatter
from tabled.fileinput import load_pdfs_images
from tabled.inference.models import load_detection_models, load_recognition_models
from tabled.inference.models import load_detection_models, load_recognition_models, load_layout_models


@click.command(help="Extract tables from PDFs")
Expand All @@ -36,8 +36,9 @@ def main(in_path, out_folder, save_json, save_debug_images, skip_detection, dete

det_models = load_detection_models()
rec_models = load_recognition_models()
layout_models = load_layout_models()

page_results = extract_tables(images, highres_images, text_lines, det_models, rec_models, skip_detection=skip_detection, detect_boxes=detect_cell_boxes)
page_results = extract_tables(images, highres_images, text_lines, det_models, layout_models, rec_models, skip_detection=skip_detection, detect_boxes=detect_cell_boxes)

out_json = defaultdict(list)
for name, pnum, result in zip(names, pnums, page_results):
Expand Down
8 changes: 4 additions & 4 deletions table_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,24 @@
from PIL import Image

import streamlit as st
from tabled.inference.models import load_detection_models, load_recognition_models
from tabled.inference.models import load_detection_models, load_recognition_models, load_layout_models


@st.cache_resource()
def load_models():
return load_detection_models(), load_recognition_models()
return load_detection_models(), load_recognition_models(), load_layout_models()


def run_table_rec(image, highres_image, text_line, models, skip_detection=False, detect_boxes=False):
if not skip_detection:
table_imgs, table_bboxes, _ = detect_tables([image], [highres_image], models[0])
table_imgs, table_bboxes, _ = detect_tables([image], [highres_image], models[2])
else:
table_imgs = [highres_image]
table_bboxes = [[0, 0, highres_image.size[0], highres_image.size[1]]]

table_text_lines = [text_line] * len(table_imgs)
highres_image_sizes = [highres_image.size] * len(table_imgs)
cells, needs_ocr = get_cells(table_imgs, table_bboxes, highres_image_sizes, table_text_lines, models[0][:2], detect_boxes=detect_boxes)
cells, needs_ocr = get_cells(table_imgs, table_bboxes, highres_image_sizes, table_text_lines, models[0], detect_boxes=detect_boxes)

table_rec = recognize_tables(table_imgs, cells, needs_ocr, models[1])
cells = [assign_rows_columns(tr, im_size) for tr, im_size in zip(table_rec, highres_image_sizes)]
Expand Down
5 changes: 3 additions & 2 deletions tabled/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def extract_tables(
highres_images,
text_lines,
det_models,
layout_models,
rec_models,
skip_detection=False,
detect_boxes=False
) -> List[ExtractPageResult]:
if not skip_detection:
table_imgs, table_bboxes, table_counts = detect_tables(images, highres_images, det_models)
table_imgs, table_bboxes, table_counts = detect_tables(images, highres_images, layout_models)
else:
table_imgs = highres_images
table_bboxes = [[0, 0, img.size[0], img.size[1]] for img in highres_images]
Expand All @@ -30,7 +31,7 @@ def extract_tables(
table_text_lines.extend([text_lines[i]] * tc)
highres_image_sizes.extend([highres_images[i].size] * tc)

cells, needs_ocr = get_cells(table_imgs, table_bboxes, highres_image_sizes, table_text_lines, det_models[:2], detect_boxes=detect_boxes)
cells, needs_ocr = get_cells(table_imgs, table_bboxes, highres_image_sizes, table_text_lines, det_models, detect_boxes=detect_boxes)

table_rec = recognize_tables(table_imgs, cells, needs_ocr, rec_models)
cells = [assign_rows_columns(tr, im_size) for tr, im_size in zip(table_rec, highres_image_sizes)]
Expand Down
1 change: 0 additions & 1 deletion tabled/inference/detection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.postprocessing.util import rescale_bbox
from surya.schema import Bbox
Expand Down
13 changes: 9 additions & 4 deletions tabled/inference/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@


def load_detection_models():
layout_model = load_layout_model()
layout_processor = load_layout_processor()
return layout_model, layout_processor
detection_model = load_det_model()
detection_processor = load_det_processor()
return detection_model, detection_processor


def load_recognition_models():
table_rec_model = load_table_rec_model()
table_rec_processor = load_table_rec_processor()
rec_model = load_rec_model()
rec_processor = load_rec_processor()
return table_rec_model, table_rec_processor, rec_model, rec_processor
return table_rec_model, table_rec_processor, rec_model, rec_processor

def load_layout_models():
layout_model = load_layout_model()
layout_processor = load_layout_processor()
return layout_model, layout_processor

0 comments on commit 694ea62

Please sign in to comment.