未验证 提交 9005e080 编写于 作者: W Walter 提交者: GitHub

Merge pull request #1938 from HydrogenSulfate/fix_resume_opt

Fix resume opt
...@@ -105,7 +105,8 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): ...@@ -105,7 +105,8 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
net.set_state_dict(para_dict) net.set_state_dict(para_dict)
loss.set_state_dict(para_dict) loss.set_state_dict(para_dict)
for i in range(len(optimizer)): for i in range(len(optimizer)):
optimizer[i].set_state_dict(opti_dict) optimizer[i].set_state_dict(opti_dict[i] if isinstance(
opti_dict, list) else opti_dict)
logger.info("Finish load checkpoints from {}".format(checkpoints)) logger.info("Finish load checkpoints from {}".format(checkpoints))
return metric_dict return metric_dict
...@@ -117,7 +118,7 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): ...@@ -117,7 +118,7 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
else: # common load else: # common load
load_dygraph_pretrain(net, path=pretrained_model) load_dygraph_pretrain(net, path=pretrained_model)
logger.info("Finish load pretrained model from {}".format( logger.info("Finish load pretrained model from {}".format(
pretrained_model)) pretrained_model))
def save_model(net, def save_model(net,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册