提交 46e3442e 编写于 作者: xuyang2233's avatar xuyang2233

add spin

上级 f56a7e9c
......@@ -75,7 +75,7 @@ Train:
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms:
- NRTRDecodeImage: # load image
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SPINAttnLabelEncode: # Class handling label
......@@ -98,7 +98,7 @@ Eval:
data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms:
- NRTRDecodeImage: # load image
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SPINAttnLabelEncode: # Class handling label
......
......@@ -274,6 +274,7 @@ class SPINRecResizeImg(object):
def __call__(self, data):
img = data['image']
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# different interpolation type corresponding the OpenCV
if self.interpolation == 0:
interpolation = cv2.INTER_NEAREST
......@@ -294,12 +295,9 @@ class SPINRecResizeImg(object):
img = np.expand_dims(img, -1)
img = img.transpose((2, 0, 1))
# normalize the image
to_rgb = False
img = img.copy().astype(np.float32)
mean = np.float64(self.mean.reshape(1, -1))
stdinv = 1 / np.float64(self.std.reshape(1, -1))
if to_rgb:
cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img -= mean
img *= stdinv
data['image'] = img
......
......@@ -76,7 +76,7 @@ Train:
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms:
- NRTRDecodeImage: # load image
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SPINAttnLabelEncode: # Class handling label
......@@ -99,7 +99,7 @@ Eval:
data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms:
- NRTRDecodeImage: # load image
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SPINAttnLabelEncode: # Class handling label
......
......@@ -91,7 +91,7 @@ def export_single_model(model,
]
# print([None, 3, 32, 128])
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "NRTR":
elif arch_config["algorithm"] == "NRTR" or arch_config["algorithm"] == "SPIN":
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 32, 100], dtype="float32"),
......
......@@ -81,7 +81,6 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "SPIN":
postprocess_params = {
'name': 'SPINAttnLabelDecode',
......@@ -362,6 +361,8 @@ class TextRecognizer(object):
norm_img_batch.append(norm_img)
elif self.rec_algorithm == 'SPIN':
norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "ABINet":
norm_img = self.resize_norm_img_abinet(
img_list[indices[ino]], self.rec_image_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册