提交 7f1badf7 编写于 作者: 文幕地方's avatar 文幕地方

pair param with key when load trained model params

上级 4a6f7cec
...@@ -54,7 +54,7 @@ def load_model(config, model, optimizer=None): ...@@ -54,7 +54,7 @@ def load_model(config, model, optimizer=None):
pretrained_model = global_config.get('pretrained_model') pretrained_model = global_config.get('pretrained_model')
best_model_dict = {} best_model_dict = {}
if checkpoints: if checkpoints:
if checkpoints.endswith('pdparams'): if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '') checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdparams"), \ assert os.path.exists(checkpoints + ".pdparams"), \
"The {}.pdparams does not exists!".format(checkpoints) "The {}.pdparams does not exists!".format(checkpoints)
...@@ -97,7 +97,7 @@ def load_model(config, model, optimizer=None): ...@@ -97,7 +97,7 @@ def load_model(config, model, optimizer=None):
def load_pretrained_params(model, path): def load_pretrained_params(model, path):
logger = get_logger() logger = get_logger()
if path.endswith('pdparams'): if path.endswith('.pdparams'):
path = path.replace('.pdparams', '') path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \ assert os.path.exists(path + ".pdparams"), \
"The {}.pdparams does not exists!".format(path) "The {}.pdparams does not exists!".format(path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册