提交 5161d825 编写于 作者: Z zhiboniu 提交者: zhiboniu

delete other rec algorithm

上级 653009ab
......@@ -19,7 +19,6 @@ VEHICLE_PLATE:
det_model_dir: output_inference/ch_PP-OCRv3_det_infer/
det_limit_side_len: 480
det_limit_type: "max"
rec_algorithm: "SVTR_LCNet"
rec_model_dir: output_inference/ch_PP-OCRv3_rec_infer/
rec_image_shape: [3, 48, 320]
rec_batch_num: 6
......
......@@ -151,7 +151,6 @@ class TextRecognizer(object):
def __init__(self, args, cfg, use_gpu=True):
self.rec_image_shape = cfg['rec_image_shape']
self.rec_batch_num = cfg['rec_batch_num']
self.rec_algorithm = cfg['rec_algorithm']
word_dict_path = cfg['word_dict_path']
use_space_char = True
......@@ -160,30 +159,6 @@ class TextRecognizer(object):
"character_dict_path": word_dict_path,
"use_space_char": use_space_char
}
if self.rec_algorithm == "SRN":
postprocess_params = {
'name': 'SRNLabelDecode',
"character_dict_path": word_dict_path,
"use_space_char": use_space_char
}
elif self.rec_algorithm == "RARE":
postprocess_params = {
'name': 'AttnLabelDecode',
"character_dict_path": word_dict_path,
"use_space_char": use_space_char
}
elif self.rec_algorithm == 'NRTR':
postprocess_params = {
'name': 'NRTRLabelDecode',
"character_dict_path": word_dict_path,
"use_space_char": use_space_char
}
elif self.rec_algorithm == "SAR":
postprocess_params = {
'name': 'SARLabelDecode',
"character_dict_path": word_dict_path,
"use_space_char": use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
create_predictor(args, cfg, 'rec')
......@@ -191,15 +166,6 @@ class TextRecognizer(object):
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
if self.rec_algorithm == 'NRTR':
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize([100, 32], Image.ANTIALIAS)
img = np.array(img)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
return norm_img.astype(np.float32) / 128. - 1.
assert imgC == img.shape[2]
imgW = int((imgH * max_wh_ratio))
......@@ -214,10 +180,6 @@ class TextRecognizer(object):
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
if self.rec_algorithm == 'RARE':
if resized_w > self.rec_image_shape[2]:
resized_w = self.rec_image_shape[2]
imgW = self.rec_image_shape[2]
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
......@@ -227,124 +189,6 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
encoder_word_pos = encoder_word_pos.astype(np.int64)
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def resize_norm_img_sar(self, img, image_shape,
width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def predict_text(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
......@@ -367,103 +211,16 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
valid_ratio = np.expand_dims(valid_ratio, axis=0)
valid_ratios = []
valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "SRN":
norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
elif self.rec_algorithm == "SVTR":
norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
if self.rec_algorithm == "SRN":
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
inputs = [
norm_img_batch,
encoder_word_pos_list,
gsrm_word_pos_list,
gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list,
]
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(self.output_tensors,
input_dict)
preds = {"predict": outputs[2]}
else:
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(
input_names[i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {"predict": outputs[2]}
elif self.rec_algorithm == "SAR":
valid_ratios = np.concatenate(valid_ratios)
inputs = [
norm_img_batch,
valid_ratios,
]
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(self.output_tensors,
input_dict)
preds = outputs[0]
else:
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(
input_names[i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
else:
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(self.output_tensors,
input_dict)
outputs = self.predictor.run(self.output_tensors, input_dict)
preds = outputs[0]
else:
self.input_tensor.copy_from_cpu(norm_img_batch)
......
......@@ -185,7 +185,6 @@ def create_predictor(args, cfg, mode):
def get_output_tensors(cfg, mode, predictor):
output_names = predictor.get_output_names()
output_tensors = []
if mode == "rec" and cfg['rec_algorithm'] in ["CRNN", "SVTR_LCNet"]:
output_name = 'softmax_0.tmp_0'
if output_name in output_names:
return [predictor.get_output_handle(output_name)]
......@@ -193,10 +192,6 @@ def get_output_tensors(cfg, mode, predictor):
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
else:
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
return output_tensors
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册