未验证 提交 a3eb2de1 编写于 作者: W whs 提交者: GitHub

Fix loading checkpoint in eval script of pruning demo. (#392)

上级 18a3b86f
...@@ -86,6 +86,7 @@ def main(): ...@@ -86,6 +86,7 @@ def main():
fetches = model.eval(feed_vars, multi_scale_test) fetches = model.eval(feed_vars, multi_scale_test)
eval_prog = eval_prog.clone(True) eval_prog = eval_prog.clone(True)
exe.run(startup_prog)
reader = create_reader(cfg.EvalReader) reader = create_reader(cfg.EvalReader)
loader.set_sample_list_generator(reader, place) loader.set_sample_list_generator(reader, place)
...@@ -123,7 +124,7 @@ def main(): ...@@ -123,7 +124,7 @@ def main():
params=pruned_params, params=pruned_params,
ratios=pruned_ratios, ratios=pruned_ratios,
place=place, place=place,
only_graph=True) only_graph=False)
pruned_flops = flops(eval_prog) pruned_flops = flops(eval_prog)
logger.info("pruned FLOPS: {}".format( logger.info("pruned FLOPS: {}".format(
float(base_flops - pruned_flops) / base_flops)) float(base_flops - pruned_flops) / base_flops))
...@@ -174,7 +175,6 @@ def main(): ...@@ -174,7 +175,6 @@ def main():
sub_eval_prog = sub_eval_prog.clone(True) sub_eval_prog = sub_eval_prog.clone(True)
# load model # load model
exe.run(startup_prog)
if 'weights' in cfg: if 'weights' in cfg:
checkpoint.load_checkpoint(exe, eval_prog, cfg.weights) checkpoint.load_checkpoint(exe, eval_prog, cfg.weights)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册