提交 49958dca 编写于 作者: W WenmuZhou

适配rc版本

上级 a414dd86
...@@ -89,7 +89,8 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): ...@@ -89,7 +89,8 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
"Given dir {}.pdparams not exist.".format(checkpoints) "Given dir {}.pdparams not exist.".format(checkpoints)
assert os.path.exists(checkpoints + ".pdopt"), \ assert os.path.exists(checkpoints + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(checkpoints) "Given dir {}.pdopt not exist.".format(checkpoints)
para_dict, opti_dict = paddle.load(checkpoints) para_dict = paddle.load(checkpoints + '.pdparams')
opti_dict = paddle.load(checkpoints + '.pdopt')
model.set_dict(para_dict) model.set_dict(para_dict)
if optimizer is not None: if optimizer is not None:
optimizer.set_state_dict(opti_dict) optimizer.set_state_dict(opti_dict)
...@@ -133,8 +134,8 @@ def save_model(net, ...@@ -133,8 +134,8 @@ def save_model(net,
""" """
_mkdir_if_not_exist(model_path, logger) _mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix) model_prefix = os.path.join(model_path, prefix)
paddle.save(net.state_dict(), model_prefix) paddle.save(net.state_dict(), model_prefix + '.pdparams')
paddle.save(optimizer.state_dict(), model_prefix) paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
# save metric and config # save metric and config
with open(model_prefix + '.states', 'wb') as f: with open(model_prefix + '.states', 'wb') as f:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册