未验证 提交 d0937f0b 编写于 作者: G Guo Sheng 提交者: GitHub

Make dygraph Transformer support to load py2 saved models. (#4329)

上级 e15f205a
......@@ -24,6 +24,7 @@ import paddle.fluid as fluid
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
from utils.load import load_dygraph
# include task-specific libs
import reader
......@@ -95,7 +96,7 @@ def do_predict(args):
# load the trained model
assert args.init_from_params, (
"Please set init_from_params to load the infer model.")
model_dict, _ = fluid.load_dygraph(
model_dict, _ = load_dygraph(
os.path.join(args.init_from_params, "transformer"))
# to avoid a longer length than training, reset the size of position
# encoding to max_length
......
......@@ -24,6 +24,7 @@ import paddle.fluid as fluid
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
from utils.load import load_dygraph
# include task-specific libs
import reader
......@@ -97,13 +98,13 @@ def do_train(args):
## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
model_dict, opt_dict = fluid.load_dygraph(
model_dict, opt_dict = load_dygraph(
os.path.join(args.init_from_checkpoint, "transformer"))
transformer.load_dict(model_dict)
optimizer.set_dict(opt_dict)
## init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model:
model_dict, _ = fluid.load_dygraph(
model_dict, _ = load_dygraph(
os.path.join(args.init_from_pretrain_model, "transformer"))
transformer.load_dict(model_dict)
......
import pickle
import six
import warnings
from functools import partial
import paddle.fluid as fluid
def load_dygraph(model_path, keep_name_table=False):
"""
To load python2 saved models in python3.
"""
try:
para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table)
return para_dict, opti_dict
except UnicodeDecodeError:
warnings.warn(
"An UnicodeDecodeError is catched, which might be caused by loading "
"a python2 saved model. Encoding of pickle.load would be set and "
"load again automatically.")
if six.PY3:
load_bak = pickle.load
pickle.load = partial(load_bak, encoding="latin1")
para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table)
pickle.load = load_bak
return para_dict, opti_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册