未验证 提交 c2702229 编写于 作者: Z zhoujun 提交者: GitHub

add global_step to .states files (#2566)

Co-authored-by: Nlittletomatodonkey <2120160898@bit.edu.cn>
上级 8e4b2138
...@@ -121,7 +121,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): ...@@ -121,7 +121,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
return best_model_dict return best_model_dict
def save_model(net, def save_model(model,
optimizer, optimizer,
model_path, model_path,
logger, logger,
...@@ -133,7 +133,7 @@ def save_model(net, ...@@ -133,7 +133,7 @@ def save_model(net,
""" """
_mkdir_if_not_exist(model_path, logger) _mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix) model_prefix = os.path.join(model_path, prefix)
paddle.save(net.state_dict(), model_prefix + '.pdparams') paddle.save(model.state_dict(), model_prefix + '.pdparams')
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
# save metric and config # save metric and config
......
...@@ -159,6 +159,8 @@ def train(config, ...@@ -159,6 +159,8 @@ def train(config,
eval_batch_step = config['Global']['eval_batch_step'] eval_batch_step = config['Global']['eval_batch_step']
global_step = 0 global_step = 0
if 'global_step' in pre_best_model_dict:
global_step = pre_best_model_dict['global_step']
start_eval_step = 0 start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2: if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0] start_eval_step = eval_batch_step[0]
...@@ -285,7 +287,8 @@ def train(config, ...@@ -285,7 +287,8 @@ def train(config,
is_best=True, is_best=True,
prefix='best_accuracy', prefix='best_accuracy',
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch,
global_step=global_step)
best_str = 'best metric, {}'.format(', '.join([ best_str = 'best metric, {}'.format(', '.join([
'{}: {}'.format(k, v) for k, v in best_model_dict.items() '{}: {}'.format(k, v) for k, v in best_model_dict.items()
])) ]))
...@@ -307,7 +310,8 @@ def train(config, ...@@ -307,7 +310,8 @@ def train(config,
is_best=False, is_best=False,
prefix='latest', prefix='latest',
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch,
global_step=global_step)
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
save_model( save_model(
model, model,
...@@ -317,7 +321,8 @@ def train(config, ...@@ -317,7 +321,8 @@ def train(config,
is_best=False, is_best=False,
prefix='iter_epoch_{}'.format(epoch), prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch,
global_step=global_step)
best_str = 'best metric, {}'.format(', '.join( best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
logger.info(best_str) logger.info(best_str)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册