From 385f9bbd62252d07aabb0fbd49cf33b2597d0ca0 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 11 May 2021 11:32:38 +0800 Subject: [PATCH] fix static prune bug (#2933) --- static/slim/prune/prune.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/static/slim/prune/prune.py b/static/slim/prune/prune.py index 52b0b0c7c..a75c6d4cb 100644 --- a/static/slim/prune/prune.py +++ b/static/slim/prune/prune.py @@ -203,21 +203,7 @@ def main(): assert FLAGS.prune_criterion in ['l1_norm', 'geometry_median'], \ "unsupported prune criterion {}".format(FLAGS.prune_criterion) pruner = Pruner(criterion=FLAGS.prune_criterion) - train_prog = pruner.prune( - train_prog, - fluid.global_scope(), - params=pruned_params, - ratios=pruned_ratios, - place=place, - only_graph=False)[0] - - compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel( - loss_name=loss.name, - build_strategy=build_strategy, - exec_strategy=exec_strategy) - if FLAGS.eval: - base_flops = flops(eval_prog) eval_prog = pruner.prune( eval_prog, @@ -232,6 +218,19 @@ def main(): pruned_flops)) compiled_eval_prog = fluid.CompiledProgram(eval_prog) + train_prog = pruner.prune( + train_prog, + fluid.global_scope(), + params=pruned_params, + ratios=pruned_ratios, + place=place, + only_graph=False)[0] + + compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel( + loss_name=loss.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + if FLAGS.resume_checkpoint: checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint) start_iter = checkpoint.global_step() -- GitLab