提交 bc1ad207 编写于 作者: T tink2123

support srn inference

上级 b626aa3e
...@@ -136,7 +136,7 @@ class RecModel(object): ...@@ -136,7 +136,7 @@ class RecModel(object):
else: else:
labels = None labels = None
loader = None loader = None
if self.char_type == "ch" and self.infer_img: if self.char_type == "ch" and self.infer_img and self.loss_type != "srn":
image_shape[-1] = -1 image_shape[-1] = -1
if self.tps != None: if self.tps != None:
logger.info( logger.info(
...@@ -172,16 +172,13 @@ class RecModel(object): ...@@ -172,16 +172,13 @@ class RecModel(object):
self.max_text_length self.max_text_length
], ],
dtype="float32") dtype="float32")
feed_list = [
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
labels = { labels = {
'encoder_word_pos': encoder_word_pos, 'encoder_word_pos': encoder_word_pos,
'gsrm_word_pos': gsrm_word_pos, 'gsrm_word_pos': gsrm_word_pos,
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2
} }
return image, labels, loader return image, labels, loader
def __call__(self, mode): def __call__(self, mode):
...@@ -218,8 +215,13 @@ class RecModel(object): ...@@ -218,8 +215,13 @@ class RecModel(object):
if self.loss_type == "ctc": if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict) predict = fluid.layers.softmax(predict)
if self.loss_type == "srn": if self.loss_type == "srn":
raise Exception( return [
"Warning! SRN does not support export model currently") image, labels, {
'decoded_out': decoded_out,
'predicts': predict
}
]
return [image, {'decoded_out': decoded_out, 'predicts': predict}] return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else: else:
predict = predicts['predict'] predict = predicts['predict']
......
...@@ -26,6 +26,7 @@ import copy ...@@ -26,6 +26,7 @@ import copy
import numpy as np import numpy as np
import math import math
import time import time
import paddle.fluid as fluid
from ppocr.utils.character import CharacterOps from ppocr.utils.character import CharacterOps
...@@ -37,18 +38,22 @@ class TextRecognizer(object): ...@@ -37,18 +38,22 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm self.rec_algorithm = args.rec_algorithm
self.text_len = args.max_text_length
char_ops_params = { char_ops_params = {
"character_type": args.rec_char_type, "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path, "character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char, "use_space_char": args.use_space_char,
"max_text_length": args.max_text_length "max_text_length": args.max_text_length
} }
if self.rec_algorithm != "RARE": if self.rec_algorithm in ["CRNN", "Rosetta", "STAR-Net"]:
char_ops_params['loss_type'] = 'ctc' char_ops_params['loss_type'] = 'ctc'
self.loss_type = 'ctc' self.loss_type = 'ctc'
else: elif self.rec_algorithm == "RARE":
char_ops_params['loss_type'] = 'attention' char_ops_params['loss_type'] = 'attention'
self.loss_type = 'attention' self.loss_type = 'attention'
elif self.rec_algorithm == "SRN":
char_ops_params['loss_type'] = 'srn'
self.loss_type = 'srn'
self.char_ops = CharacterOps(char_ops_params) self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
...@@ -71,6 +76,83 @@ class TextRecognizer(object): ...@@ -71,6 +76,83 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image padding_im[:, :, 0:resized_w] = resized_image
return padding_im return padding_im
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length,
char_num):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self,
img,
image_shape,
num_heads,
max_text_length,
char_ops=None):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
char_num = char_ops.get_char_num()
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length, char_num)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def __call__(self, img_list): def __call__(self, img_list):
img_num = len(img_list) img_num = len(img_list)
# Calculate the aspect ratio of all text bars # Calculate the aspect ratio of all text bars
...@@ -80,7 +162,7 @@ class TextRecognizer(object): ...@@ -80,7 +162,7 @@ class TextRecognizer(object):
# Sorting can speed up the recognition process # Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list)) indices = np.argsort(np.array(width_list))
# rec_res = [] #rec_res = []
rec_res = [['', 0.0]] * img_num rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num batch_num = self.rec_batch_num
predict_time = 0 predict_time = 0
...@@ -94,16 +176,52 @@ class TextRecognizer(object): ...@@ -94,16 +176,52 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) if self.loss_type != "srn":
norm_img = self.resize_norm_img(img_list[indices[ino]], norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio) max_wh_ratio)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch) else:
norm_img_batch = norm_img_batch.copy() norm_img = self.process_image_srn(img_list[indices[ino]],
self.rec_image_shape, 8,
25, self.char_ops)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list)
starttime = time.time() starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run() norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
encoder_word_pos_list = fluid.core.PaddleTensor(
encoder_word_pos_list)
gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias2_list)
inputs = [
norm_img_batch, encoder_word_pos_list, gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list, gsrm_word_pos_list
]
self.predictor.run(inputs)
if self.loss_type == "ctc": if self.loss_type == "ctc":
rec_idx_batch = self.output_tensors[0].copy_to_cpu() rec_idx_batch = self.output_tensors[0].copy_to_cpu()
...@@ -128,6 +246,26 @@ class TextRecognizer(object): ...@@ -128,6 +246,26 @@ class TextRecognizer(object):
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
# rec_res.append([preds_text, score]) # rec_res.append([preds_text, score])
rec_res[indices[beg_img_no + rno]] = [preds_text, score] rec_res[indices[beg_img_no + rno]] = [preds_text, score]
elif self.loss_type == 'srn':
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
probs = self.output_tensors[1].copy_to_cpu()
char_num = self.char_ops.get_char_num()
preds = rec_idx_batch.reshape(-1)
elapse = time.time() - starttime
predict_time += elapse
total_preds = preds.copy()
for ino in range(int(len(rec_idx_batch) / self.text_len)):
preds = total_preds[ino * self.text_len:(ino + 1) *
self.text_len]
ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != int(char_num - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
preds = preds[:valid_ind[-1] + 1]
preds_text = self.char_ops.decode(preds)
rec_res[indices[beg_img_no + ino]] = [preds_text, score]
else: else:
rec_idx_batch = self.output_tensors[0].copy_to_cpu() rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_batch = self.output_tensors[1].copy_to_cpu() predict_batch = self.output_tensors[1].copy_to_cpu()
...@@ -162,6 +300,7 @@ def main(args): ...@@ -162,6 +300,7 @@ def main(args):
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
try: try:
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
except Exception as e: except Exception as e:
......
...@@ -59,10 +59,10 @@ def parse_args(): ...@@ -59,10 +59,10 @@ def parse_args():
parser.add_argument("--det_sast_polygon", type=bool, default=False) parser.add_argument("--det_sast_polygon", type=bool, default=False)
#params for text recognizer #params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_algorithm", type=str, default='SRN')
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_image_shape", type=str, default="3, 64, 256")
parser.add_argument("--rec_char_type", type=str, default='ch') parser.add_argument("--rec_char_type", type=str, default='en')
parser.add_argument("--rec_batch_num", type=int, default=30) parser.add_argument("--rec_batch_num", type=int, default=30)
parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument( parser.add_argument(
...@@ -107,10 +107,13 @@ def create_predictor(args, mode): ...@@ -107,10 +107,13 @@ def create_predictor(args, mode):
# use zero copy # use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False) #config.switch_use_feed_fetch_ops(False)
config.switch_use_feed_fetch_ops(True)
predictor = create_paddle_predictor(config) predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0]) print(input_names)
for name in input_names:
input_tensor = predictor.get_input_tensor(name)
output_names = predictor.get_output_names() output_names = predictor.get_output_names()
output_tensors = [] output_tensors = []
for output_name in output_names: for output_name in output_names:
......
...@@ -145,7 +145,7 @@ def main(): ...@@ -145,7 +145,7 @@ def main():
preds = preds.reshape(-1) preds = preds.reshape(-1)
probs = np.array(predict[1]) probs = np.array(predict[1])
ind = np.argmax(probs, axis=1) ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != int(char_num-1))[0] valid_ind = np.where(preds != int(char_num - 1))[0]
if len(valid_ind) == 0: if len(valid_ind) == 0:
continue continue
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
...@@ -162,7 +162,10 @@ def main(): ...@@ -162,7 +162,10 @@ def main():
fluid.io.save_inference_model( fluid.io.save_inference_model(
"./output/", "./output/",
feeded_var_names=['image'], feeded_var_names=[
'image', 'encoder_word_pos', 'gsrm_slf_attn_bias1',
'gsrm_slf_attn_bias2', 'gsrm_word_pos'
],
target_vars=target_var, target_vars=target_var,
executor=exe, executor=exe,
main_program=eval_prog, main_program=eval_prog,
......
...@@ -208,10 +208,19 @@ def build_export(config, main_prog, startup_prog): ...@@ -208,10 +208,19 @@ def build_export(config, main_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
func_infor = config['Architecture']['function'] func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config) model = create_module(func_infor)(params=config)
loss_type = config['Global']['loss_type']
if loss_type == "srn":
image, others, outputs = model(mode='export')
else:
image, outputs = model(mode='export') image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs.keys()]) fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name] fetches_var = [outputs[name] for name in fetches_var_name]
if loss_type == "srn":
others_var_names = sorted([name for name in others.keys()])
feeded_var_names = [image.name] + others_var_names
else:
feeded_var_names = [image.name] feeded_var_names = [image.name]
target_vars = fetches_var target_vars = fetches_var
return feeded_var_names, target_vars, fetches_var_name return feeded_var_names, target_vars, fetches_var_name
...@@ -409,7 +418,9 @@ def preprocess(): ...@@ -409,7 +418,9 @@ def preprocess():
check_gpu(use_gpu) check_gpu(use_gpu)
alg = config['Global']['algorithm'] alg = config['Global']['algorithm']
assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'] assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'
]
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']: if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
config['Global']['char_ops'] = CharacterOps(config['Global']) config['Global']['char_ops'] = CharacterOps(config['Global'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册