diff --git a/dygraph/transformer/predict.py b/dygraph/transformer/predict.py index d33d4e5c909d9565dccbddaf3181e5a0b56c0d88..c5452e8bdf482d0932851521e7f35880590f9471 100644 --- a/dygraph/transformer/predict.py +++ b/dygraph/transformer/predict.py @@ -24,6 +24,7 @@ import paddle.fluid as fluid from utils.configure import PDConfig from utils.check import check_gpu, check_version +from utils.load import load_dygraph # include task-specific libs import reader @@ -95,7 +96,7 @@ def do_predict(args): # load the trained model assert args.init_from_params, ( "Please set init_from_params to load the infer model.") - model_dict, _ = fluid.load_dygraph( + model_dict, _ = load_dygraph( os.path.join(args.init_from_params, "transformer")) # to avoid a longer length than training, reset the size of position # encoding to max_length diff --git a/dygraph/transformer/train.py b/dygraph/transformer/train.py index bbfb2c12f58a4a4b4a2ac39baeabaa08f670e242..f24553b17f884e9372f70e7791e6e96ed9cb53e7 100644 --- a/dygraph/transformer/train.py +++ b/dygraph/transformer/train.py @@ -24,6 +24,7 @@ import paddle.fluid as fluid from utils.configure import PDConfig from utils.check import check_gpu, check_version +from utils.load import load_dygraph # include task-specific libs import reader @@ -97,13 +98,13 @@ def do_train(args): ## init from some checkpoint, to resume the previous training if args.init_from_checkpoint: - model_dict, opt_dict = fluid.load_dygraph( + model_dict, opt_dict = load_dygraph( os.path.join(args.init_from_checkpoint, "transformer")) transformer.load_dict(model_dict) optimizer.set_dict(opt_dict) ## init from some pretrain models, to better solve the current task if args.init_from_pretrain_model: - model_dict, _ = fluid.load_dygraph( + model_dict, _ = load_dygraph( os.path.join(args.init_from_pretrain_model, "transformer")) transformer.load_dict(model_dict) diff --git a/dygraph/transformer/utils/load.py b/dygraph/transformer/utils/load.py new file mode 100644 index 0000000000000000000000000000000000000000..436fa3bd7a0fd4ec9404c97c0491bd80adbbcfc4 --- /dev/null +++ b/dygraph/transformer/utils/load.py @@ -0,0 +1,25 @@ +import pickle +import six +import warnings +from functools import partial + +import paddle.fluid as fluid + +def load_dygraph(model_path, keep_name_table=False): + """ + To load python2 saved models in python3. + """ + try: + para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table) + return para_dict, opti_dict + except UnicodeDecodeError: + warnings.warn( + "An UnicodeDecodeError is catched, which might be caused by loading " + "a python2 saved model. Encoding of pickle.load would be set and " + "load again automatically.") + if six.PY3: + load_bak = pickle.load + pickle.load = partial(load_bak, encoding="latin1") + para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table) + pickle.load = load_bak + return para_dict, opti_dict