diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml index e8235415c7e75addfc5599e181844a227d1e4eff..f7c1b813fcd6d0553027c3c6d0f7191e950c63ad 100644 --- a/configs/rec/rec_r32_gaspin_bilstm_att.yml +++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml @@ -75,7 +75,7 @@ Train: data_dir: ./train_data/ic15_data/ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] transforms: - - NRTRDecodeImage: # load image + - DecodeImage: # load image img_mode: BGR channel_first: False - SPINAttnLabelEncode: # Class handling label @@ -98,7 +98,7 @@ Eval: data_dir: ./train_data/ic15_data label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] transforms: - - NRTRDecodeImage: # load image + - DecodeImage: # load image img_mode: BGR channel_first: False - SPINAttnLabelEncode: # Class handling label diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 3f002173d105130f90ebac7e33d2bad9f081f1c1..c5d8a3b2fd773a1877a788401a926d7fbca07adf 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -274,6 +274,7 @@ class SPINRecResizeImg(object): def __call__(self, data): img = data['image'] + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # different interpolation type corresponding the OpenCV if self.interpolation == 0: interpolation = cv2.INTER_NEAREST @@ -294,12 +295,9 @@ class SPINRecResizeImg(object): img = np.expand_dims(img, -1) img = img.transpose((2, 0, 1)) # normalize the image - to_rgb = False img = img.copy().astype(np.float32) mean = np.float64(self.mean.reshape(1, -1)) stdinv = 1 / np.float64(self.std.reshape(1, -1)) - if to_rgb: - cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img -= mean img *= stdinv data['image'] = img diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml index 3999ecda8bfca5d1c92daa198bdedc9cd4f9732e..a08efe579642e3dc959a568a7ca40c8ca3f8614c 100644 --- a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml +++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml @@ -76,7 +76,7 @@ Train: data_dir: ./train_data/ic15_data/ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] transforms: - - NRTRDecodeImage: # load image + - DecodeImage: # load image img_mode: BGR channel_first: False - SPINAttnLabelEncode: # Class handling label @@ -99,7 +99,7 @@ Eval: data_dir: ./train_data/ic15_data label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] transforms: - - NRTRDecodeImage: # load image + - DecodeImage: # load image img_mode: BGR channel_first: False - SPINAttnLabelEncode: # Class handling label diff --git a/tools/export_model.py b/tools/export_model.py index afecbff8cbb834a5aa5ef3ea1448cf04fbd8c3bb..4855c53a978706c52feaebeb7b3649a71bd66b8e 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -91,7 +91,7 @@ def export_single_model(model, ] # print([None, 3, 32, 128]) model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "NRTR": + elif arch_config["algorithm"] == "NRTR" or arch_config["algorithm"] == "SPIN": other_shape = [ paddle.static.InputSpec( shape=[None, 1, 32, 100], dtype="float32"), diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 7f9aea09d4269b80d4eb4136824499ccc80e23b8..5a8cb84f758c9a364fba17500318fa1a93283c0e 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -81,7 +81,6 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } - elif self.rec_algorithm == "SPIN": postprocess_params = { 'name': 'SPINAttnLabelDecode', @@ -362,6 +361,8 @@ class TextRecognizer(object): norm_img_batch.append(norm_img) elif self.rec_algorithm == 'SPIN': norm_img = self.resize_norm_img_spin(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) elif self.rec_algorithm == "ABINet": norm_img = self.resize_norm_img_abinet( img_list[indices[ino]], self.rec_image_shape)