diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 8f676294c3079f3a686e44d86b38865ee47c270f..e592ccdc7288e5688adb0bd6b08ecc6f031c8933 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -196,32 +196,36 @@ class TextRecognizer(object): norm_img_batch.append(norm_img[0]) norm_img_batch = np.concatenate(norm_img_batch, axis=0) - - encoder_word_pos_list = np.concatenate(encoder_word_pos_list) - - gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) - - gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list) - - gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list) - - starttime = time.time() - - norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) - encoder_word_pos_list = fluid.core.PaddleTensor( - encoder_word_pos_list) - gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list) - gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor( - gsrm_slf_attn_bias1_list) - gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor( - gsrm_slf_attn_bias2_list) - - inputs = [ - norm_img_batch, encoder_word_pos_list, gsrm_slf_attn_bias1_list, - gsrm_slf_attn_bias2_list, gsrm_word_pos_list - ] - - self.predictor.run(inputs) + norm_img_batch = norm_img_batch.copy() + + if self.loss_type == "srn": + encoder_word_pos_list = np.concatenate(encoder_word_pos_list) + gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) + gsrm_slf_attn_bias1_list = np.concatenate( + gsrm_slf_attn_bias1_list) + gsrm_slf_attn_bias2_list = np.concatenate( + gsrm_slf_attn_bias2_list) + starttime = time.time() + + norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) + encoder_word_pos_list = fluid.core.PaddleTensor( + encoder_word_pos_list) + gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list) + gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor( + gsrm_slf_attn_bias1_list) + gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor( + gsrm_slf_attn_bias2_list) + + inputs = [ + norm_img_batch, encoder_word_pos_list, + gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list, + gsrm_word_pos_list + ] + + self.predictor.run(inputs) + else: + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.zero_copy_run() if self.loss_type == "ctc": rec_idx_batch = self.output_tensors[0].copy_to_cpu() diff --git a/tools/infer/utility.py b/tools/infer/utility.py index e4e2a9fcca52c033a4ed49e15a4e78df5dc8aaf1..ccf6e32b43589c22a3f63ec242de431872f15283 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -59,10 +59,10 @@ def parse_args(): parser.add_argument("--det_sast_polygon", type=bool, default=False) #params for text recognizer - parser.add_argument("--rec_algorithm", type=str, default='SRN') + parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_model_dir", type=str) - parser.add_argument("--rec_image_shape", type=str, default="3, 64, 256") - parser.add_argument("--rec_char_type", type=str, default='en') + parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") + parser.add_argument("--rec_char_type", type=str, default='ch') parser.add_argument("--rec_batch_num", type=int, default=30) parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument( @@ -107,11 +107,9 @@ def create_predictor(args, mode): # use zero copy config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") - #config.switch_use_feed_fetch_ops(False) - config.switch_use_feed_fetch_ops(True) + config.switch_use_feed_fetch_ops(False) predictor = create_paddle_predictor(config) input_names = predictor.get_input_names() - print(input_names) for name in input_names: input_tensor = predictor.get_input_tensor(name) output_names = predictor.get_output_names() diff --git a/tools/infer_rec.py b/tools/infer_rec.py index e4b75e86f0883a6ab2b0baae8123ad1be9d5d634..29fc5b40a890cd6e8fa3ca7d3f0999835555d9bd 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -162,10 +162,7 @@ def main(): fluid.io.save_inference_model( "./output/", - feeded_var_names=[ - 'image', 'encoder_word_pos', 'gsrm_slf_attn_bias1', - 'gsrm_slf_attn_bias2', 'gsrm_word_pos' - ], + feeded_var_names=['image'], target_vars=target_var, executor=exe, main_program=eval_prog,