From 1b959e3eaae2aed3c7ea9e99b1809586c681bf51 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Sun, 25 Apr 2021 12:49:45 +0000 Subject: [PATCH] add save rec res --- .../rec_chinese_common_train_v2.0.yml | 1 + .../rec_chinese_lite_train_v2.0.yml | 1 + configs/rec/rec_icdar15_train.yml | 1 + configs/rec/rec_mv3_none_bilstm_ctc.yml | 1 + configs/rec/rec_mv3_none_none_ctc.yml | 1 + configs/rec/rec_mv3_tps_bilstm_att.yml | 1 + configs/rec/rec_mv3_tps_bilstm_ctc.yml | 1 + configs/rec/rec_r34_vd_none_bilstm_ctc.yml | 1 + configs/rec/rec_r34_vd_none_none_ctc.yml | 1 + configs/rec/rec_r34_vd_tps_bilstm_att.yml | 1 + configs/rec/rec_r34_vd_tps_bilstm_ctc.yml | 1 + configs/rec/rec_r50_fpn_srn.yml | 1 + tools/infer_rec.py | 66 +++++++++++-------- 13 files changed, 50 insertions(+), 28 deletions(-) diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml index 6a524e22..717c1681 100644 --- a/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml +++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: True + save_res_path: ./output/rec/predicts_chinese_common_v2.0.txt Optimizer: diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml index c96621c5..d60832b7 100644 --- a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml +++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: True + # save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt Optimizer: diff --git a/configs/rec/rec_icdar15_train.yml b/configs/rec/rec_icdar15_train.yml index 5ae47c67..79e3ff88 100644 --- a/configs/rec/rec_icdar15_train.yml +++ b/configs/rec/rec_icdar15_train.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_ic15.txt Optimizer: name: Adam diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index 900e98b6..9e0bd23e 100644 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_mv3_none_bilstm_ctc.txt Optimizer: name: Adam diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml index 6d86b90c..904afe11 100644 --- a/configs/rec/rec_mv3_none_none_ctc.yml +++ b/configs/rec/rec_mv3_none_none_ctc.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_mv3_none_none_ctc.txt Optimizer: name: Adam diff --git a/configs/rec/rec_mv3_tps_bilstm_att.yml b/configs/rec/rec_mv3_tps_bilstm_att.yml index 33aed74d..feaeb054 100644 --- a/configs/rec/rec_mv3_tps_bilstm_att.yml +++ b/configs/rec/rec_mv3_tps_bilstm_att.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_mv3_tps_bilstm_att.txt Optimizer: diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml index 026c6a9d..65ab23c4 100644 --- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_mv3_tps_bilstm_ctc.txt Optimizer: name: Adam diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml index 4052d426..331bb36e 100644 --- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_r34_vd_none_bilstm_ctc.txt Optimizer: name: Adam diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml index c3e1d9a3..695a4695 100644 --- a/configs/rec/rec_r34_vd_none_none_ctc.yml +++ b/configs/rec/rec_r34_vd_none_none_ctc.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_r34_vd_none_none_ctc.txt Optimizer: name: Adam diff --git a/configs/rec/rec_r34_vd_tps_bilstm_att.yml b/configs/rec/rec_r34_vd_tps_bilstm_att.yml index 87a14559..fdd3588c 100644 --- a/configs/rec/rec_r34_vd_tps_bilstm_att.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_att.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt Optimizer: diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml index 9c51962e..67108a6e 100644 --- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_r34_vd_tps_bilstm_ctc.txt Optimizer: name: Adam diff --git a/configs/rec/rec_r50_fpn_srn.yml b/configs/rec/rec_r50_fpn_srn.yml index 34a997f3..fa7b1ae4 100644 --- a/configs/rec/rec_r50_fpn_srn.yml +++ b/configs/rec/rec_r50_fpn_srn.yml @@ -20,6 +20,7 @@ Global: num_heads: 8 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_srn.txt Optimizer: diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 075ec261..2563f5a8 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -73,35 +73,45 @@ def main(): global_config['infer_mode'] = True ops = create_operators(transforms, global_config) + save_res_path = config['Global'].get('save_res_path', + "./output/rec/predicts_rec.txt") + if not os.path.exists(os.path.dirname(save_res_path)): + os.makedirs(os.path.dirname(save_res_path)) + model.eval() - for file in get_image_file_list(config['Global']['infer_img']): - logger.info("infer_img: {}".format(file)) - with open(file, 'rb') as f: - img = f.read() - data = {'image': img} - batch = transform(data, ops) - if config['Architecture']['algorithm'] == "SRN": - encoder_word_pos_list = np.expand_dims(batch[1], axis=0) - gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) - gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) - gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) - - others = [ - paddle.to_tensor(encoder_word_pos_list), - paddle.to_tensor(gsrm_word_pos_list), - paddle.to_tensor(gsrm_slf_attn_bias1_list), - paddle.to_tensor(gsrm_slf_attn_bias2_list) - ] - - images = np.expand_dims(batch[0], axis=0) - images = paddle.to_tensor(images) - if config['Architecture']['algorithm'] == "SRN": - preds = model(images, others) - else: - preds = model(images) - post_result = post_process_class(preds) - for rec_reuslt in post_result: - logger.info('\t result: {}'.format(rec_reuslt)) + + with open(save_res_path, "w") as fout: + for file in get_image_file_list(config['Global']['infer_img']): + logger.info("infer_img: {}".format(file)) + with open(file, 'rb') as f: + img = f.read() + data = {'image': img} + batch = transform(data, ops) + if config['Architecture']['algorithm'] == "SRN": + encoder_word_pos_list = np.expand_dims(batch[1], axis=0) + gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) + gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) + gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) + + others = [ + paddle.to_tensor(encoder_word_pos_list), + paddle.to_tensor(gsrm_word_pos_list), + paddle.to_tensor(gsrm_slf_attn_bias1_list), + paddle.to_tensor(gsrm_slf_attn_bias2_list) + ] + + images = np.expand_dims(batch[0], axis=0) + images = paddle.to_tensor(images) + if config['Architecture']['algorithm'] == "SRN": + preds = model(images, others) + else: + preds = model(images) + post_result = post_process_class(preds) + for rec_reuslt in post_result: + logger.info('\t result: {}'.format(rec_reuslt)) + if len(rec_reuslt) >= 2: + fout.write(file + "\t" + rec_reuslt[0] + "\t" + str( + rec_reuslt[1]) + "\n") logger.info("success!") -- GitLab