From 64ce951d8566098381d7857b0465542ae89d53a2 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Thu, 18 Jul 2019 10:14:13 +0800 Subject: [PATCH] fix save_inference_model cannot find feed vars (#2842) * fix save_inference_model cannot find feed vars * remove comment * fix format * add comment for pruned var * add log for prune --- PaddleCV/PaddleDetection/tools/infer.py | 26 +++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/PaddleCV/PaddleDetection/tools/infer.py b/PaddleCV/PaddleDetection/tools/infer.py index 41e15500..09e440e3 100644 --- a/PaddleCV/PaddleDetection/tools/infer.py +++ b/PaddleCV/PaddleDetection/tools/infer.py @@ -82,13 +82,35 @@ def get_test_images(infer_dir, infer_img): return images +def prune_feed_vars(feeded_var_names, target_vars, prog): + """ + Filter out feed variables which are not in program, + pruned feed variables are only used in post processing + on model output, which are not used in program, such + as im_id to identify image order, im_shape to clip bbox + in image. + """ + exist_var_names = [] + prog = prog.clone() + prog = prog._prune(targets=target_vars) + global_block = prog.global_block() + for name in feeded_var_names: + try: + v = global_block.var(name) + exist_var_names.append(v.name) + except Exception: + logger.info('save_inference_model pruned unused feed ' + 'variables {}'.format(name)) + pass + return exist_var_names + + def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog): cfg_name = os.path.basename(FLAGS.config).split('.')[0] save_dir = os.path.join(FLAGS.output_dir, cfg_name) feeded_var_names = [var.name for var in feed_vars.values()] - # im_id is only used for visualize, not used in inference model - feeded_var_names.remove('im_id') target_vars = test_fetches.values() + feeded_var_names = prune_feed_vars(feeded_var_names, target_vars, infer_prog) logger.info("Save inference model to {}, input: {}, output: " "{}...".format(save_dir, feeded_var_names, [var.name for var in target_vars])) -- GitLab