From c27022294e88db29ab309b9056c6e28853bd745e Mon Sep 17 00:00:00 2001 From: zhoujun Date: Mon, 26 Apr 2021 21:13:21 -0500 Subject: [PATCH] add global_step to .states files (#2566) Co-authored-by: littletomatodonkey <2120160898@bit.edu.cn> --- ppocr/utils/save_load.py | 4 ++-- tools/program.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 02814d62..e69b330f 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -121,7 +121,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): return best_model_dict -def save_model(net, +def save_model(model, optimizer, model_path, logger, @@ -133,7 +133,7 @@ def save_model(net, """ _mkdir_if_not_exist(model_path, logger) model_prefix = os.path.join(model_path, prefix) - paddle.save(net.state_dict(), model_prefix + '.pdparams') + paddle.save(model.state_dict(), model_prefix + '.pdparams') paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') # save metric and config diff --git a/tools/program.py b/tools/program.py index d4c35838..0aa3307e 100755 --- a/tools/program.py +++ b/tools/program.py @@ -159,6 +159,8 @@ def train(config, eval_batch_step = config['Global']['eval_batch_step'] global_step = 0 + if 'global_step' in pre_best_model_dict: + global_step = pre_best_model_dict['global_step'] start_eval_step = 0 if type(eval_batch_step) == list and len(eval_batch_step) >= 2: start_eval_step = eval_batch_step[0] @@ -285,7 +287,8 @@ def train(config, is_best=True, prefix='best_accuracy', best_model_dict=best_model_dict, - epoch=epoch) + epoch=epoch, + global_step=global_step) best_str = 'best metric, {}'.format(', '.join([ '{}: {}'.format(k, v) for k, v in best_model_dict.items() ])) @@ -307,7 +310,8 @@ def train(config, is_best=False, prefix='latest', best_model_dict=best_model_dict, - epoch=epoch) + epoch=epoch, + global_step=global_step) if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: save_model( model, @@ -317,7 +321,8 @@ def train(config, is_best=False, prefix='iter_epoch_{}'.format(epoch), best_model_dict=best_model_dict, - epoch=epoch) + epoch=epoch, + global_step=global_step) best_str = 'best metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) logger.info(best_str) -- GitLab