提交 8ffa50fe 编写于 作者: 文幕地方's avatar 文幕地方

opt load model

上级 134e7070
...@@ -57,7 +57,7 @@ def load_model(config, model, optimizer=None): ...@@ -57,7 +57,7 @@ def load_model(config, model, optimizer=None):
if checkpoints.endswith('.pdparams'): if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '') checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdparams"), \ assert os.path.exists(checkpoints + ".pdparams"), \
"The {}.pdparams does not exists!".format(checkpoints) "The {}.pdparams is not exists!".format(checkpoints)
# load params from trained model # load params from trained model
params = paddle.load(checkpoints + '.pdparams') params = paddle.load(checkpoints + '.pdparams')
...@@ -67,6 +67,7 @@ def load_model(config, model, optimizer=None): ...@@ -67,6 +67,7 @@ def load_model(config, model, optimizer=None):
if key not in params: if key not in params:
logger.warning("{} not in loaded params {} !".format( logger.warning("{} not in loaded params {} !".format(
key, params.keys())) key, params.keys()))
continue
pre_value = params[key] pre_value = params[key]
if list(value.shape) == list(pre_value.shape): if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value new_state_dict[key] = pre_value
...@@ -76,9 +77,14 @@ def load_model(config, model, optimizer=None): ...@@ -76,9 +77,14 @@ def load_model(config, model, optimizer=None):
format(key, value.shape, pre_value.shape)) format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
optim_dict = paddle.load(checkpoints + '.pdopt')
if optimizer is not None: if optimizer is not None:
optimizer.set_state_dict(optim_dict) if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
if os.path.exists(checkpoints + '.states'): if os.path.exists(checkpoints + '.states'):
with open(checkpoints + '.states', 'rb') as f: with open(checkpoints + '.states', 'rb') as f:
...@@ -100,7 +106,7 @@ def load_pretrained_params(model, path): ...@@ -100,7 +106,7 @@ def load_pretrained_params(model, path):
if path.endswith('.pdparams'): if path.endswith('.pdparams'):
path = path.replace('.pdparams', '') path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \ assert os.path.exists(path + ".pdparams"), \
"The {}.pdparams does not exists!".format(path) "The {}.pdparams is not exists!".format(path)
params = paddle.load(path + '.pdparams') params = paddle.load(path + '.pdparams')
state_dict = model.state_dict() state_dict = model.state_dict()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册