From 28e19603460bb60cbbe1a1fa55e953fbec2909db Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Tue, 17 Sep 2019 09:36:40 +0800 Subject: [PATCH] Update prune interface (#3336) --- tools/infer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tools/infer.py b/tools/infer.py index 00b2749c9..383d16d03 100644 --- a/tools/infer.py +++ b/tools/infer.py @@ -106,12 +106,12 @@ def prune_feed_vars(feeded_var_names, target_vars, prog): """ exist_var_names = [] prog = prog.clone() - prog = prog._prune(targets=target_vars) + prog = prog._prune(feeded_var_names, targets=target_vars) global_block = prog.global_block() for name in feeded_var_names: try: v = global_block.var(name) - exist_var_names.append(v.name) + exist_var_names.append(v.name.encode('utf-8')) except Exception: logger.info('save_inference_model pruned unused feed ' 'variables {}'.format(name)) @@ -127,8 +127,9 @@ def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog): feeded_var_names = prune_feed_vars(feeded_var_names, target_vars, infer_prog) logger.info("Save inference model to {}, input: {}, output: " - "{}...".format(save_dir, feeded_var_names, - [var.name for var in target_vars])) + "{}...".format(save_dir, feeded_var_names, [ + var.name.encode('utf-8') for var in target_vars + ])) fluid.io.save_inference_model( save_dir, feeded_var_names=feeded_var_names, @@ -216,7 +217,7 @@ def main(): from tb_paddle import SummaryWriter tb_writer = SummaryWriter(FLAGS.tb_log_dir) tb_image_step = 0 - tb_image_frame = 0 # each frame can display ten pictures at most. + tb_image_frame = 0 # each frame can display ten pictures at most. imid2path = reader.imid2path for iter_id, data in enumerate(reader()): @@ -257,7 +258,7 @@ def main(): int(im_id), catid2name, FLAGS.draw_threshold, bbox_results, mask_results) - + # use tb-paddle to log image with bbox if FLAGS.use_tb: infer_image_np = np.array(image) -- GitLab