未验证 提交 b70a0f9b 编写于 作者: B Bai Yifan 提交者: GitHub

fix slim distillation load params (#233)

上级 1a30667d
...@@ -156,26 +156,7 @@ def main(): ...@@ -156,26 +156,7 @@ def main():
train_fetches = model.train(train_feed_vars) train_fetches = model.train(train_feed_vars)
loss = train_fetches['loss'] loss = train_fetches['loss']
fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
ignore_params = cfg.finetune_exclude_pretrained_params \
if 'finetune_exclude_pretrained_params' in cfg else []
start_iter = 0 start_iter = 0
if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe,
fluid.default_main_program(),
FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step()
elif cfg.pretrain_weights and fuse_bn and not ignore_params:
checkpoint.load_and_fusebn(exe,
fluid.default_main_program(),
cfg.pretrain_weights)
elif cfg.pretrain_weights:
checkpoint.load_params(
exe,
fluid.default_main_program(),
cfg.pretrain_weights,
ignore_params=ignore_params)
train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) * train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) *
devices_num, cfg) devices_num, cfg)
train_loader.set_sample_list_generator(train_reader, place) train_loader.set_sample_list_generator(train_reader, place)
...@@ -283,11 +264,28 @@ def main(): ...@@ -283,11 +264,28 @@ def main():
opt.minimize(loss) opt.minimize(loss)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
fuse_bn = getattr(model.backbone, 'norm_type', None) == 'affine_channel'
ignore_params = cfg.finetune_exclude_pretrained_params \
if 'finetune_exclude_pretrained_params' in cfg else []
if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe,
fluid.default_main_program(),
FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step()
elif cfg.pretrain_weights and fuse_bn and not ignore_params:
checkpoint.load_and_fusebn(exe,
fluid.default_main_program(),
cfg.pretrain_weights)
elif cfg.pretrain_weights:
checkpoint.load_params(
exe,
fluid.default_main_program(),
cfg.pretrain_weights,
ignore_params=ignore_params)
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False build_strategy.fuse_all_reduce_ops = False
build_strategy.fuse_all_optimizer_ops = False build_strategy.fuse_all_optimizer_ops = False
build_strategy.fuse_elewise_add_act_ops = True
# only enable sync_bn in multi GPU devices # only enable sync_bn in multi GPU devices
sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn' sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \ build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册