From 46e3442e2e38f4449d06a77010634d43fab16331 Mon Sep 17 00:00:00 2001 From: smilelite Date: Sun, 10 Jul 2022 11:47:25 +0800 Subject: [PATCH] add spin --- configs/rec/rec_r32_gaspin_bilstm_att.yml | 4 ++-- ppocr/data/imaug/rec_img_aug.py | 4 +--- .../rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml | 4 ++-- tools/export_model.py | 2 +- tools/infer/predict_rec.py | 3 ++- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml index e8235415..f7c1b813 100644 --- a/configs/rec/rec_r32_gaspin_bilstm_att.yml +++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml @@ -75,7 +75,7 @@ Train: data_dir: ./train_data/ic15_data/ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] transforms: - - NRTRDecodeImage: # load image + - DecodeImage: # load image img_mode: BGR channel_first: False - SPINAttnLabelEncode: # Class handling label @@ -98,7 +98,7 @@ Eval: data_dir: ./train_data/ic15_data label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] transforms: - - NRTRDecodeImage: # load image + - DecodeImage: # load image img_mode: BGR channel_first: False - SPINAttnLabelEncode: # Class handling label diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 3f002173..c5d8a3b2 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -274,6 +274,7 @@ class SPINRecResizeImg(object): def __call__(self, data): img = data['image'] + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # different interpolation type corresponding the OpenCV if self.interpolation == 0: interpolation = cv2.INTER_NEAREST @@ -294,12 +295,9 @@ class SPINRecResizeImg(object): img = np.expand_dims(img, -1) img = img.transpose((2, 0, 1)) # normalize the image - to_rgb = False img = img.copy().astype(np.float32) mean = np.float64(self.mean.reshape(1, -1)) stdinv = 1 / np.float64(self.std.reshape(1, -1)) - if to_rgb: - cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img -= mean img *= stdinv data['image'] = img diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml index 3999ecda..a08efe57 100644 --- a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml +++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml @@ -76,7 +76,7 @@ Train: data_dir: ./train_data/ic15_data/ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] transforms: - - NRTRDecodeImage: # load image + - DecodeImage: # load image img_mode: BGR channel_first: False - SPINAttnLabelEncode: # Class handling label @@ -99,7 +99,7 @@ Eval: data_dir: ./train_data/ic15_data label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] transforms: - - NRTRDecodeImage: # load image + - DecodeImage: # load image img_mode: BGR channel_first: False - SPINAttnLabelEncode: # Class handling label diff --git a/tools/export_model.py b/tools/export_model.py index afecbff8..4855c53a 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -91,7 +91,7 @@ def export_single_model(model, ] # print([None, 3, 32, 128]) model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "NRTR": + elif arch_config["algorithm"] == "NRTR" or arch_config["algorithm"] == "SPIN": other_shape = [ paddle.static.InputSpec( shape=[None, 1, 32, 100], dtype="float32"), diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 7f9aea09..5a8cb84f 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -81,7 +81,6 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } - elif self.rec_algorithm == "SPIN": postprocess_params = { 'name': 'SPINAttnLabelDecode', @@ -362,6 +361,8 @@ class TextRecognizer(object): norm_img_batch.append(norm_img) elif self.rec_algorithm == 'SPIN': norm_img = self.resize_norm_img_spin(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) elif self.rec_algorithm == "ABINet": norm_img = self.resize_norm_img_abinet( img_list[indices[ino]], self.rec_image_shape) -- GitLab