未验证 提交 211aaf02 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix prune resume (#223)

上级 8dae8b5e
...@@ -171,10 +171,7 @@ def main(): ...@@ -171,10 +171,7 @@ def main():
fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel' fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
start_iter = 0 start_iter = 0
if FLAGS.resume_checkpoint: if cfg.pretrain_weights:
checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step()
elif cfg.pretrain_weights:
checkpoint.load_params(exe, train_prog, cfg.pretrain_weights) checkpoint.load_params(exe, train_prog, cfg.pretrain_weights)
pruned_params = FLAGS.pruned_params pruned_params = FLAGS.pruned_params
...@@ -220,6 +217,10 @@ def main(): ...@@ -220,6 +217,10 @@ def main():
pruned_flops)) pruned_flops))
compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog)
if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step()
train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) * train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) *
devices_num, cfg) devices_num, cfg)
train_loader.set_sample_list_generator(train_reader, place) train_loader.set_sample_list_generator(train_reader, place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册