提交 28e19603 编写于 作者: W wangguanzhong 提交者: GitHub

Update prune interface (#3336)

上级 0be30344
...@@ -106,12 +106,12 @@ def prune_feed_vars(feeded_var_names, target_vars, prog): ...@@ -106,12 +106,12 @@ def prune_feed_vars(feeded_var_names, target_vars, prog):
""" """
exist_var_names = [] exist_var_names = []
prog = prog.clone() prog = prog.clone()
prog = prog._prune(targets=target_vars) prog = prog._prune(feeded_var_names, targets=target_vars)
global_block = prog.global_block() global_block = prog.global_block()
for name in feeded_var_names: for name in feeded_var_names:
try: try:
v = global_block.var(name) v = global_block.var(name)
exist_var_names.append(v.name) exist_var_names.append(v.name.encode('utf-8'))
except Exception: except Exception:
logger.info('save_inference_model pruned unused feed ' logger.info('save_inference_model pruned unused feed '
'variables {}'.format(name)) 'variables {}'.format(name))
...@@ -127,8 +127,9 @@ def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog): ...@@ -127,8 +127,9 @@ def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog):
feeded_var_names = prune_feed_vars(feeded_var_names, target_vars, feeded_var_names = prune_feed_vars(feeded_var_names, target_vars,
infer_prog) infer_prog)
logger.info("Save inference model to {}, input: {}, output: " logger.info("Save inference model to {}, input: {}, output: "
"{}...".format(save_dir, feeded_var_names, "{}...".format(save_dir, feeded_var_names, [
[var.name for var in target_vars])) var.name.encode('utf-8') for var in target_vars
]))
fluid.io.save_inference_model( fluid.io.save_inference_model(
save_dir, save_dir,
feeded_var_names=feeded_var_names, feeded_var_names=feeded_var_names,
...@@ -216,7 +217,7 @@ def main(): ...@@ -216,7 +217,7 @@ def main():
from tb_paddle import SummaryWriter from tb_paddle import SummaryWriter
tb_writer = SummaryWriter(FLAGS.tb_log_dir) tb_writer = SummaryWriter(FLAGS.tb_log_dir)
tb_image_step = 0 tb_image_step = 0
tb_image_frame = 0 # each frame can display ten pictures at most. tb_image_frame = 0 # each frame can display ten pictures at most.
imid2path = reader.imid2path imid2path = reader.imid2path
for iter_id, data in enumerate(reader()): for iter_id, data in enumerate(reader()):
...@@ -257,7 +258,7 @@ def main(): ...@@ -257,7 +258,7 @@ def main():
int(im_id), catid2name, int(im_id), catid2name,
FLAGS.draw_threshold, bbox_results, FLAGS.draw_threshold, bbox_results,
mask_results) mask_results)
# use tb-paddle to log image with bbox # use tb-paddle to log image with bbox
if FLAGS.use_tb: if FLAGS.use_tb:
infer_image_np = np.array(image) infer_image_np = np.array(image)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册