diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index cd81ffbf75ca8ee568228a377b18df80105720cd..6dbeb522b38d267478b9601bc301fa240291751f 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -69,10 +69,9 @@ Metric: Train: dataset: name: PGDataSet - data_dir: ./train_data/ + data_dir: ./train_data/train label_file_list: [.././train_data/total_text/train/] ratio_list: [1.0] - data_format: icdar #two data format: icdar/textnet transforms: - DecodeImage: # load image img_mode: BGR @@ -94,7 +93,7 @@ Train: Eval: dataset: name: PGDataSet - data_dir: ./train_data/ + data_dir: ./train_data/test label_file_list: [./train_data/total_text/test/] transforms: - DecodeImage: # load image diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py index c6bc694f1f5d4e657002b76593013964a383ada2..b98193584857177855d60e9e380e16619536e74e 100644 --- a/ppocr/data/pgnet_dataset.py +++ b/ppocr/data/pgnet_dataset.py @@ -78,7 +78,10 @@ class PGDataSet(Dataset): file_name = substr[0] label = substr[1] img_path = os.path.join(self.data_dir, file_name) - img_id = int(data_line.split(".")[0][3:]) + if self.mode.lower() == 'eval': + img_id = int(data_line.split(".")[0][7:]) + else: + img_id = 0 data = {'img_path': img_path, 'label': label, 'img_id': img_id} if not os.path.exists(img_path): raise Exception("{} does not exist!".format(img_path)) diff --git a/tools/infer/predict_e2e.py b/tools/infer/predict_e2e.py index 8b94f24a9ffa16ca5683795afa6392fa23c24a94..cd6c2005a7cc77c356e3f004cd586a84676ea7fa 100755 --- a/tools/infer/predict_e2e.py +++ b/tools/infer/predict_e2e.py @@ -122,7 +122,7 @@ class TextE2E(object): else: raise NotImplementedError post_result = self.postprocess_op(preds, shape_list) - points, strs = post_result['points'], post_result['strs'] + points, strs = post_result['points'], post_result['texts'] dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape) elapse = time.time() - starttime return dt_boxes, strs, elapse diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py index b7503adb94eb797d4fb12cf47b377fa72d02158b..9c079f6074f088ef0298cab839f74faefad82abb 100755 --- a/tools/infer_e2e.py +++ b/tools/infer_e2e.py @@ -103,7 +103,7 @@ def main(): images = paddle.to_tensor(images) preds = model(images) post_result = post_process_class(preds, shape_list) - points, strs = post_result['points'], post_result['strs'] + points, strs = post_result['points'], post_result['texts'] # write resule dt_boxes_json = [] for poly, str in zip(points, strs):