From b40ffdd45c8e5bf65aab1ff54a08f3be4849c4d3 Mon Sep 17 00:00:00 2001 From: Topdu <784990967@qq.com> Date: Wed, 29 Sep 2021 01:50:24 +0000 Subject: [PATCH] fix sar export inference model --- tools/export_model.py | 6 ++++ tools/infer/predict_rec.py | 71 +++++++++++++++++++++++++++++++++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/tools/export_model.py b/tools/export_model.py index d8fe2972..64a0d403 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -49,6 +49,12 @@ def export_single_model(model, arch_config, save_path, logger): ] ] model = to_static(model, input_spec=other_shape) + elif arch_config["algorithm"] == "SAR": + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 48, 160], dtype="float32"), + ] + model = to_static(model, input_spec=other_shape) else: infer_shape = [3, -1, -1] if arch_config["model_type"] == "rec": diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 332cffd5..dad70281 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -68,6 +68,13 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + elif self.rec_algorithm == "SAR": + postprocess_params = { + 'name': 'SARLabelDecode', + "character_type": args.rec_char_type, + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) @@ -194,6 +201,41 @@ class TextRecognizer(object): return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2) + def resize_norm_img_sar(self, img, image_shape, + width_downsample_ratio=0.25): + imgC, imgH, imgW_min, imgW_max = image_shape + h = img.shape[0] + w = img.shape[1] + valid_ratio = 1.0 + # make sure new_width is an integral multiple of width_divisor. + width_divisor = int(1 / width_downsample_ratio) + # resize + ratio = w / float(h) + resize_w = math.ceil(imgH * ratio) + if resize_w % width_divisor != 0: + resize_w = round(resize_w / width_divisor) * width_divisor + if imgW_min is not None: + resize_w = max(imgW_min, resize_w) + if imgW_max is not None: + valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) + resize_w = min(imgW_max, resize_w) + resized_image = cv2.resize(img, (resize_w, imgH)) + resized_image = resized_image.astype('float32') + # norm + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + resize_shape = resized_image.shape + padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) + padding_im[:, :, 0:resize_w] = resized_image + pad_shape = padding_im.shape + + return padding_im, resize_shape, pad_shape, valid_ratio + def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -216,11 +258,19 @@ class TextRecognizer(object): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - if self.rec_algorithm != "SRN": + if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR": norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) + elif self.rec_algorithm == "SAR": + norm_img, _, _, valid_ratio = self.resize_norm_img_sar( + img_list[indices[ino]], self.rec_image_shape) + norm_img = norm_img[np.newaxis, :] + valid_ratio = np.expand_dims(valid_ratio, axis=0) + valid_ratios = [] + valid_ratios.append(valid_ratio) + norm_img_batch.append(norm_img) else: norm_img = self.process_image_srn( img_list[indices[ino]], self.rec_image_shape, 8, 25) @@ -266,6 +316,25 @@ class TextRecognizer(object): if self.benchmark: self.autolog.times.stamp() preds = {"predict": outputs[2]} + elif self.rec_algorithm == "SAR": + valid_ratios = np.concatenate(valid_ratios) + inputs = [ + norm_img_batch, + valid_ratios, + ] + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[ + i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + preds = outputs[0] else: self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.run() -- GitLab