未验证 提交 b5e662f8 编写于 作者: W WeiXin 提交者: GitHub

refine jit.save/load to add support for other method, not only forward (#28376)

* refine jit.save/load to add support for other method, not only forward

* refine the code based on unit tests

* Add unit test for the code

* Add unit test for the code

* Modify the code according to the unit test

* Delete useless comments, save only one info file, etc.

* remove static_mode_white_list.pyc

* edit the code that generate 'extra_var_info'
上级 7fe5f9cc
...@@ -500,7 +500,20 @@ def _construct_program_holders(model_path, model_filename=None): ...@@ -500,7 +500,20 @@ def _construct_program_holders(model_path, model_filename=None):
# [compatible] if assign model_filename, only can load one program as Layer.forward # [compatible] if assign model_filename, only can load one program as Layer.forward
model_filename = os.path.basename(model_filename) model_filename = os.path.basename(model_filename)
model_file_path = os.path.join(model_path, model_filename) model_file_path = os.path.join(model_path, model_filename)
program_holder_dict['forward'] = _ProgramHolder( model_name = model_filename[:-len(INFER_MODEL_SUFFIX)]
#Load every file that meets the requirements in the directory model_path.
for filename in os.listdir(model_path):
if model_filename == filename:
func_name = 'forward'
model_file_path = os.path.join(model_path, model_filename)
elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith(
model_name):
func_name = filename[len(model_name) + 1:-len(
INFER_MODEL_SUFFIX)]
model_file_path = os.path.join(model_path, filename)
else:
continue
program_holder_dict[func_name] = _ProgramHolder(
_load_program_desc(model_file_path)) _load_program_desc(model_file_path))
else: else:
for _, _, file_names in os.walk(model_path): for _, _, file_names in os.walk(model_path):
...@@ -524,9 +537,23 @@ def _construct_params_and_buffers(model_path, ...@@ -524,9 +537,23 @@ def _construct_params_and_buffers(model_path,
append_suffix=True): append_suffix=True):
var_info_filename = str(params_filename) + ".info" var_info_filename = str(params_filename) + ".info"
var_info_path = os.path.join(model_path, var_info_filename) var_info_path = os.path.join(model_path, var_info_filename)
if os.path.exists(var_info_path): if os.path.exists(var_info_path):
var_dict = _load_persistable_vars(model_path, var_info_path, var_dict = _load_persistable_vars(model_path, var_info_path,
programs['forward'], params_filename) programs['forward'], params_filename)
model_name = params_filename[:-len(INFER_PARAMS_SUFFIX)]
#Load every file that meets the requirements in the directory model_path.
for file_name in os.listdir(model_path):
if file_name.endswith(INFER_PARAMS_SUFFIX) and file_name.startswith(
model_name) and file_name != params_filename:
func_name = file_name[len(model_name) + 1:-len(
INFER_PARAMS_SUFFIX)]
else:
continue
var_info_path = os.path.join(model_path, var_info_filename)
var_dict.update(
_load_persistable_vars(model_path, var_info_path, programs[
func_name], file_name))
else: else:
var_dict = _load_persistable_vars_by_program( var_dict = _load_persistable_vars_by_program(
model_path, programs['forward'], params_filename) model_path, programs['forward'], params_filename)
......
...@@ -594,6 +594,13 @@ def save(layer, path, input_spec=None, **configs): ...@@ -594,6 +594,13 @@ def save(layer, path, input_spec=None, **configs):
# avoid change user given input_spec # avoid change user given input_spec
inner_input_spec = None inner_input_spec = None
if input_spec is not None: if input_spec is not None:
for attr_func in dir(layer):
static_func = getattr(layer, attr_func, None)
if isinstance(static_func,
StaticFunction) and 'forward' != attr_func:
raise ValueError(
"If there are static functions other than 'forward' that need to be saved, the input 'input_spec' should be None, but received the type of 'input_spec' is %s."
% type(input_spec))
if not isinstance(input_spec, list): if not isinstance(input_spec, list):
raise TypeError( raise TypeError(
"The input input_spec should be 'list', but received input_spec's type is %s." "The input input_spec should be 'list', but received input_spec's type is %s."
...@@ -612,19 +619,23 @@ def save(layer, path, input_spec=None, **configs): ...@@ -612,19 +619,23 @@ def save(layer, path, input_spec=None, **configs):
# parse configs # parse configs
configs = _parse_save_configs(configs) configs = _parse_save_configs(configs)
scope = core.Scope()
# 2. get program from Layer extra_var_info = dict()
# TODO(chenweihang): add support for other method, not only forward for attr_func in dir(layer):
if isinstance(layer.forward, StaticFunction): static_func = getattr(layer, attr_func, None)
concrete_program = layer.forward.concrete_program if isinstance(static_func, StaticFunction):
else: concrete_program = static_func.concrete_program
elif 'forward' == attr_func:
# transform in jit.save, if input_spec is incomplete, declarative will throw error # transform in jit.save, if input_spec is incomplete, declarative will throw error
static_forward = declarative(layer.forward, input_spec=inner_input_spec) static_forward = declarative(
layer.forward, input_spec=inner_input_spec)
concrete_program = static_forward.concrete_program concrete_program = static_forward.concrete_program
# the input_spec has been used in declarative, which is equal to # the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec, # @declarative with input_spec and jit.save without input_spec,
# avoid needless warning # avoid needless warning
inner_input_spec = None inner_input_spec = None
else:
continue
# 3. build input & output of save_infernece_model # 3. build input & output of save_infernece_model
# NOTE(chenweihang): [ Get input variables name ] # NOTE(chenweihang): [ Get input variables name ]
...@@ -654,14 +665,14 @@ def save(layer, path, input_spec=None, **configs): ...@@ -654,14 +665,14 @@ def save(layer, path, input_spec=None, **configs):
state_names_dict[var.name] = structured_name state_names_dict[var.name] = structured_name
# 4. share parameters from Layer to scope & record var info # 4. share parameters from Layer to scope & record var info
scope = core.Scope()
extra_var_info = dict()
for param_or_buffer in concrete_program.parameters: for param_or_buffer in concrete_program.parameters:
# share to scope # share to scope
param_or_buffer_tensor = scope.var(param_or_buffer.name).get_tensor() param_or_buffer_tensor = scope.var(param_or_buffer.name).get_tensor(
)
src_tensor = param_or_buffer.value().get_tensor() src_tensor = param_or_buffer.value().get_tensor()
param_or_buffer_tensor._share_data_with(src_tensor) param_or_buffer_tensor._share_data_with(src_tensor)
# record var info # record var info
if param_or_buffer.name not in extra_var_info:
extra_info_dict = dict() extra_info_dict = dict()
if param_or_buffer.name in state_names_dict: if param_or_buffer.name in state_names_dict:
extra_info_dict['structured_name'] = state_names_dict[ extra_info_dict['structured_name'] = state_names_dict[
...@@ -678,8 +689,12 @@ def save(layer, path, input_spec=None, **configs): ...@@ -678,8 +689,12 @@ def save(layer, path, input_spec=None, **configs):
model_path = dirname model_path = dirname
# NOTE(chenweihang): because prefix contains model and params filename, # NOTE(chenweihang): because prefix contains model and params filename,
# so we don't support set model_filename & params_filename # so we don't support set model_filename & params_filename
if 'forward' == attr_func:
model_filename = file_prefix + INFER_MODEL_SUFFIX model_filename = file_prefix + INFER_MODEL_SUFFIX
params_filename = file_prefix + INFER_PARAMS_SUFFIX params_filename = file_prefix + INFER_PARAMS_SUFFIX
else:
model_filename = file_prefix + '.' + attr_func + INFER_MODEL_SUFFIX
params_filename = file_prefix + '.' + attr_func + INFER_PARAMS_SUFFIX
with scope_guard(scope): with scope_guard(scope):
save_inference_model( save_inference_model(
...@@ -708,6 +723,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -708,6 +723,7 @@ def save(layer, path, input_spec=None, **configs):
# but we can save these information in `jit.save` without changing the original # but we can save these information in `jit.save` without changing the original
# storage to improve user experience. So we save extra information into # storage to improve user experience. So we save extra information into
# file `***.pdiparams.info` # file `***.pdiparams.info`
with scope_guard(scope):
extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX
with open(extra_var_info_path, 'wb') as f: with open(extra_var_info_path, 'wb') as f:
pickle.dump(extra_var_info, f, protocol=2) pickle.dump(extra_var_info, f, protocol=2)
......
...@@ -115,6 +115,7 @@ class TestInputSpec(unittest.TestCase): ...@@ -115,6 +115,7 @@ class TestInputSpec(unittest.TestCase):
self.assertTrue(len(net.forward.program_cache) == 1) self.assertTrue(len(net.forward.program_cache) == 1)
# 2. test save load # 2. test save load
net.inner_function(x)
jit.save(net, './simple_net') jit.save(net, './simple_net')
infer_net = fluid.dygraph.jit.load('./simple_net') infer_net = fluid.dygraph.jit.load('./simple_net')
pred = infer_net(x) pred = infer_net(x)
......
...@@ -187,6 +187,26 @@ class NoParamLayer(paddle.nn.Layer): ...@@ -187,6 +187,26 @@ class NoParamLayer(paddle.nn.Layer):
return x + y return x + y
class LinearNetWithMultiStaticFunc(fluid.dygraph.Layer):
def __init__(self, in_size, out_size):
super(LinearNetWithMultiStaticFunc, self).__init__()
self._linear_0 = Linear(in_size, out_size)
self._linear_1 = Linear(in_size, out_size)
self._scale = paddle.to_tensor(9.9)
@paddle.jit.to_static
def forward(self, x):
return self._linear_0(x)
@paddle.jit.to_static
def forward_no_param(self, x):
return x
@paddle.jit.to_static
def forward_general(self, x):
return self._linear_0(x) + self._linear_1(x) * self._scale
def train(layer, input_size=784, label_size=1): def train(layer, input_size=784, label_size=1):
# create optimizer # create optimizer
sgd = fluid.optimizer.SGDOptimizer( sgd = fluid.optimizer.SGDOptimizer(
...@@ -764,5 +784,34 @@ class TestJitSaveLoadNoParamLayer(unittest.TestCase): ...@@ -764,5 +784,34 @@ class TestJitSaveLoadNoParamLayer(unittest.TestCase):
self.assertTrue(np.array_equal(out, load_out)) self.assertTrue(np.array_equal(out, load_out))
class TestJitSaveLoadMultiMethods(unittest.TestCase):
def setUp(self):
# enable dygraph mode
paddle.disable_static()
def test_jit_save_load_inference(self):
model_path_inference = "jit_save_load_multi_methods/model"
IMAGE_SIZE = 224
layer = LinearNetWithMultiStaticFunc(IMAGE_SIZE, 10)
inps = paddle.randn([1, IMAGE_SIZE])
result_origin = {}
for func in dir(layer):
if func.startswith('forward'):
result_origin[func] = getattr(layer, func, None)(inps)
paddle.jit.save(layer, model_path_inference)
load_net = paddle.jit.load(model_path_inference)
for func, result in result_origin.items():
self.assertTrue(
float((result - getattr(load_net, func, None)(inps)).abs().max(
)) < 1e-5)
def test_jit_save_load_multi_methods_inputspec(self):
model_path = 'jit_save_load_multi_methods/model'
layer = LinearNetWithMultiStaticFunc(784, 1)
with self.assertRaises(ValueError):
paddle.jit.save(
layer, model_path, input_spec=[InputSpec(shape=[None, 784])])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册