From edcf04ca5a6f46ee34f77e3418b46ea1ec51d4fd Mon Sep 17 00:00:00 2001 From: songyouwei Date: Mon, 17 Feb 2020 09:47:51 +0800 Subject: [PATCH] [cherry-pick] fix pickle between python 2 & 3 (#22620) * cherry-pick #22555 test=release/1.7, test=develop * cherry-pick #22621 test=release/1.7, test=develop --- python/paddle/fluid/dygraph/checkpoint.py | 9 ++++++--- python/paddle/fluid/io.py | 18 +++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index 27658ba3d46..a51c431c515 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -18,6 +18,7 @@ import os import collections from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase import pickle +import six from . import learning_rate_scheduler import warnings from .. import core @@ -88,7 +89,7 @@ def save_dygraph(state_dict, model_path): os.makedirs(dir_name) with open(file_name, 'wb') as f: - pickle.dump(model_dict, f) + pickle.dump(model_dict, f, protocol=2) @dygraph_only @@ -130,7 +131,8 @@ def load_dygraph(model_path, keep_name_table=False): params_file_path)) with open(params_file_path, 'rb') as f: - para_dict = pickle.load(f) + para_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') if not keep_name_table and "StructuredToParameterName@@" in para_dict: del para_dict["StructuredToParameterName@@"] @@ -138,6 +140,7 @@ def load_dygraph(model_path, keep_name_table=False): opti_file_path = model_path + ".pdopt" if os.path.exists(opti_file_path): with open(opti_file_path, 'rb') as f: - opti_dict = pickle.load(f) + opti_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') return para_dict, opti_dict diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 09adb1fb13d..448eb49563a 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -800,7 +800,7 @@ def load_vars(executor, var_temp = paddle.fluid.global_scope().find_var(each_var.name) assert var_temp != None, "can't not find var: " + each_var.name new_shape = (np.array(var_temp.get_tensor())).shape - assert each_var.name in orig_para_shape, earch_var.name + "MUST in var list" + assert each_var.name in orig_para_shape, each_var.name + "MUST in var list" orig_shape = orig_para_shape.get(each_var.name) if new_shape != orig_shape: raise RuntimeError( @@ -1579,14 +1579,14 @@ def save(program, model_path): parameter_list = list(filter(is_parameter, program.list_vars())) param_dict = {p.name: get_tensor(p) for p in parameter_list} with open(model_path + ".pdparams", 'wb') as f: - pickle.dump(param_dict, f) + pickle.dump(param_dict, f, protocol=2) optimizer_var_list = list( filter(is_belong_to_optimizer, program.list_vars())) opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list} with open(model_path + ".pdopt", 'wb') as f: - pickle.dump(opt_dict, f) + pickle.dump(opt_dict, f, protocol=2) main_program = program.clone() program.desc.flush() @@ -1733,7 +1733,8 @@ def load(program, model_path, executor=None, var_list=None): global_scope(), executor._default_executor) with open(parameter_file_name, 'rb') as f: - load_dict = pickle.load(f) + load_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') for v in parameter_list: assert v.name in load_dict, \ "Can not find [{}] in model file [{}]".format( @@ -1753,7 +1754,8 @@ def load(program, model_path, executor=None, var_list=None): optimizer_var_list, global_scope(), executor._default_executor) with open(opt_file_name, 'rb') as f: - load_dict = pickle.load(f) + load_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') for v in optimizer_var_list: assert v.name in load_dict, \ "Can not find [{}] in model file [{}]".format( @@ -1877,12 +1879,14 @@ def load_program_state(model_path, var_list=None): "Parameter file [{}] not exits".format(parameter_file_name) with open(parameter_file_name, 'rb') as f: - para_dict = pickle.load(f) + para_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') opt_file_name = model_prefix + ".pdopt" if os.path.exists(opt_file_name): with open(opt_file_name, 'rb') as f: - opti_dict = pickle.load(f) + opti_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') para_dict.update(opti_dict) -- GitLab