提交 aa7e9ac3 编写于 作者: T tink2123

polish code

上级 bc1ad207
...@@ -196,32 +196,36 @@ class TextRecognizer(object): ...@@ -196,32 +196,36 @@ class TextRecognizer(object):
norm_img_batch.append(norm_img[0]) norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch, axis=0) norm_img_batch = np.concatenate(norm_img_batch, axis=0)
norm_img_batch = norm_img_batch.copy()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
if self.loss_type == "srn":
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list) gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list) gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
starttime = time.time() starttime = time.time()
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
encoder_word_pos_list = fluid.core.PaddleTensor( encoder_word_pos_list = fluid.core.PaddleTensor(
encoder_word_pos_list) encoder_word_pos_list)
gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list) gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor( gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias1_list) gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor( gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias2_list) gsrm_slf_attn_bias2_list)
inputs = [ inputs = [
norm_img_batch, encoder_word_pos_list, gsrm_slf_attn_bias1_list, norm_img_batch, encoder_word_pos_list,
gsrm_slf_attn_bias2_list, gsrm_word_pos_list gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list,
] gsrm_word_pos_list
]
self.predictor.run(inputs)
self.predictor.run(inputs)
else:
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
if self.loss_type == "ctc": if self.loss_type == "ctc":
rec_idx_batch = self.output_tensors[0].copy_to_cpu() rec_idx_batch = self.output_tensors[0].copy_to_cpu()
......
...@@ -59,10 +59,10 @@ def parse_args(): ...@@ -59,10 +59,10 @@ def parse_args():
parser.add_argument("--det_sast_polygon", type=bool, default=False) parser.add_argument("--det_sast_polygon", type=bool, default=False)
#params for text recognizer #params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='SRN') parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 64, 256") parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='en') parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=30) parser.add_argument("--rec_batch_num", type=int, default=30)
parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument( parser.add_argument(
...@@ -107,11 +107,9 @@ def create_predictor(args, mode): ...@@ -107,11 +107,9 @@ def create_predictor(args, mode):
# use zero copy # use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
#config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
config.switch_use_feed_fetch_ops(True)
predictor = create_paddle_predictor(config) predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
print(input_names)
for name in input_names: for name in input_names:
input_tensor = predictor.get_input_tensor(name) input_tensor = predictor.get_input_tensor(name)
output_names = predictor.get_output_names() output_names = predictor.get_output_names()
......
...@@ -162,10 +162,7 @@ def main(): ...@@ -162,10 +162,7 @@ def main():
fluid.io.save_inference_model( fluid.io.save_inference_model(
"./output/", "./output/",
feeded_var_names=[ feeded_var_names=['image'],
'image', 'encoder_word_pos', 'gsrm_slf_attn_bias1',
'gsrm_slf_attn_bias2', 'gsrm_word_pos'
],
target_vars=target_var, target_vars=target_var,
executor=exe, executor=exe,
main_program=eval_prog, main_program=eval_prog,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册