You need to sign in or sign up before continuing.
未验证 提交 385f9bbd 编写于 作者: G Guanghua Yu 提交者: GitHub

fix static prune bug (#2933)

上级 83364301
...@@ -203,21 +203,7 @@ def main(): ...@@ -203,21 +203,7 @@ def main():
assert FLAGS.prune_criterion in ['l1_norm', 'geometry_median'], \ assert FLAGS.prune_criterion in ['l1_norm', 'geometry_median'], \
"unsupported prune criterion {}".format(FLAGS.prune_criterion) "unsupported prune criterion {}".format(FLAGS.prune_criterion)
pruner = Pruner(criterion=FLAGS.prune_criterion) pruner = Pruner(criterion=FLAGS.prune_criterion)
train_prog = pruner.prune(
train_prog,
fluid.global_scope(),
params=pruned_params,
ratios=pruned_ratios,
place=place,
only_graph=False)[0]
compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
if FLAGS.eval: if FLAGS.eval:
base_flops = flops(eval_prog) base_flops = flops(eval_prog)
eval_prog = pruner.prune( eval_prog = pruner.prune(
eval_prog, eval_prog,
...@@ -232,6 +218,19 @@ def main(): ...@@ -232,6 +218,19 @@ def main():
pruned_flops)) pruned_flops))
compiled_eval_prog = fluid.CompiledProgram(eval_prog) compiled_eval_prog = fluid.CompiledProgram(eval_prog)
train_prog = pruner.prune(
train_prog,
fluid.global_scope(),
params=pruned_params,
ratios=pruned_ratios,
place=place,
only_graph=False)[0]
compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
if FLAGS.resume_checkpoint: if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint) checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step() start_iter = checkpoint.global_step()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册