diff --git a/deploy/cpp_infer/src/postprocess_op.cpp b/deploy/cpp_infer/src/postprocess_op.cpp
index 8d7af6474c996067a532e3d3eefb86ea5d6d3e3b..551f98a1668124f83ef615f0a41b081508898d6e 100644
--- a/deploy/cpp_infer/src/postprocess_op.cpp
+++ b/deploy/cpp_infer/src/postprocess_op.cpp
@@ -400,7 +400,7 @@ void TablePostProcessor::Run(
score += char_score;
rec_html_tags.push_back(html_tag);
// box
- if (html_tag == "
" || html_tag == " | " || html_tag == " | | ") {
for (int point_idx = 0; point_idx < loc_preds_shape[2];
point_idx += 2) {
std::vector point(2, 0);
@@ -416,7 +416,7 @@ void TablePostProcessor::Run(
}
}
score /= count;
- if (isnan(score) || rec_boxes.size() == 0 || rec_html_tags.size() == 0) {
+ if (isnan(score) || rec_boxes.size() == 0) {
score = -1;
}
rec_scores.push_back(score);
diff --git a/ppocr/utils/visual.py b/ppocr/utils/visual.py
index 030d1c38d2d985d9673eb1ca6e1664adb21608f6..5bd805ea6e76be37612a142102beab492bece941 100644
--- a/ppocr/utils/visual.py
+++ b/ppocr/utils/visual.py
@@ -114,6 +114,7 @@ def draw_re_results(image,
def draw_rectangle(img_path, boxes):
+ boxes = np.array(boxes)
img = cv2.imread(img_path)
img_show = img.copy()
for box in boxes.astype(int):
diff --git a/ppstructure/docs/quickstart.md b/ppstructure/docs/quickstart.md
index 700800759923ec4f3703bdb542bee0df930d6cd1..f4645bdfe011a12370bedc7bd7a125b28ded41ff 100644
--- a/ppstructure/docs/quickstart.md
+++ b/ppstructure/docs/quickstart.md
@@ -4,7 +4,7 @@
- [2. 便捷使用](#2-便捷使用)
- [2.1 命令行使用](#21-命令行使用)
- [2.1.1 图像方向分类+版面分析+表格识别](#211-图像方向分类版面分析表格识别)
- - [2.1.1 版面分析+表格识别](#211-版面分析表格识别)
+ - [2.1.2 版面分析+表格识别](#212-版面分析表格识别)
- [2.1.3 版面分析](#213-版面分析)
- [2.1.4 表格识别](#214-表格识别)
- [2.1.5 DocVQA](#215-docvqa)
@@ -44,7 +44,7 @@ paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --
```
-#### 2.1.1 版面分析+表格识别
+#### 2.1.2 版面分析+表格识别
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
```
diff --git a/ppstructure/table/eval_table.py b/ppstructure/table/eval_table.py
index 435d693223bb4e7b7fb51e8778135775ad345f0d..4fc16b5d4c6a0143dcea149508bd6b62730092b6 100755
--- a/ppstructure/table/eval_table.py
+++ b/ppstructure/table/eval_table.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import os
import sys
diff --git a/ppstructure/table/predict_table.py b/ppstructure/table/predict_table.py
index f580213753b10d775e031fa70f69fdd774f4a9df..e94347d86144cd66474546e99a2c9dffee4978d9 100644
--- a/ppstructure/table/predict_table.py
+++ b/ppstructure/table/predict_table.py
@@ -117,7 +117,6 @@ class TableSystem(object):
pred_html = self.match(structure_res, dt_boxes, rec_res)
toc = time.time()
time_dict['match'] = toc - tic
- # pred_html = self.match(1, 1, 1,img_name)
result['html'] = pred_html
if self.benchmark:
self.autolog.times.end(stamp=True)
@@ -212,8 +211,12 @@ def main(args):
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
- # img = predict_strture.draw_rectangle(image_file, pred_res['cell_bbox'], use_xywh)
- img = utility.draw_boxes(cv2.imread(image_file), pred_res['cell_bbox'])
+ if len(pred_res['cell_bbox']) > 0 and len(pred_res['cell_bbox'][
+ 0]) == 4:
+ img = predict_strture.draw_rectangle(image_file,
+ pred_res['cell_bbox'])
+ else:
+ img = utility.draw_boxes(img, pred_res['cell_bbox'])
img_save_path = os.path.join(args.output, os.path.basename(image_file))
cv2.imwrite(img_save_path, img)
diff --git a/ppstructure/table/table_master_match.py b/ppstructure/table/table_master_match.py
index 6a4c4e9dc688f683f009b135821a020a2e391957..163b13063b245210d38f22bc2b741f9019cdb03d 100644
--- a/ppstructure/table/table_master_match.py
+++ b/ppstructure/table/table_master_match.py
@@ -273,10 +273,6 @@ def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes):
end2end_sorted_idx_list, end2end_sorted_bbox_list \
= flatten(sorted_groups, sorted_bbox_groups)
- # check sorted
- #img = cv2.imread('/data_0/yejiaquan/data/TableRecognization/singleVal/PMC3286376_004_00.png')
- #img = drawBboxAfterSorted(img, sorted_groups, sorted_bbox_groups)
-
return end2end_sorted_idx_list, end2end_sorted_bbox_list, sorted_groups, sorted_bbox_groups
@@ -302,9 +298,6 @@ def get_bboxes_list(end2end_result, structure_master_result):
# structure master
src_bboxes = structure_master_result['bbox']
src_bboxes = remove_empty_bboxes(src_bboxes)
- # structure_master_xywh_bboxes = src_bboxes
- # xyxy_bboxes = xywh2xyxy(src_bboxes)
- # structure_master_xyxy_bboxes = xyxy_bboxes
structure_master_xyxy_bboxes = src_bboxes
xywh_bbox = xyxy2xywh(src_bboxes)
structure_master_xywh_bboxes = xywh_bbox
@@ -410,64 +403,6 @@ def extra_match(no_match_end2end_indexes, master_bbox_nums):
return extra_match_list
-def match_visual(file_name,
- match_list,
- end2end_xyxy,
- master_xyxy,
- prex='ordinary_match'):
- """
- Show the match result by xyxy coord style.
- :param file_name:
- :param match_list:
- :param end2end_xyxy:
- :param master_xyxy:
- :param prex:
- :return:
- """
- folder = ''
- save_folder = '/data_0/cache'
- file_path = os.path.join(folder, file_name)
- img_end2end = cv2.imread(file_path)
- img_master = copy.deepcopy(img_end2end)
- text_color = (0, 0, 255)
- bbox_color = (255, 0, 0)
- master_nums = len(master_xyxy)
-
- for idx, match_group in enumerate(match_list):
- end2end_idx, master_index = match_group[0], match_group[1]
-
- # master_index larger than master_nums, did not draw master bbox.
- if master_index < master_nums:
- # draw master
- master_bbox = master_xyxy[master_index]
- img_master = cv2.rectangle(
- img_master, (int(master_bbox[0]), int(master_bbox[1])),
- (int(master_bbox[2]), int(master_bbox[3])),
- bbox_color,
- thickness=1)
- master_text_coord = (int(master_bbox[0]) - 4, int(master_bbox[1]))
- img_master = cv2.putText(img_master,
- str(master_index), master_text_coord, 1, 1,
- text_color, 2)
-
- # draw end2end
- end2end_bbox = end2end_xyxy[end2end_idx]
- img_end2end = cv2.rectangle(
- img_end2end, (int(end2end_bbox[0]), int(end2end_bbox[1])),
- (int(end2end_bbox[2]), int(end2end_bbox[3])),
- bbox_color,
- thickness=1)
- end2end_text_coord = (int(end2end_bbox[0]) - 4, int(end2end_bbox[1]))
- # write end2end bbox matching master bbox's index
- img_end2end = cv2.putText(img_end2end,
- str(master_index), end2end_text_coord, 1, 1,
- text_color, 2)
-
- img = np.hstack([img_end2end, img_master])
- save_path = os.path.join(save_folder, '{}_matchShow.png'.format(prex))
- cv2.imwrite(save_path, img)
-
-
def get_match_dict(match_list):
"""
Convert match_list to a dict, where key is master bbox's index, value is end2end bbox index.
@@ -555,8 +490,6 @@ def merge_span_token(master_token_list):
pattern
' | ' + ' | '
"""
- # tmp = master_token_list[pointer] + master_token_list[pointer+1] + master_token_list[pointer+2] + \
- # master_token_list[pointer+3]
tmp = ''.join(master_token_list[pointer:pointer + 3 + 1])
pointer += 4
new_master_token_list.append(tmp)
@@ -569,8 +502,6 @@ def merge_span_token(master_token_list):
pattern
' | ' + ' | '
"""
- # tmp = master_token_list[pointer] + master_token_list[pointer+1] + \
- # master_token_list[pointer+2] + master_token_list[pointer+3] + master_token_list[pointer+4]
tmp = ''.join(master_token_list[pointer:pointer + 4 + 1])
pointer += 5
new_master_token_list.append(tmp)
@@ -909,11 +840,6 @@ class Matcher:
'sorted_bboxes_groups': sorted_bboxes_groups
}
- # ordinary match show
- # match_visual(file_name, match_list, end2end_xyxy_bboxes, structure_master_xyxy_bboxes, prex='ordinary_match')
- # extra match show
- # match_visual(file_name, match_list_add_extra_match, end2end_xyxy_bboxes, structure_master_xyxy_bboxes, prex='extra_match')
-
# format output
match_result_dict = self._format(match_result_dict, file_name)
diff --git a/tools/train.py b/tools/train.py
index 0c881ecae8daf78860829b1419178358c2209f25..a46cd67cb93271ca03a53018bc5140d5375910a7 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -125,6 +125,7 @@ def main(config, device, logger, vdl_writer):
logger.info('convert_sync_batchnorm')
model = apply_to_static(model, config, logger)
+ logger.info(model)
# build loss
loss_class = build_loss(config['Loss'])