未验证 提交 b768708e 编写于 作者: S songyouwei 提交者: GitHub

fix pickle load encoding between python 2 and 3 (#22621)

test=develop
上级 cb4560b7
...@@ -132,7 +132,7 @@ def load_dygraph(model_path, keep_name_table=False): ...@@ -132,7 +132,7 @@ def load_dygraph(model_path, keep_name_table=False):
with open(params_file_path, 'rb') as f: with open(params_file_path, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load( para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='bytes') f, encoding='latin1')
if not keep_name_table and "StructuredToParameterName@@" in para_dict: if not keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"] del para_dict["StructuredToParameterName@@"]
...@@ -141,6 +141,6 @@ def load_dygraph(model_path, keep_name_table=False): ...@@ -141,6 +141,6 @@ def load_dygraph(model_path, keep_name_table=False):
if os.path.exists(opti_file_path): if os.path.exists(opti_file_path):
with open(opti_file_path, 'rb') as f: with open(opti_file_path, 'rb') as f:
opti_dict = pickle.load(f) if six.PY2 else pickle.load( opti_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='bytes') f, encoding='latin1')
return para_dict, opti_dict return para_dict, opti_dict
...@@ -1696,7 +1696,7 @@ def load(program, model_path, executor=None, var_list=None): ...@@ -1696,7 +1696,7 @@ def load(program, model_path, executor=None, var_list=None):
executor._default_executor) executor._default_executor)
with open(parameter_file_name, 'rb') as f: with open(parameter_file_name, 'rb') as f:
load_dict = pickle.load(f) if six.PY2 else pickle.load( load_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='bytes') f, encoding='latin1')
for v in parameter_list: for v in parameter_list:
assert v.name in load_dict, \ assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format( "Can not find [{}] in model file [{}]".format(
...@@ -1717,7 +1717,7 @@ def load(program, model_path, executor=None, var_list=None): ...@@ -1717,7 +1717,7 @@ def load(program, model_path, executor=None, var_list=None):
with open(opt_file_name, 'rb') as f: with open(opt_file_name, 'rb') as f:
load_dict = pickle.load(f) if six.PY2 else pickle.load( load_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='bytes') f, encoding='latin1')
for v in optimizer_var_list: for v in optimizer_var_list:
assert v.name in load_dict, \ assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format( "Can not find [{}] in model file [{}]".format(
...@@ -1842,13 +1842,13 @@ def load_program_state(model_path, var_list=None): ...@@ -1842,13 +1842,13 @@ def load_program_state(model_path, var_list=None):
with open(parameter_file_name, 'rb') as f: with open(parameter_file_name, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load( para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='bytes') f, encoding='latin1')
opt_file_name = model_prefix + ".pdopt" opt_file_name = model_prefix + ".pdopt"
if os.path.exists(opt_file_name): if os.path.exists(opt_file_name):
with open(opt_file_name, 'rb') as f: with open(opt_file_name, 'rb') as f:
opti_dict = pickle.load(f) if six.PY2 else pickle.load( opti_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='bytes') f, encoding='latin1')
para_dict.update(opti_dict) para_dict.update(opti_dict)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册