From be5fdae5739283dd782e1c3029eaec075900b3f4 Mon Sep 17 00:00:00 2001 From: Jethong <1147925384@qq.com> Date: Tue, 13 Apr 2021 15:33:09 +0800 Subject: [PATCH] fix config format --- configs/e2e/e2e_r50_vd_pg.yml | 5 ++--- ppocr/data/pgnet_dataset.py | 5 ++++- tools/infer/predict_e2e.py | 2 +- tools/infer_e2e.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index cd81ffbf..6dbeb522 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -69,10 +69,9 @@ Metric: Train: dataset: name: PGDataSet - data_dir: ./train_data/ + data_dir: ./train_data/train label_file_list: [.././train_data/total_text/train/] ratio_list: [1.0] - data_format: icdar #two data format: icdar/textnet transforms: - DecodeImage: # load image img_mode: BGR @@ -94,7 +93,7 @@ Train: Eval: dataset: name: PGDataSet - data_dir: ./train_data/ + data_dir: ./train_data/test label_file_list: [./train_data/total_text/test/] transforms: - DecodeImage: # load image diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py index c6bc694f..b9819358 100644 --- a/ppocr/data/pgnet_dataset.py +++ b/ppocr/data/pgnet_dataset.py @@ -78,7 +78,10 @@ class PGDataSet(Dataset): file_name = substr[0] label = substr[1] img_path = os.path.join(self.data_dir, file_name) - img_id = int(data_line.split(".")[0][3:]) + if self.mode.lower() == 'eval': + img_id = int(data_line.split(".")[0][7:]) + else: + img_id = 0 data = {'img_path': img_path, 'label': label, 'img_id': img_id} if not os.path.exists(img_path): raise Exception("{} does not exist!".format(img_path)) diff --git a/tools/infer/predict_e2e.py b/tools/infer/predict_e2e.py index 8b94f24a..cd6c2005 100755 --- a/tools/infer/predict_e2e.py +++ b/tools/infer/predict_e2e.py @@ -122,7 +122,7 @@ class TextE2E(object): else: raise NotImplementedError post_result = self.postprocess_op(preds, shape_list) - points, strs = post_result['points'], post_result['strs'] + points, strs = post_result['points'], post_result['texts'] dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape) elapse = time.time() - starttime return dt_boxes, strs, elapse diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py index b7503adb..9c079f60 100755 --- a/tools/infer_e2e.py +++ b/tools/infer_e2e.py @@ -103,7 +103,7 @@ def main(): images = paddle.to_tensor(images) preds = model(images) post_result = post_process_class(preds, shape_list) - points, strs = post_result['points'], post_result['strs'] + points, strs = post_result['points'], post_result['texts'] # write resule dt_boxes_json = [] for poly, str in zip(points, strs): -- GitLab