diff --git a/static/slim/prune/prune.py b/static/slim/prune/prune.py index 52b0b0c7c5519ef326242207479d9926c6febd18..a75c6d4cbc89de0df5c8150af5024b02ab83178f 100644 --- a/static/slim/prune/prune.py +++ b/static/slim/prune/prune.py @@ -203,21 +203,7 @@ def main(): assert FLAGS.prune_criterion in ['l1_norm', 'geometry_median'], \ "unsupported prune criterion {}".format(FLAGS.prune_criterion) pruner = Pruner(criterion=FLAGS.prune_criterion) - train_prog = pruner.prune( - train_prog, - fluid.global_scope(), - params=pruned_params, - ratios=pruned_ratios, - place=place, - only_graph=False)[0] - - compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel( - loss_name=loss.name, - build_strategy=build_strategy, - exec_strategy=exec_strategy) - if FLAGS.eval: - base_flops = flops(eval_prog) eval_prog = pruner.prune( eval_prog, @@ -232,6 +218,19 @@ def main(): pruned_flops)) compiled_eval_prog = fluid.CompiledProgram(eval_prog) + train_prog = pruner.prune( + train_prog, + fluid.global_scope(), + params=pruned_params, + ratios=pruned_ratios, + place=place, + only_graph=False)[0] + + compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel( + loss_name=loss.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + if FLAGS.resume_checkpoint: checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint) start_iter = checkpoint.global_step()