diff --git a/PaddleCV/PaddleDetection/ppdet/utils/checkpoint.py b/PaddleCV/PaddleDetection/ppdet/utils/checkpoint.py index 689059a4c17fb23e1a7888c73052117fb3348af3..4a20ef7256849bb012de6f7890291131152b3a73 100644 --- a/PaddleCV/PaddleDetection/ppdet/utils/checkpoint.py +++ b/PaddleCV/PaddleDetection/ppdet/utils/checkpoint.py @@ -83,6 +83,23 @@ def load_checkpoint(exe, prog, path): fluid.io.load_persistables(exe, path, prog) +def global_step(scope=None): + """ + Load global step in scope. + Args: + scope (fluid.Scope): load global step from which scope. If None, + from default global_scope(). + + Returns: + global step: int. + """ + if scope is None: + scope = fluid.global_scope() + v = scope.find_var('@LR_DECAY_COUNTER@') + step = np.array(v.get_tensor())[0] if v else 0 + return step + + def save(exe, prog, path): """ Load model from the given path. diff --git a/PaddleCV/PaddleDetection/tools/train.py b/PaddleCV/PaddleDetection/tools/train.py index 037decdcc4139ca7336d6bb9df1772d09564033f..274cdff9fd94c93e690f352e51900254b5b8c13f 100644 --- a/PaddleCV/PaddleDetection/tools/train.py +++ b/PaddleCV/PaddleDetection/tools/train.py @@ -137,8 +137,10 @@ def main(): exe.run(startup_prog) freeze_bn = getattr(model.backbone, 'freeze_norm', False) + start_iter = 0 if FLAGS.resume_checkpoint: checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint) + start_iter = checkpoint.global_step() elif cfg.pretrain_weights and freeze_bn: checkpoint.load_and_fusebn(exe, train_prog, cfg.pretrain_weights) elif cfg.pretrain_weights: @@ -151,7 +153,7 @@ def main(): cfg_name = os.path.basename(FLAGS.config).split('.')[0] save_dir = os.path.join(cfg.save_dir, cfg_name) - for it in range(cfg.max_iters): + for it in range(start_iter, cfg.max_iters): start_time = end_time end_time = time.time() outs = exe.run(train_compile_program, fetch_list=train_values)