提交 cf670ec9 编写于 作者: Y Yang Zhang 提交者: hong

Serialize to pickle format (#20820)

test=develop
上级 3b31b74e
......@@ -19,6 +19,7 @@ import errno
import warnings
import six
import logging
import pickle
from functools import reduce
import numpy as np
......@@ -1505,15 +1506,21 @@ def save(program, model_path):
assert base_name != "", \
"model_path MUST be format of dirname/filename [dirname\\filename in Window], Now filename is empty str"
def get_tensor(var):
t = global_scope().find_var(var.name).get_tensor()
return np.array(t)
parameter_list = list(filter(is_parameter, program.list_vars()))
paddle.fluid.core._save_static_dict(model_path + ".pdparams",
parameter_list, global_scope())
param_dict = {p.name: get_tensor(p) for p in parameter_list}
with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f)
optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars()))
paddle.fluid.core._save_static_dict(model_path + ".pdopt",
optimizer_var_list, global_scope())
opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
with open(model_path + ".pdopt", 'wb') as f:
pickle.dump(opt_dict, f)
main_program = program.clone()
program.desc.flush()
......@@ -1552,11 +1559,30 @@ def load(program, model_path):
parameter_file_name = model_path + ".pdparams"
assert os.path.exists(parameter_file_name), \
"Parameter file [{}] not exits".format( parameter_file_name)
"Parameter file [{}] not exits".format(parameter_file_name)
def set_var(var, ndarray):
t = global_scope().find_var(var.name).get_tensor()
p = t._place()
if p.is_cpu_place():
place = paddle.fluid.CPUPlace()
elif p.is_cuda_pinned_place():
place = paddle.fluid.CUDAPinnedPlace()
else:
p = paddle.fluid.core.Place()
p.set_place(t._place())
place = paddle.fluid.CUDAPlace(p.gpu_device_id())
t.set(ndarray, place)
parameter_list = list(filter(is_parameter, program.list_vars()))
paddle.fluid.core._load_static_dict(parameter_file_name, parameter_list,
global_scope())
with open(parameter_file_name, 'rb') as f:
load_dict = pickle.load(f)
for v in parameter_list:
assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format(
v.name, parameter_file_name)
set_var(v, load_dict[v.name])
optimizer_var_list = list(
filter(is_belong_to_optimizer, program.list_vars()))
......@@ -1564,6 +1590,12 @@ def load(program, model_path):
if len(optimizer_var_list) > 0:
opt_file_name = model_path + ".pdopt"
assert os.path.exists(opt_file_name), \
"Optimizer file [{}] not exits".format( opt_file_name)
paddle.fluid.core._load_static_dict(opt_file_name, optimizer_var_list,
global_scope())
"Optimizer file [{}] not exits".format(opt_file_name)
with open(opt_file_name, 'rb') as f:
load_dict = pickle.load(f)
for v in optimizer_var_list:
assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format(
v.name, opt_file_name)
set_var(v, load_dict[v.name])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册