提交 91b79f69 编写于 作者: 文幕地方's avatar 文幕地方

pair param with key when load trained model params

上级 5fd13ab8
......@@ -56,9 +56,25 @@ def load_model(config, model, optimizer=None):
if checkpoints:
if checkpoints.endswith('pdparams'):
checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdopt"), \
f"The {checkpoints}.pdopt does not exists!"
load_pretrained_params(model, checkpoints)
assert os.path.exists(checkpoints + ".pdparams"), \
f"The {checkpoints}.pdparams does not exists!"
# load params from trained model
params = paddle.load(checkpoints + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
for key, value in state_dict.items():
if key not in params:
logger.warning(f"{key} not in loaded params {params.keys()} !")
pre_value = params[key]
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
logger.warning(
f"The shape of model params {key} {value.shape} not matched with loaded params shape {pre_value.shape} !"
)
model.set_state_dict(new_state_dict)
optim_dict = paddle.load(checkpoints + '.pdopt')
if optimizer is not None:
optimizer.set_state_dict(optim_dict)
......@@ -92,7 +108,7 @@ def load_pretrained_params(model, path):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
logger.info(
logger.warning(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册