From 5a524807f9029b3a4dd4497c0e1dbbbb581b346d Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Tue, 23 Aug 2022 18:05:03 +0800 Subject: [PATCH] fix infer_rec for other input (#7286) --- doc/doc_ch/algorithm_rec_srn.md | 2 +- tools/infer/predict_rec.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/doc/doc_ch/algorithm_rec_srn.md b/doc/doc_ch/algorithm_rec_srn.md index ca796135..dd61a388 100644 --- a/doc/doc_ch/algorithm_rec_srn.md +++ b/doc/doc_ch/algorithm_rec_srn.md @@ -78,7 +78,7 @@ python3 tools/export_model.py -c configs/rec/rec_r50_fpn_srn.yml -o Global.pretr SRN文本识别模型推理,可以执行如下命令: ``` -python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_srn/" --rec_image_shape="1,64,256" --rec_char_type="ch" --rec_algorithm="SRN" --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --use_space_char=False +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_srn/" --rec_image_shape="1,64,256" --rec_algorithm="SRN" --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --use_space_char=False ``` diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 7c46e17b..176e2c68 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -349,6 +349,13 @@ class TextRecognizer(object): for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] + if self.rec_algorithm == "SRN": + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + if self.rec_algorithm == "SAR": + valid_ratios = [] imgC, imgH, imgW = self.rec_image_shape[:3] max_wh_ratio = imgW / imgH # max_wh_ratio = 0 @@ -357,22 +364,16 @@ class TextRecognizer(object): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - if self.rec_algorithm == "SAR": norm_img, _, _, valid_ratio = self.resize_norm_img_sar( img_list[indices[ino]], self.rec_image_shape) norm_img = norm_img[np.newaxis, :] valid_ratio = np.expand_dims(valid_ratio, axis=0) - valid_ratios = [] valid_ratios.append(valid_ratio) norm_img_batch.append(norm_img) elif self.rec_algorithm == "SRN": norm_img = self.process_image_srn( img_list[indices[ino]], self.rec_image_shape, 8, 25) - encoder_word_pos_list = [] - gsrm_word_pos_list = [] - gsrm_slf_attn_bias1_list = [] - gsrm_slf_attn_bias2_list = [] encoder_word_pos_list.append(norm_img[1]) gsrm_word_pos_list.append(norm_img[2]) gsrm_slf_attn_bias1_list.append(norm_img[3]) -- GitLab