未验证 提交 6f7e07e6 编写于 作者: X xiaoting 提交者: GitHub

Merge pull request #4216 from andyjpaddle/fix_export_sar

fix sar export inference model
......@@ -49,6 +49,12 @@ def export_single_model(model, arch_config, save_path, logger):
]
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "SAR":
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 48, 160], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
......
......@@ -68,6 +68,13 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "SAR":
postprocess_params = {
'name': 'SARLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
......@@ -194,6 +201,41 @@ class TextRecognizer(object):
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def resize_norm_img_sar(self, img, image_shape,
width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
......@@ -216,11 +258,19 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm != "SRN":
if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR":
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
valid_ratio = np.expand_dims(valid_ratio, axis=0)
valid_ratios = []
valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img)
else:
norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25)
......@@ -266,6 +316,25 @@ class TextRecognizer(object):
if self.benchmark:
self.autolog.times.stamp()
preds = {"predict": outputs[2]}
elif self.rec_algorithm == "SAR":
valid_ratios = np.concatenate(valid_ratios)
inputs = [
norm_img_batch,
valid_ratios,
]
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[
i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if self.benchmark:
self.autolog.times.stamp()
preds = outputs[0]
else:
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册