提交 da723385 编写于 作者: Q qingqing01 提交者: GitHub

Enhance resume training (#2612)

上级 fcf80d92
......@@ -83,6 +83,23 @@ def load_checkpoint(exe, prog, path):
fluid.io.load_persistables(exe, path, prog)
def global_step(scope=None):
"""
Load global step in scope.
Args:
scope (fluid.Scope): load global step from which scope. If None,
from default global_scope().
Returns:
global step: int.
"""
if scope is None:
scope = fluid.global_scope()
v = scope.find_var('@LR_DECAY_COUNTER@')
step = np.array(v.get_tensor())[0] if v else 0
return step
def save(exe, prog, path):
"""
Load model from the given path.
......
......@@ -137,8 +137,10 @@ def main():
exe.run(startup_prog)
freeze_bn = getattr(model.backbone, 'freeze_norm', False)
start_iter = 0
if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step()
elif cfg.pretrain_weights and freeze_bn:
checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights)
elif cfg.pretrain_weights:
......@@ -151,7 +153,7 @@ def main():
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(cfg.save_dir, cfg_name)
for it in range(cfg.max_iters):
for it in range(start_iter, cfg.max_iters):
start_time = end_time
end_time = time.time()
outs = exe.run(train_compile_program, fetch_list=train_values)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册