提交 46e3442e 编写于 作者: xuyang2233's avatar xuyang2233

add spin

上级 f56a7e9c
...@@ -75,7 +75,7 @@ Train: ...@@ -75,7 +75,7 @@ Train:
data_dir: ./train_data/ic15_data/ data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms: transforms:
- NRTRDecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- SPINAttnLabelEncode: # Class handling label - SPINAttnLabelEncode: # Class handling label
...@@ -98,7 +98,7 @@ Eval: ...@@ -98,7 +98,7 @@ Eval:
data_dir: ./train_data/ic15_data data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms: transforms:
- NRTRDecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- SPINAttnLabelEncode: # Class handling label - SPINAttnLabelEncode: # Class handling label
......
...@@ -274,6 +274,7 @@ class SPINRecResizeImg(object): ...@@ -274,6 +274,7 @@ class SPINRecResizeImg(object):
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# different interpolation type corresponding the OpenCV # different interpolation type corresponding the OpenCV
if self.interpolation == 0: if self.interpolation == 0:
interpolation = cv2.INTER_NEAREST interpolation = cv2.INTER_NEAREST
...@@ -294,12 +295,9 @@ class SPINRecResizeImg(object): ...@@ -294,12 +295,9 @@ class SPINRecResizeImg(object):
img = np.expand_dims(img, -1) img = np.expand_dims(img, -1)
img = img.transpose((2, 0, 1)) img = img.transpose((2, 0, 1))
# normalize the image # normalize the image
to_rgb = False
img = img.copy().astype(np.float32) img = img.copy().astype(np.float32)
mean = np.float64(self.mean.reshape(1, -1)) mean = np.float64(self.mean.reshape(1, -1))
stdinv = 1 / np.float64(self.std.reshape(1, -1)) stdinv = 1 / np.float64(self.std.reshape(1, -1))
if to_rgb:
cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img -= mean img -= mean
img *= stdinv img *= stdinv
data['image'] = img data['image'] = img
......
...@@ -76,7 +76,7 @@ Train: ...@@ -76,7 +76,7 @@ Train:
data_dir: ./train_data/ic15_data/ data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms: transforms:
- NRTRDecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- SPINAttnLabelEncode: # Class handling label - SPINAttnLabelEncode: # Class handling label
...@@ -99,7 +99,7 @@ Eval: ...@@ -99,7 +99,7 @@ Eval:
data_dir: ./train_data/ic15_data data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms: transforms:
- NRTRDecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- SPINAttnLabelEncode: # Class handling label - SPINAttnLabelEncode: # Class handling label
......
...@@ -91,7 +91,7 @@ def export_single_model(model, ...@@ -91,7 +91,7 @@ def export_single_model(model,
] ]
# print([None, 3, 32, 128]) # print([None, 3, 32, 128])
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "NRTR": elif arch_config["algorithm"] == "NRTR" or arch_config["algorithm"] == "SPIN":
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 1, 32, 100], dtype="float32"), shape=[None, 1, 32, 100], dtype="float32"),
......
...@@ -81,7 +81,6 @@ class TextRecognizer(object): ...@@ -81,7 +81,6 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path, "character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char "use_space_char": args.use_space_char
} }
elif self.rec_algorithm == "SPIN": elif self.rec_algorithm == "SPIN":
postprocess_params = { postprocess_params = {
'name': 'SPINAttnLabelDecode', 'name': 'SPINAttnLabelDecode',
...@@ -362,6 +361,8 @@ class TextRecognizer(object): ...@@ -362,6 +361,8 @@ class TextRecognizer(object):
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
elif self.rec_algorithm == 'SPIN': elif self.rec_algorithm == 'SPIN':
norm_img = self.resize_norm_img_spin(img_list[indices[ino]]) norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "ABINet": elif self.rec_algorithm == "ABINet":
norm_img = self.resize_norm_img_abinet( norm_img = self.resize_norm_img_abinet(
img_list[indices[ino]], self.rec_image_shape) img_list[indices[ino]], self.rec_image_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册