未验证 提交 2e9abcb9 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #4748 from WenmuZhou/tipc

pair param with key when load trained model params
...@@ -54,11 +54,28 @@ def load_model(config, model, optimizer=None): ...@@ -54,11 +54,28 @@ def load_model(config, model, optimizer=None):
pretrained_model = global_config.get('pretrained_model') pretrained_model = global_config.get('pretrained_model')
best_model_dict = {} best_model_dict = {}
if checkpoints: if checkpoints:
if checkpoints.endswith('pdparams'): if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '') checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdopt"), \ assert os.path.exists(checkpoints + ".pdparams"), \
f"The {checkpoints}.pdopt does not exists!" "The {}.pdparams does not exists!".format(checkpoints)
load_pretrained_params(model, checkpoints)
# load params from trained model
params = paddle.load(checkpoints + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
for key, value in state_dict.items():
if key not in params:
logger.warning("{} not in loaded params {} !".format(
key, params.keys()))
pre_value = params[key]
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params shape {} !".
format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict)
optim_dict = paddle.load(checkpoints + '.pdopt') optim_dict = paddle.load(checkpoints + '.pdopt')
if optimizer is not None: if optimizer is not None:
optimizer.set_state_dict(optim_dict) optimizer.set_state_dict(optim_dict)
...@@ -80,10 +97,10 @@ def load_model(config, model, optimizer=None): ...@@ -80,10 +97,10 @@ def load_model(config, model, optimizer=None):
def load_pretrained_params(model, path): def load_pretrained_params(model, path):
logger = get_logger() logger = get_logger()
if path.endswith('pdparams'): if path.endswith('.pdparams'):
path = path.replace('.pdparams', '') path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \ assert os.path.exists(path + ".pdparams"), \
f"The {path}.pdparams does not exists!" "The {}.pdparams does not exists!".format(path)
params = paddle.load(path + '.pdparams') params = paddle.load(path + '.pdparams')
state_dict = model.state_dict() state_dict = model.state_dict()
...@@ -92,11 +109,11 @@ def load_pretrained_params(model, path): ...@@ -92,11 +109,11 @@ def load_pretrained_params(model, path):
if list(state_dict[k1].shape) == list(params[k2].shape): if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2] new_state_dict[k1] = params[k2]
else: else:
logger.info( logger.warning(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" "The shape of model params {} {} not matched with loaded params {} {} !".
) format(k1, state_dict[k1].shape, k2, params[k2].shape))
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
logger.info(f"load pretrain successful from {path}") logger.info("load pretrain successful from {}".format(path))
return model return model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册