未验证 提交 5a524807 编写于 作者: X xiaoting 提交者: GitHub

fix infer_rec for other input (#7286)

上级 1dbff736
...@@ -78,7 +78,7 @@ python3 tools/export_model.py -c configs/rec/rec_r50_fpn_srn.yml -o Global.pretr ...@@ -78,7 +78,7 @@ python3 tools/export_model.py -c configs/rec/rec_r50_fpn_srn.yml -o Global.pretr
SRN文本识别模型推理,可以执行如下命令: SRN文本识别模型推理,可以执行如下命令:
``` ```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_srn/" --rec_image_shape="1,64,256" --rec_char_type="ch" --rec_algorithm="SRN" --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --use_space_char=False python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_srn/" --rec_image_shape="1,64,256" --rec_algorithm="SRN" --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --use_space_char=False
``` ```
<a name="4-2"></a> <a name="4-2"></a>
......
...@@ -349,6 +349,13 @@ class TextRecognizer(object): ...@@ -349,6 +349,13 @@ class TextRecognizer(object):
for beg_img_no in range(0, img_num, batch_num): for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num) end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = [] norm_img_batch = []
if self.rec_algorithm == "SRN":
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
if self.rec_algorithm == "SAR":
valid_ratios = []
imgC, imgH, imgW = self.rec_image_shape[:3] imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH max_wh_ratio = imgW / imgH
# max_wh_ratio = 0 # max_wh_ratio = 0
...@@ -357,22 +364,16 @@ class TextRecognizer(object): ...@@ -357,22 +364,16 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm == "SAR": if self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar( norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape) img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
valid_ratio = np.expand_dims(valid_ratio, axis=0) valid_ratio = np.expand_dims(valid_ratio, axis=0)
valid_ratios = []
valid_ratios.append(valid_ratio) valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
elif self.rec_algorithm == "SRN": elif self.rec_algorithm == "SRN":
norm_img = self.process_image_srn( norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25) img_list[indices[ino]], self.rec_image_shape, 8, 25)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1]) encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2]) gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3]) gsrm_slf_attn_bias1_list.append(norm_img[3])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册