diff --git a/slim/prune/eval.py b/slim/prune/eval.py index db74a217459b4e5ac84a5348b37767ec113797fd..ad4ceff77312d144546b58fe29840610ea65ef42 100644 --- a/slim/prune/eval.py +++ b/slim/prune/eval.py @@ -86,6 +86,7 @@ def main(): fetches = model.eval(feed_vars, multi_scale_test) eval_prog = eval_prog.clone(True) + exe.run(startup_prog) reader = create_reader(cfg.EvalReader) loader.set_sample_list_generator(reader, place) @@ -123,7 +124,7 @@ def main(): params=pruned_params, ratios=pruned_ratios, place=place, - only_graph=True) + only_graph=False) pruned_flops = flops(eval_prog) logger.info("pruned FLOPS: {}".format( float(base_flops - pruned_flops) / base_flops)) @@ -174,7 +175,6 @@ def main(): sub_eval_prog = sub_eval_prog.clone(True) # load model - exe.run(startup_prog) if 'weights' in cfg: checkpoint.load_checkpoint(exe, eval_prog, cfg.weights)