diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 02814d6208aba7ddfa6eac338229502b18b535da..e69b330f0344321d88e7d175ae093cd9e51296aa 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -121,7 +121,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): return best_model_dict -def save_model(net, +def save_model(model, optimizer, model_path, logger, @@ -133,7 +133,7 @@ def save_model(net, """ _mkdir_if_not_exist(model_path, logger) model_prefix = os.path.join(model_path, prefix) - paddle.save(net.state_dict(), model_prefix + '.pdparams') + paddle.save(model.state_dict(), model_prefix + '.pdparams') paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') # save metric and config diff --git a/tools/program.py b/tools/program.py index d4c3583829f5946c73fde06d0838d9b4d9376858..0aa3307e38982985519e3156214c0b4991ef54bc 100755 --- a/tools/program.py +++ b/tools/program.py @@ -159,6 +159,8 @@ def train(config, eval_batch_step = config['Global']['eval_batch_step'] global_step = 0 + if 'global_step' in pre_best_model_dict: + global_step = pre_best_model_dict['global_step'] start_eval_step = 0 if type(eval_batch_step) == list and len(eval_batch_step) >= 2: start_eval_step = eval_batch_step[0] @@ -285,7 +287,8 @@ def train(config, is_best=True, prefix='best_accuracy', best_model_dict=best_model_dict, - epoch=epoch) + epoch=epoch, + global_step=global_step) best_str = 'best metric, {}'.format(', '.join([ '{}: {}'.format(k, v) for k, v in best_model_dict.items() ])) @@ -307,7 +310,8 @@ def train(config, is_best=False, prefix='latest', best_model_dict=best_model_dict, - epoch=epoch) + epoch=epoch, + global_step=global_step) if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: save_model( model, @@ -317,7 +321,8 @@ def train(config, is_best=False, prefix='iter_epoch_{}'.format(epoch), best_model_dict=best_model_dict, - epoch=epoch) + epoch=epoch, + global_step=global_step) best_str = 'best metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) logger.info(best_str)