From d0937f0bfc202178cdc307813636ffc31fca2d88 Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Sat, 22 Feb 2020 00:50:41 +0800 Subject: [PATCH] Make dygraph Transformer support to load py2 saved models. (#4329) --- dygraph/transformer/predict.py | 3 ++- dygraph/transformer/train.py | 5 +++-- dygraph/transformer/utils/load.py | 25 +++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 dygraph/transformer/utils/load.py diff --git a/dygraph/transformer/predict.py b/dygraph/transformer/predict.py index d33d4e5c..c5452e8b 100644 --- a/dygraph/transformer/predict.py +++ b/dygraph/transformer/predict.py @@ -24,6 +24,7 @@ import paddle.fluid as fluid from utils.configure import PDConfig from utils.check import check_gpu, check_version +from utils.load import load_dygraph # include task-specific libs import reader @@ -95,7 +96,7 @@ def do_predict(args): # load the trained model assert args.init_from_params, ( "Please set init_from_params to load the infer model.") - model_dict, _ = fluid.load_dygraph( + model_dict, _ = load_dygraph( os.path.join(args.init_from_params, "transformer")) # to avoid a longer length than training, reset the size of position # encoding to max_length diff --git a/dygraph/transformer/train.py b/dygraph/transformer/train.py index bbfb2c12..f24553b1 100644 --- a/dygraph/transformer/train.py +++ b/dygraph/transformer/train.py @@ -24,6 +24,7 @@ import paddle.fluid as fluid from utils.configure import PDConfig from utils.check import check_gpu, check_version +from utils.load import load_dygraph # include task-specific libs import reader @@ -97,13 +98,13 @@ def do_train(args): ## init from some checkpoint, to resume the previous training if args.init_from_checkpoint: - model_dict, opt_dict = fluid.load_dygraph( + model_dict, opt_dict = load_dygraph( os.path.join(args.init_from_checkpoint, "transformer")) transformer.load_dict(model_dict) optimizer.set_dict(opt_dict) ## init from some pretrain models, to better solve the current task if args.init_from_pretrain_model: - model_dict, _ = fluid.load_dygraph( + model_dict, _ = load_dygraph( os.path.join(args.init_from_pretrain_model, "transformer")) transformer.load_dict(model_dict) diff --git a/dygraph/transformer/utils/load.py b/dygraph/transformer/utils/load.py new file mode 100644 index 00000000..436fa3bd --- /dev/null +++ b/dygraph/transformer/utils/load.py @@ -0,0 +1,25 @@ +import pickle +import six +import warnings +from functools import partial + +import paddle.fluid as fluid + +def load_dygraph(model_path, keep_name_table=False): + """ + To load python2 saved models in python3. + """ + try: + para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table) + return para_dict, opti_dict + except UnicodeDecodeError: + warnings.warn( + "An UnicodeDecodeError is catched, which might be caused by loading " + "a python2 saved model. Encoding of pickle.load would be set and " + "load again automatically.") + if six.PY3: + load_bak = pickle.load + pickle.load = partial(load_bak, encoding="latin1") + para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table) + pickle.load = load_bak + return para_dict, opti_dict -- GitLab