diff --git a/slim/distillation/distill.py b/slim/distillation/distill.py index cbc9b823d93ac1fbd8c5a7d9d5c19e6f515beff0..136507f6d825e0187c318ad1a2b2b6a66a51da09 100644 --- a/slim/distillation/distill.py +++ b/slim/distillation/distill.py @@ -156,26 +156,7 @@ def main(): train_fetches = model.train(train_feed_vars) loss = train_fetches['loss'] - fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel' - ignore_params = cfg.finetune_exclude_pretrained_params \ - if 'finetune_exclude_pretrained_params' in cfg else [] start_iter = 0 - if FLAGS.resume_checkpoint: - checkpoint.load_checkpoint(exe, - fluid.default_main_program(), - FLAGS.resume_checkpoint) - start_iter = checkpoint.global_step() - elif cfg.pretrain_weights and fuse_bn and not ignore_params: - checkpoint.load_and_fusebn(exe, - fluid.default_main_program(), - cfg.pretrain_weights) - elif cfg.pretrain_weights: - checkpoint.load_params( - exe, - fluid.default_main_program(), - cfg.pretrain_weights, - ignore_params=ignore_params) - train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num, cfg) train_loader.set_sample_list_generator(train_reader, place) @@ -283,11 +264,28 @@ def main(): opt.minimize(loss) exe.run(fluid.default_startup_program()) + fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel' + ignore_params = cfg.finetune_exclude_pretrained_params \ + if 'finetune_exclude_pretrained_params' in cfg else [] + if FLAGS.resume_checkpoint: + checkpoint.load_checkpoint(exe, + fluid.default_main_program(), + FLAGS.resume_checkpoint) + start_iter = checkpoint.global_step() + elif cfg.pretrain_weights and fuse_bn and not ignore_params: + checkpoint.load_and_fusebn(exe, + fluid.default_main_program(), + cfg.pretrain_weights) + elif cfg.pretrain_weights: + checkpoint.load_params( + exe, + fluid.default_main_program(), + cfg.pretrain_weights, + ignore_params=ignore_params) build_strategy = fluid.BuildStrategy() build_strategy.fuse_all_reduce_ops = False build_strategy.fuse_all_optimizer_ops = False - build_strategy.fuse_elewise_add_act_ops = True # only enable sync_bn in multi GPU devices sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn' build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \