提交 aa7e9ac3 编写于 作者: T tink2123

polish code

上级 bc1ad207
......@@ -196,32 +196,36 @@ class TextRecognizer(object):
norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list)
starttime = time.time()
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
encoder_word_pos_list = fluid.core.PaddleTensor(
encoder_word_pos_list)
gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias2_list)
inputs = [
norm_img_batch, encoder_word_pos_list, gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list, gsrm_word_pos_list
]
self.predictor.run(inputs)
norm_img_batch = norm_img_batch.copy()
if self.loss_type == "srn":
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
starttime = time.time()
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
encoder_word_pos_list = fluid.core.PaddleTensor(
encoder_word_pos_list)
gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias2_list)
inputs = [
norm_img_batch, encoder_word_pos_list,
gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list,
gsrm_word_pos_list
]
self.predictor.run(inputs)
else:
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
if self.loss_type == "ctc":
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
......
......@@ -59,10 +59,10 @@ def parse_args():
parser.add_argument("--det_sast_polygon", type=bool, default=False)
#params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='SRN')
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 64, 256")
parser.add_argument("--rec_char_type", type=str, default='en')
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=30)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
......@@ -107,11 +107,9 @@ def create_predictor(args, mode):
# use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
#config.switch_use_feed_fetch_ops(False)
config.switch_use_feed_fetch_ops(True)
config.switch_use_feed_fetch_ops(False)
predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names()
print(input_names)
for name in input_names:
input_tensor = predictor.get_input_tensor(name)
output_names = predictor.get_output_names()
......
......@@ -162,10 +162,7 @@ def main():
fluid.io.save_inference_model(
"./output/",
feeded_var_names=[
'image', 'encoder_word_pos', 'gsrm_slf_attn_bias1',
'gsrm_slf_attn_bias2', 'gsrm_word_pos'
],
feeded_var_names=['image'],
target_vars=target_var,
executor=exe,
main_program=eval_prog,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册