提交 4d775dc9 编写于 作者: W WenmuZhou

rc版本适配

上级 44840726
......@@ -68,11 +68,11 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
param_state_dict[key] = pre_state_dict[weight_name]
else:
param_state_dict[key] = model_dict[key]
model.set_dict(param_state_dict)
model.set_state_dict(param_state_dict)
return
param_state_dict, optim_state_dict = paddle.load(path)
model.set_dict(param_state_dict)
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
return
......@@ -91,7 +91,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
"Given dir {}.pdopt not exist.".format(checkpoints)
para_dict = paddle.load(checkpoints + '.pdparams')
opti_dict = paddle.load(checkpoints + '.pdopt')
model.set_dict(para_dict)
model.set_state_dict(para_dict)
if optimizer is not None:
optimizer.set_state_dict(opti_dict)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册