diff --git a/doc/doc_ch/algorithm_rec_srn.md b/doc/doc_ch/algorithm_rec_srn.md index ca7961359eb902fafee959b26d02f324aece233a..dd61a388c7024fabdadec1c120bd3341ed0197cc 100644 --- a/doc/doc_ch/algorithm_rec_srn.md +++ b/doc/doc_ch/algorithm_rec_srn.md @@ -78,7 +78,7 @@ python3 tools/export_model.py -c configs/rec/rec_r50_fpn_srn.yml -o Global.pretr SRN文本识别模型推理,可以执行如下命令: ``` -python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_srn/" --rec_image_shape="1,64,256" --rec_char_type="ch" --rec_algorithm="SRN" --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --use_space_char=False +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_srn/" --rec_image_shape="1,64,256" --rec_algorithm="SRN" --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --use_space_char=False ``` diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 7c46e17bacdf1fff464322d284e4549bd8edacf2..176e2c68e2c9b2e08f9b56378c45a57733faf8cd 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -349,6 +349,13 @@ class TextRecognizer(object): for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] + if self.rec_algorithm == "SRN": + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + if self.rec_algorithm == "SAR": + valid_ratios = [] imgC, imgH, imgW = self.rec_image_shape[:3] max_wh_ratio = imgW / imgH # max_wh_ratio = 0 @@ -357,22 +364,16 @@ class TextRecognizer(object): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - if self.rec_algorithm == "SAR": norm_img, _, _, valid_ratio = self.resize_norm_img_sar( img_list[indices[ino]], self.rec_image_shape) norm_img = norm_img[np.newaxis, :] valid_ratio = np.expand_dims(valid_ratio, axis=0) - valid_ratios = [] valid_ratios.append(valid_ratio) norm_img_batch.append(norm_img) elif self.rec_algorithm == "SRN": norm_img = self.process_image_srn( img_list[indices[ino]], self.rec_image_shape, 8, 25) - encoder_word_pos_list = [] - gsrm_word_pos_list = [] - gsrm_slf_attn_bias1_list = [] - gsrm_slf_attn_bias2_list = [] encoder_word_pos_list.append(norm_img[1]) gsrm_word_pos_list.append(norm_img[2]) gsrm_slf_attn_bias1_list.append(norm_img[3])