提交 cf670ec9 编写于 作者: Y Yang Zhang 提交者: hong

Serialize to pickle format (#20820)

test=develop
上级 3b31b74e
...@@ -19,6 +19,7 @@ import errno ...@@ -19,6 +19,7 @@ import errno
import warnings import warnings
import six import six
import logging import logging
import pickle
from functools import reduce from functools import reduce
import numpy as np import numpy as np
...@@ -1505,15 +1506,21 @@ def save(program, model_path): ...@@ -1505,15 +1506,21 @@ def save(program, model_path):
assert base_name != "", \ assert base_name != "", \
"model_path MUST be format of dirname/filename [dirname\\filename in Window], Now filename is empty str" "model_path MUST be format of dirname/filename [dirname\\filename in Window], Now filename is empty str"
def get_tensor(var):
t = global_scope().find_var(var.name).get_tensor()
return np.array(t)
parameter_list = list(filter(is_parameter, program.list_vars())) parameter_list = list(filter(is_parameter, program.list_vars()))
paddle.fluid.core._save_static_dict(model_path + ".pdparams", param_dict = {p.name: get_tensor(p) for p in parameter_list}
parameter_list, global_scope()) with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f)
optimizer_var_list = list( optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars())) filter(is_belong_to_optimizer, program.list_vars()))
paddle.fluid.core._save_static_dict(model_path + ".pdopt", opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
optimizer_var_list, global_scope()) with open(model_path + ".pdopt", 'wb') as f:
pickle.dump(opt_dict, f)
main_program = program.clone() main_program = program.clone()
program.desc.flush() program.desc.flush()
...@@ -1552,11 +1559,30 @@ def load(program, model_path): ...@@ -1552,11 +1559,30 @@ def load(program, model_path):
parameter_file_name = model_path + ".pdparams" parameter_file_name = model_path + ".pdparams"
assert os.path.exists(parameter_file_name), \ assert os.path.exists(parameter_file_name), \
"Parameter file [{}] not exits".format( parameter_file_name) "Parameter file [{}] not exits".format(parameter_file_name)
def set_var(var, ndarray):
t = global_scope().find_var(var.name).get_tensor()
p = t._place()
if p.is_cpu_place():
place = paddle.fluid.CPUPlace()
elif p.is_cuda_pinned_place():
place = paddle.fluid.CUDAPinnedPlace()
else:
p = paddle.fluid.core.Place()
p.set_place(t._place())
place = paddle.fluid.CUDAPlace(p.gpu_device_id())
t.set(ndarray, place)
parameter_list = list(filter(is_parameter, program.list_vars())) parameter_list = list(filter(is_parameter, program.list_vars()))
paddle.fluid.core._load_static_dict(parameter_file_name, parameter_list, with open(parameter_file_name, 'rb') as f:
global_scope()) load_dict = pickle.load(f)
for v in parameter_list:
assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format(
v.name, parameter_file_name)
set_var(v, load_dict[v.name])
optimizer_var_list = list( optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars())) filter(is_belong_to_optimizer, program.list_vars()))
...@@ -1564,6 +1590,12 @@ def load(program, model_path): ...@@ -1564,6 +1590,12 @@ def load(program, model_path):
if len(optimizer_var_list) > 0: if len(optimizer_var_list) > 0:
opt_file_name = model_path + ".pdopt" opt_file_name = model_path + ".pdopt"
assert os.path.exists(opt_file_name), \ assert os.path.exists(opt_file_name), \
"Optimizer file [{}] not exits".format( opt_file_name) "Optimizer file [{}] not exits".format(opt_file_name)
paddle.fluid.core._load_static_dict(opt_file_name, optimizer_var_list,
global_scope()) with open(opt_file_name, 'rb') as f:
load_dict = pickle.load(f)
for v in optimizer_var_list:
assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format(
v.name, opt_file_name)
set_var(v, load_dict[v.name])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册