diff --git a/PaddleNLP/PaddleMT/transformer/README.md b/PaddleNLP/PaddleMT/transformer/README.md index 90d47f53cc4566bc5428d95e24ee43641e27c90b..1e05dc22050e9132eeac1eef72988f59cc6081e9 100644 --- a/PaddleNLP/PaddleMT/transformer/README.md +++ b/PaddleNLP/PaddleMT/transformer/README.md @@ -106,7 +106,7 @@ python -u main.py \ --prepostprocess_dropout 0.3 ``` -训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用 CPU 训练(通过参数 `--use_cuda False` 设置),训练速度相对较慢。在执行训练时若提供了 `save_param` 和 `save_checkpoint`(默认为 trained_params 和 trained_ckpts),则每隔一定 iteration 后(通过参数 `save_step` 设置,默认为10000)将分别保存当前训练的参数值和 checkpoint 到相应目录,每隔一定数目的 iteration (通过参数 `print_step` 设置,默认为100)将打印如下的日志到标准输出: +训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用 CPU 训练(通过参数 `--use_cuda False` 设置),训练速度相对较慢。在执行训练时若提供了 `save_model_path`(默认为 saved_models),则每隔一定 iteration 后(通过参数 `save_step` 设置,默认为10000)将保存当前训练的 checkpoint 到相应目录(会保存分别记录了模型参数和优化器状态的 `transformer.pdparams` 和 `transformer.pdopt` 两个文件),每隔一定数目的 iteration (通过参数 `print_step` 设置,默认为100)将打印如下的日志到标准输出: ```txt [2019-08-02 15:30:51,656 INFO train.py:262] step_idx: 150100, epoch: 32, batch: 1364, avg loss: 2.880427, normalized loss: 1.504687, ppl: 17.821888, speed: 3.34 step/s @@ -195,7 +195,7 @@ BLEU = 26.35, 57.7/32.1/20.0/13.0 (BP=1.000, ratio=1.013, hyp_len=63903, ref_len ### 预训练模型 -我们这里提供了对应有以上 BLEU 值的 [base model](https://transformer-res.bj.bcebos.com/base_model_params.tar.gz) 和 [big model](https://transformer-res.bj.bcebos.com/big_model_params.tar.gz) 的模型参数提供下载使用(注意,模型使用了提供下载的数据进行训练和测试)。 +我们这里提供了对应有以上 BLEU 值的 [base model](https://transformer-res.bj.bcebos.com/base_model_graph.tar.gz) 和 [big model](https://transformer-res.bj.bcebos.com/big_model_graph.tar.gz) 的模型参数提供下载使用(注意,模型使用了提供下载的数据进行训练和测试)。 ## 进阶使用 diff --git a/PaddleNLP/PaddleMT/transformer/inference_model.py b/PaddleNLP/PaddleMT/transformer/inference_model.py index 5de0a107cd941c7d2eba42d1ec9922095bde5d8d..5ff15108935567bbe288d74ed7f4fd729b089233 100644 --- a/PaddleNLP/PaddleMT/transformer/inference_model.py +++ b/PaddleNLP/PaddleMT/transformer/inference_model.py @@ -24,6 +24,7 @@ import paddle.fluid as fluid from utils.input_field import InputField from utils.configure import PDConfig +from utils.load import load # include task-specific libs import desc @@ -31,51 +32,6 @@ import reader from transformer import create_net -def init_from_pretrain_model(args, exe, program): - - assert isinstance(args.init_from_pretrain_model, str) - - if not os.path.exists(args.init_from_pretrain_model): - raise Warning("The pretrained params do not exist.") - return False - - def existed_params(var): - if not isinstance(var, fluid.framework.Parameter): - return False - return os.path.exists( - os.path.join(args.init_from_pretrain_model, var.name)) - - fluid.io.load_vars( - exe, - args.init_from_pretrain_model, - main_program=program, - predicate=existed_params) - - print("finish initing model from pretrained params from %s" % - (args.init_from_pretrain_model)) - - return True - - -def init_from_params(args, exe, program): - - assert isinstance(args.init_from_params, str) - - if not os.path.exists(args.init_from_params): - raise Warning("the params path does not exist.") - return False - - fluid.io.load_params( - executor=exe, - dirname=args.init_from_params, - main_program=program, - filename="params.pdparams") - - print("finish init model from params from %s" % (args.init_from_params)) - - return True - - def do_save_inference_model(args): if args.use_cuda: dev_count = fluid.core.get_cuda_device_count() @@ -84,6 +40,11 @@ def do_save_inference_model(args): dev_count = int(os.environ.get('CPU_NUM', 1)) place = fluid.CPUPlace() + src_vocab = reader.DataProcessor.load_dict(args.src_vocab_fpath) + trg_vocab = reader.DataProcessor.load_dict(args.trg_vocab_fpath) + args.src_vocab_size = len(src_vocab) + args.trg_vocab_size = len(trg_vocab) + test_prog = fluid.default_main_program() startup_prog = fluid.default_startup_program() @@ -119,24 +80,20 @@ def do_save_inference_model(args): exe = fluid.Executor(place) exe.run(startup_prog) - assert (args.init_from_params) or (args.init_from_pretrain_model) - - if args.init_from_params: - init_from_params(args, exe, test_prog) - - elif args.init_from_pretrain_model: - init_from_pretrain_model(args, exe, test_prog) + assert ( + args.init_from_params), "must set init_from_params to load parameters" + load(test_prog, os.path.join(args.init_from_params, "transformer"), exe) + print("finish initing model from params from %s" % (args.init_from_params)) # saving inference model - fluid.io.save_inference_model( - args.inference_model_dir, - feeded_var_names=input_field_names, - target_vars=[out_ids, out_scores], - executor=exe, - main_program=test_prog, - model_filename="model.pdmodel", - params_filename="params.pdparams") + fluid.io.save_inference_model(args.inference_model_dir, + feeded_var_names=list(input_field_names), + target_vars=[out_ids, out_scores], + executor=exe, + main_program=test_prog, + model_filename="model.pdmodel", + params_filename="params.pdparams") print("save inference model at %s" % (args.inference_model_dir)) diff --git a/PaddleNLP/PaddleMT/transformer/predict.py b/PaddleNLP/PaddleMT/transformer/predict.py index 2ad93e5838d6a87c1aa9deb8e35da7f071aec51d..179e39f6efdb3d78cafb87a97f6e0d9de346dac5 100644 --- a/PaddleNLP/PaddleMT/transformer/predict.py +++ b/PaddleNLP/PaddleMT/transformer/predict.py @@ -25,6 +25,7 @@ import paddle.fluid as fluid from utils.input_field import InputField from utils.configure import PDConfig from utils.check import check_gpu, check_version +from utils.load import load # include task-specific libs import desc @@ -32,51 +33,6 @@ import reader from transformer import create_net, position_encoding_init -def init_from_pretrain_model(args, exe, program): - - assert isinstance(args.init_from_pretrain_model, str) - - if not os.path.exists(args.init_from_pretrain_model): - raise Warning("The pretrained params do not exist.") - return False - - def existed_params(var): - if not isinstance(var, fluid.framework.Parameter): - return False - return os.path.exists( - os.path.join(args.init_from_pretrain_model, var.name)) - - fluid.io.load_vars( - exe, - args.init_from_pretrain_model, - main_program=program, - predicate=existed_params) - - print("finish initing model from pretrained params from %s" % - (args.init_from_pretrain_model)) - - return True - - -def init_from_params(args, exe, program): - - assert isinstance(args.init_from_params, str) - - if not os.path.exists(args.init_from_params): - raise Warning("the params path does not exist.") - return False - - fluid.io.load_params( - executor=exe, - dirname=args.init_from_params, - main_program=program, - filename="params.pdparams") - - print("finish init model from params from %s" % (args.init_from_params)) - - return True - - def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): """ Post-process the beam-search decoded sequence. Truncate from the first @@ -160,13 +116,10 @@ def do_predict(args): exe = fluid.Executor(place) exe.run(startup_prog) - assert (args.init_from_params) or (args.init_from_pretrain_model) - - if args.init_from_params: - init_from_params(args, exe, test_prog) - - elif args.init_from_pretrain_model: - init_from_pretrain_model(args, exe, test_prog) + assert ( + args.init_from_params), "must set init_from_params to load parameters" + load(test_prog, os.path.join(args.init_from_params, "transformer"), exe) + print("finish initing model from params from %s" % (args.init_from_params)) # to avoid a longer length than training, reset the size of position encoding to max_length for pos_enc_param_name in desc.pos_enc_param_names: diff --git a/PaddleNLP/PaddleMT/transformer/train.py b/PaddleNLP/PaddleMT/transformer/train.py index c9fb5d7220c325477d6a0e5984f11e4e9b85f79a..129435baba706c37071e836bd8b6745dbafb0b1f 100644 --- a/PaddleNLP/PaddleMT/transformer/train.py +++ b/PaddleNLP/PaddleMT/transformer/train.py @@ -27,6 +27,7 @@ import utils.dist_utils as dist_utils from utils.input_field import InputField from utils.configure import PDConfig from utils.check import check_gpu, check_version +from utils.load import load # include task-specific libs import desc @@ -39,91 +40,6 @@ if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None: num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) -def init_from_pretrain_model(args, exe, program): - - assert isinstance(args.init_from_pretrain_model, str) - - if not os.path.exists(args.init_from_pretrain_model): - raise Warning("The pretrained params do not exist.") - return False - - def existed_params(var): - if not isinstance(var, fluid.framework.Parameter): - return False - return os.path.exists( - os.path.join(args.init_from_pretrain_model, var.name)) - - fluid.io.load_vars( - exe, - args.init_from_pretrain_model, - main_program=program, - predicate=existed_params) - - print("finish initing model from pretrained params from %s" % - (args.init_from_pretrain_model)) - - return True - - -def init_from_checkpoint(args, exe, program): - - assert isinstance(args.init_from_checkpoint, str) - - if not os.path.exists(args.init_from_checkpoint): - raise Warning("the checkpoint path does not exist.") - return False - - fluid.io.load_persistables( - executor=exe, - dirname=args.init_from_checkpoint, - main_program=program, - filename="checkpoint.pdckpt") - - print("finish initing model from checkpoint from %s" % - (args.init_from_checkpoint)) - - return True - - -def save_checkpoint(args, exe, program, dirname): - - assert isinstance(args.save_model_path, str) - - checkpoint_dir = os.path.join(args.save_model_path, args.save_checkpoint) - - if not os.path.exists(checkpoint_dir): - os.mkdir(checkpoint_dir) - - fluid.io.save_persistables( - exe, - os.path.join(checkpoint_dir, dirname), - main_program=program, - filename="checkpoint.pdparams") - - print("save checkpoint at %s" % (os.path.join(checkpoint_dir, dirname))) - - return True - - -def save_param(args, exe, program, dirname): - - assert isinstance(args.save_model_path, str) - - param_dir = os.path.join(args.save_model_path, args.save_param) - - if not os.path.exists(param_dir): - os.mkdir(param_dir) - - fluid.io.save_params( - exe, - os.path.join(param_dir, dirname), - main_program=program, - filename="params.pdparams") - print("save parameters at %s" % (os.path.join(param_dir, dirname))) - - return True - - def do_train(args): if args.use_cuda: if num_trainers > 1: # for multi-process gpu training @@ -226,11 +142,17 @@ def do_train(args): ## init from some checkpoint, to resume the previous training if args.init_from_checkpoint: - init_from_checkpoint(args, exe, train_prog) + load(train_prog, os.path.join(args.init_from_checkpoint, "transformer"), + exe) + print("finish initing model from checkpoint from %s" % + (args.init_from_checkpoint)) ## init from some pretrain models, to better solve the current task if args.init_from_pretrain_model: - init_from_pretrain_model(args, exe, train_prog) + load(train_prog, + os.path.join(args.init_from_pretrain_model, "transformer"), exe) + print("finish initing model from pretrained params from %s" % + (args.init_from_pretrain_model)) build_strategy = fluid.compiler.BuildStrategy() build_strategy.enable_inplace = True @@ -293,14 +215,12 @@ def do_train(args): avg_batch_time = time.time() if step_idx % args.save_step == 0 and step_idx != 0: + if args.save_model_path: + model_path = os.path.join(args.save_model_path, + "step_" + str(step_idx), + "transformer") + fluid.save(train_prog, model_path) - if args.save_checkpoint: - save_checkpoint(args, exe, train_prog, - "step_" + str(step_idx)) - - if args.save_param: - save_param(args, exe, train_prog, - "step_" + str(step_idx)) batch_id += 1 step_idx += 1 @@ -319,11 +239,10 @@ def do_train(args): time_consumed = time.time() - pass_start_time - if args.save_checkpoint: - save_checkpoint(args, exe, train_prog, "step_final") - - if args.save_param: - save_param(args, exe, train_prog, "step_final") + if args.save_model_path: + model_path = os.path.join(args.save_model_path, "step_final", + "transformer") + fluid.save(train_prog, model_path) if args.enable_ce: # For CE print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost)) diff --git a/PaddleNLP/PaddleMT/transformer/transformer.py b/PaddleNLP/PaddleMT/transformer/transformer.py index be20001b25fdb94fcc4bc234bae220413ddfacdd..d260e82fd7648b468c0c254bda98baa9353d48d6 100644 --- a/PaddleNLP/PaddleMT/transformer/transformer.py +++ b/PaddleNLP/PaddleMT/transformer/transformer.py @@ -17,6 +17,7 @@ import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers +from paddle.fluid.layers.utils import map_structure from desc import * @@ -24,6 +25,7 @@ from desc import * dropout_seed = None + def wrap_layer_with_block(layer, block_idx): """ Make layer define support indicating block, by which we can add layers @@ -90,7 +92,6 @@ def multi_head_attention(queries, n_head=1, dropout_rate=0., cache=None, - gather_idx=None, static_kv=False): """ Multi-Head Attention. Note that attn_bias is added to the logit before @@ -161,30 +162,28 @@ def multi_head_attention(queries, v = transpose_layer(x=reshaped_v, perm=[0, 2, 1, 3]) if cache is not None: # only for faster inference + cache_, i = cache if static_kv: # For encoder-decoder attention in inference - cache_k, cache_v = cache["static_k"], cache["static_v"] - # To init the static_k and static_v in cache. - # Maybe we can use condition_op(if_else) to do these at the first - # step in while loop to replace these, however it might be less - # efficient. + cache_k, cache_v = cache_["static_k"], cache_["static_v"] + # To init the static_k and static_v in global block. static_cache_init = wrap_layer_with_block( layers.assign, fluid.default_main_program().current_block().parent_idx) - static_cache_init(k, cache_k) - static_cache_init(v, cache_v) + static_cache_init( + k, + fluid.default_main_program().global_block().var( + "static_k_%d" % i)) + static_cache_init( + v, + fluid.default_main_program().global_block().var( + "static_v_%d" % i)) + k, v = cache_k, cache_v else: # For decoder self-attention in inference - cache_k, cache_v = cache["k"], cache["v"] - # gather cell states corresponding to selected parent - select_k = layers.gather(cache_k, index=gather_idx) - select_v = layers.gather(cache_v, index=gather_idx) - if not static_kv: - # For self attention in inference, use cache and concat time steps. - select_k = layers.concat([select_k, k], axis=2) - select_v = layers.concat([select_v, v], axis=2) - # update cell states(caches) cached in global block - layers.assign(select_k, cache_k) - layers.assign(select_v, cache_v) - return q, select_k, select_v + # use cache and concat time steps. + cache_k, cache_v = cache_["k"], cache_["v"] + k = layers.concat([cache_k, k], axis=2) + v = layers.concat([cache_v, v], axis=2) + cache_["k"], cache_["v"] = (k, v) return q, k, v def __combine_heads(x): @@ -405,8 +404,7 @@ def decoder_layer(dec_input, relu_dropout, preprocess_cmd, postprocess_cmd, - cache=None, - gather_idx=None): + cache=None): """ The layer to be stacked in decoder part. The structure of this module is similar to that in the encoder part except a multi-head attention is added to implement encoder-decoder attention. @@ -421,8 +419,7 @@ def decoder_layer(dec_input, d_model, n_head, attention_dropout, - cache=cache, - gather_idx=gather_idx) + cache=cache) slf_attn_output = post_process_layer( dec_input, slf_attn_output, @@ -440,7 +437,6 @@ def decoder_layer(dec_input, n_head, attention_dropout, cache=cache, - gather_idx=gather_idx, static_kv=True) enc_attn_output = post_process_layer( slf_attn_output, @@ -476,29 +472,27 @@ def decoder(dec_input, relu_dropout, preprocess_cmd, postprocess_cmd, - caches=None, - gather_idx=None): + caches=None): """ The decoder is composed of a stack of identical decoder_layer layers. """ for i in range(n_layer): - dec_output = decoder_layer( - dec_input, - enc_output, - dec_slf_attn_bias, - dec_enc_attn_bias, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - cache=None if caches is None else caches[i], - gather_idx=gather_idx) + dec_output = decoder_layer(dec_input, + enc_output, + dec_slf_attn_bias, + dec_enc_attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + cache=None if caches is None else + (caches[i], i)) dec_input = dec_output dec_output = pre_process_layer(dec_output, preprocess_cmd, prepostprocess_dropout) @@ -654,7 +648,6 @@ def wrap_decoder(dec_inputs, weight_sharing, enc_output=None, caches=None, - gather_idx=None, bos_idx=0): """ The wrapper assembles together all needed layers for the decoder. @@ -687,8 +680,7 @@ def wrap_decoder(dec_inputs, relu_dropout, preprocess_cmd, postprocess_cmd, - caches=caches, - gather_idx=gather_idx) + caches=caches) # Reshape to 2D tensor to use GEMM instead of BatchedGEMM dec_output = layers.reshape( dec_output, shape=[-1, dec_output.shape[-1]], inplace=True) @@ -748,8 +740,6 @@ def fast_decode(model_input, src_vocab_size, trg_vocab_size, max_in_len, force_cpu=True) step_idx = layers.fill_constant( shape=[1], dtype=start_tokens.dtype, value=0, force_cpu=True) - cond = layers.less_than(x=step_idx, y=max_len) # default force_cpu=True - while_op = layers.While(cond) # array states will be stored for each step. ids = layers.array_write( layers.reshape(start_tokens, (-1, 1)), step_idx) @@ -773,21 +763,25 @@ def fast_decode(model_input, src_vocab_size, trg_vocab_size, max_in_len, dtype=enc_output.dtype, value=0), "static_k": # for encoder-decoder attention - layers.create_tensor(dtype=enc_output.dtype), + fluid.data(shape=[None, n_head, 0, d_key], dtype=enc_output.dtype, name=("static_k_%d"%i)), "static_v": # for encoder-decoder attention - layers.create_tensor(dtype=enc_output.dtype) + fluid.data(shape=[None, n_head, 0, d_value], dtype=enc_output.dtype, name=("static_v_%d"%i)), } for i in range(n_layer) ] - with while_op.block(): - pre_ids = layers.array_read(array=ids, i=step_idx) - # Since beam_search_op dosen't enforce pre_ids' shape, we can do - # inplace reshape here which actually change the shape of pre_ids. - # pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True) - pre_scores = layers.array_read(array=scores, i=step_idx) + def cond_func(step_idx, selected_ids, selected_scores, gather_idx, + caches, trg_src_attn_bias): + length_cond = layers.less_than(x=step_idx, y=max_len) + finish_cond = layers.logical_not(layers.is_empty(x=selected_ids)) + return layers.logical_and(x=length_cond, y=finish_cond) + + def body_func(step_idx, pre_ids, pre_scores, gather_idx, caches, + trg_src_attn_bias): # gather cell states corresponding to selected parent - pre_src_attn_bias = layers.gather( - trg_src_attn_bias, index=parent_idx) + pre_caches = map_structure( + lambda x: layers.gather(x, index=gather_idx), caches) + pre_src_attn_bias = layers.gather(trg_src_attn_bias, + index=gather_idx) pre_pos = layers.elementwise_mul( x=layers.fill_constant_batch_size_like( input=pre_src_attn_bias, # cann't use lod tensor here @@ -812,14 +806,14 @@ def fast_decode(model_input, src_vocab_size, trg_vocab_size, max_in_len, postprocess_cmd, weight_sharing, enc_output=enc_output, - caches=caches, - gather_idx=parent_idx, + caches=pre_caches, bos_idx=bos_idx) # intra-beam topK topk_scores, topk_indices = layers.topk( input=layers.softmax(logits), k=beam_size) - accu_scores = layers.elementwise_add( - x=layers.log(topk_scores), y=pre_scores, axis=0) + accu_scores = layers.elementwise_add(x=layers.log(topk_scores), + y=pre_scores, + axis=0) # beam_search op uses lod to differentiate branches. accu_scores = layers.lod_reset(accu_scores, pre_ids) # topK reduction across beams, also contain special handle of @@ -832,16 +826,19 @@ def fast_decode(model_input, src_vocab_size, trg_vocab_size, max_in_len, beam_size=beam_size, end_id=eos_idx, return_parent_idx=True) - layers.increment(x=step_idx, value=1.0, in_place=True) - # cell states(caches) have been updated in wrap_decoder, - # only need to update beam search states here. + step_idx = layers.increment(x=step_idx, value=1.0, in_place=False) layers.array_write(selected_ids, i=step_idx, array=ids) layers.array_write(selected_scores, i=step_idx, array=scores) - layers.assign(gather_idx, parent_idx) - layers.assign(pre_src_attn_bias, trg_src_attn_bias) - length_cond = layers.less_than(x=step_idx, y=max_len) - finish_cond = layers.logical_not(layers.is_empty(x=selected_ids)) - layers.logical_and(x=length_cond, y=finish_cond, out=cond) + return (step_idx, selected_ids, selected_scores, gather_idx, + pre_caches, pre_src_attn_bias) + + _ = layers.while_loop(cond=cond_func, + body=body_func, + loop_vars=[ + step_idx, start_tokens, init_scores, + parent_idx, caches, trg_src_attn_bias + ], + is_test=True) finished_ids, finished_scores = layers.beam_search_decode( ids, scores, beam_size=beam_size, end_id=eos_idx) diff --git a/PaddleNLP/PaddleMT/transformer/transformer.yaml b/PaddleNLP/PaddleMT/transformer/transformer.yaml index c6cbc074ed8a76c8b4d649e7631f0c125e165511..521396925f7e2d4721cab0566fa78e0dc68d6f99 100644 --- a/PaddleNLP/PaddleMT/transformer/transformer.yaml +++ b/PaddleNLP/PaddleMT/transformer/transformer.yaml @@ -11,10 +11,11 @@ init_from_checkpoint: "" init_from_pretrain_model: "" # path of trained parameter, to make prediction init_from_params: "trained_params/step_100000" -save_model_path: "" -# the directory for saving checkpoints. +# the directory for saving models. +save_model_path: "saved_models" +# deprecated, the directory for saving checkpoints. save_checkpoint: "trained_ckpts" -# the directory for saving trained parameters. +# deprecated, the directory for saving trained parameters. save_param: "trained_params" # the directory for saving inference model. inference_model_dir: "infer_model" diff --git a/PaddleNLP/PaddleMT/transformer/utils/load.py b/PaddleNLP/PaddleMT/transformer/utils/load.py new file mode 100644 index 0000000000000000000000000000000000000000..24c5fccc59cc13959b8696eaa819613e29ee4eb8 --- /dev/null +++ b/PaddleNLP/PaddleMT/transformer/utils/load.py @@ -0,0 +1,24 @@ +import pickle +import six +import warnings +from functools import partial + +import paddle.fluid as fluid + + +def load(program, model_path, executor=None, var_list=None): + """ + To load python2 saved models in python3. + """ + try: + fluid.load(program, model_path, executor, var_list) + except UnicodeDecodeError: + warnings.warn( + "An UnicodeDecodeError is catched, which might be caused by loading " + "a python2 saved model. Encoding of pickle.load would be set and " + "load again automatically.") + if six.PY3: + load_bak = pickle.load + pickle.load = partial(load_bak, encoding="latin1") + fluid.load(program, model_path, executor, var_list) + pickle.load = load_bak