提交 28e19603 编写于 作者: W wangguanzhong 提交者: GitHub

Update prune interface (#3336)

上级 0be30344
......@@ -106,12 +106,12 @@ def prune_feed_vars(feeded_var_names, target_vars, prog):
"""
exist_var_names = []
prog = prog.clone()
prog = prog._prune(targets=target_vars)
prog = prog._prune(feeded_var_names, targets=target_vars)
global_block = prog.global_block()
for name in feeded_var_names:
try:
v = global_block.var(name)
exist_var_names.append(v.name)
exist_var_names.append(v.name.encode('utf-8'))
except Exception:
logger.info('save_inference_model pruned unused feed '
'variables {}'.format(name))
......@@ -127,8 +127,9 @@ def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog):
feeded_var_names = prune_feed_vars(feeded_var_names, target_vars,
infer_prog)
logger.info("Save inference model to {}, input: {}, output: "
"{}...".format(save_dir, feeded_var_names,
[var.name for var in target_vars]))
"{}...".format(save_dir, feeded_var_names, [
var.name.encode('utf-8') for var in target_vars
]))
fluid.io.save_inference_model(
save_dir,
feeded_var_names=feeded_var_names,
......@@ -216,7 +217,7 @@ def main():
from tb_paddle import SummaryWriter
tb_writer = SummaryWriter(FLAGS.tb_log_dir)
tb_image_step = 0
tb_image_frame = 0 # each frame can display ten pictures at most.
tb_image_frame = 0 # each frame can display ten pictures at most.
imid2path = reader.imid2path
for iter_id, data in enumerate(reader()):
......@@ -257,7 +258,7 @@ def main():
int(im_id), catid2name,
FLAGS.draw_threshold, bbox_results,
mask_results)
# use tb-paddle to log image with bbox
if FLAGS.use_tb:
infer_image_np = np.array(image)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册