diff --git a/slim/prune/prune.py b/slim/prune/prune.py index d5fd5f8b35309405f00bfa428650417bd9d20489..606b29e7cc76ac6538737acde2fce2bb99bfd241 100644 --- a/slim/prune/prune.py +++ b/slim/prune/prune.py @@ -171,10 +171,7 @@ def main(): fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel' start_iter = 0 - if FLAGS.resume_checkpoint: - checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint) - start_iter = checkpoint.global_step() - elif cfg.pretrain_weights: + if cfg.pretrain_weights: checkpoint.load_params(exe, train_prog, cfg.pretrain_weights) pruned_params = FLAGS.pruned_params @@ -220,6 +217,10 @@ def main(): pruned_flops)) compiled_eval_prog = fluid.compiler.CompiledProgram(eval_prog) + if FLAGS.resume_checkpoint: + checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint) + start_iter = checkpoint.global_step() + train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num, cfg) train_loader.set_sample_list_generator(train_reader, place)