From da723385dc40c0af7916f178398685c0a1cd7651 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Sat, 29 Jun 2019 12:01:29 +0800 Subject: [PATCH] Enhance resume training (#2612) --- ppdet/utils/checkpoint.py | 17 +++++++++++++++++ tools/train.py | 4 +++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index 689059a4c..4a20ef725 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -83,6 +83,23 @@ def load_checkpoint(exe, prog, path): fluid.io.load_persistables(exe, path, prog) +def global_step(scope=None): + """ + Load global step in scope. + Args: + scope (fluid.Scope): load global step from which scope. If None, + from default global_scope(). + + Returns: + global step: int. + """ + if scope is None: + scope = fluid.global_scope() + v = scope.find_var('@LR_DECAY_COUNTER@') + step = np.array(v.get_tensor())[0] if v else 0 + return step + + def save(exe, prog, path): """ Load model from the given path. diff --git a/tools/train.py b/tools/train.py index 037decdcc..274cdff9f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -137,8 +137,10 @@ def main(): exe.run(startup_prog) freeze_bn = getattr(model.backbone, 'freeze_norm', False) + start_iter = 0 if FLAGS.resume_checkpoint: checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint) + start_iter = checkpoint.global_step() elif cfg.pretrain_weights and freeze_bn: checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights) elif cfg.pretrain_weights: @@ -151,7 +153,7 @@ def main(): cfg_name = os.path.basename(FLAGS.config).split('.')[0] save_dir = os.path.join(cfg.save_dir, cfg_name) - for it in range(cfg.max_iters): + for it in range(start_iter, cfg.max_iters): start_time = end_time end_time = time.time() outs = exe.run(train_compile_program, fetch_list=train_values) -- GitLab