diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 3401f85a78b074f9e195b23828f56a7b3f848ddc..6d6c132ab5b6e65099fe35560bd8bf88fe12caf1 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -35,7 +35,7 @@ from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTra from paddle.fluid.dygraph.io import TranslatedLayer, INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX from paddle.fluid.dygraph.layers import Layer from paddle.fluid.executor import Executor, scope_guard -from paddle.fluid.framework import Block, ParamBase, Program, Variable +from paddle.fluid.framework import Block, ParamBase, Program, Variable, Parameter from paddle.fluid.framework import _current_expected_place, _dygraph_guard, _dygraph_tracer from paddle.fluid.framework import dygraph_only, in_dygraph_mode from paddle.fluid.wrapped_decorator import wrap_decorator @@ -659,6 +659,10 @@ def save(layer, path, input_spec=None, **configs): raise TypeError( "The input of paddle.jit.save should be 'Layer' or 'Function', but received input type is %s." % type(layer)) + elif inspect.isfunction(layer) or isinstance(layer, StaticFunction): + warnings.warn( + 'What you save is a function, and `jit.save` will generate the name of the model file according to `path` you specify. When loading these files with `jit.load`, you get a `TranslatedLayer` whose inference result is the same as the inference result of the function you saved.' + ) # NOTE(chenweihang): If the input layer be wrapped by DataParallel, # the args and kwargs of forward method will can't be parsed by @@ -741,12 +745,38 @@ def save(layer, path, input_spec=None, **configs): else: continue + else: + # When layer is a function + if isinstance(attr_func, StaticFunction): + concrete_program = attr_func.concrete_program_specify_input_spec( + inner_input_spec) + else: + if inner_input_spec: + inner_input_spec = pack_sequence_as(input_spec, + inner_input_spec) + static_function = declarative( + attr_func, input_spec=inner_input_spec) + concrete_program = static_function.concrete_program + + if static_function._class_instance is None: + warnings.warn( + '`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`'. + format(layer)) + + dygraph_state_dict = None + if isinstance(inner_layer, Layer): + dygraph_state_dict = inner_layer.state_dict() + elif isinstance(attr_func, StaticFunction): + if attr_func._class_instance: + dygraph_state_dict = attr_func._class_instance.state_dict() + + if dygraph_state_dict: # NOTE(chenweihang): we maintain the mapping of variable name to # structured name, the buffer variable (non-persistable) # saved to inference program may not need by dygraph Layer, # we only record the state_dict variable's structured name state_names_dict = dict() - for structured_name, var in six.iteritems(inner_layer.state_dict()): + for structured_name, var in six.iteritems(dygraph_state_dict): state_names_dict[var.name] = structured_name # 3. share parameters from Layer to scope & record var info @@ -767,18 +797,6 @@ def save(layer, path, input_spec=None, **configs): if isinstance(param_or_buffer, ParamBase): extra_info_dict['trainable'] = param_or_buffer.trainable extra_var_info[param_or_buffer.name] = extra_info_dict - else: - # When layer is a function - if isinstance(attr_func, StaticFunction): - concrete_program = attr_func.concrete_program_specify_input_spec( - inner_input_spec) - else: - if inner_input_spec: - inner_input_spec = pack_sequence_as(input_spec, - inner_input_spec) - static_function = declarative( - attr_func, input_spec=inner_input_spec) - concrete_program = static_function.concrete_program # 4. build input & output of save_infernece_model # NOTE(chenweihang): [ Get input variables name ] @@ -840,7 +858,14 @@ def save(layer, path, input_spec=None, **configs): # but we can save these information in `jit.save` without changing the original # storage to improve user experience. So we save extra information into # file `***.pdiparams.info` - if isinstance(layer, Layer) and extra_var_info: + + # "layer" can only be Layer or function or StaticFunction. + + contain_parameter = False + for var in concrete_program.main_program.list_vars(): + contain_parameter |= isinstance(var, Parameter) + + if (isinstance(layer, Layer) or contain_parameter) and extra_var_info: with scope_guard(scope): extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX with open(extra_var_info_path, 'wb') as f: diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index 81db84a5262fb65914ab9a32688208a3f50cbc62..1d24687a6b1994d66a9565b7b3c05f0fbe04f04e 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -1227,6 +1227,99 @@ class TestJitSaveLoadFunctionCase3(unittest.TestCase): self.assertTrue((load_result - origin).abs().max() < 1e-10) +class TestJitSaveLoadFunctionWithParamCase1(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_jit_save_load_function(self): + class LinearNet(paddle.nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = paddle.nn.Linear(5, 6) + + def forward(self, x): + return paddle.tanh(x) + + def anothor_forward(self, x): + return self._linear(x) + + layer = LinearNet() + + inps = paddle.rand([3, 5]) + origin = layer.anothor_forward(inps) + + func = paddle.jit.to_static( + layer.anothor_forward, [paddle.static.InputSpec(shape=[-1, 5])]) + path = 'test_jit_save_load_function_with_params_case1/func' + paddle.jit.save(func, path) + load_func = paddle.jit.load(path) + + load_result = load_func(inps) + self.assertTrue(np.array_equal(load_result.numpy(), origin.numpy())) + + +class TestJitSaveLoadFunctionWithParamCase2(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_jit_save_load_function(self): + class LinearNet(paddle.nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = paddle.nn.Linear(5, 6) + + def forward(self, x): + return paddle.tanh(x) + + @paddle.jit.to_static(input_spec=[InputSpec(shape=[-1, 5])]) + def anothor_forward(self, x): + return self._linear(x) + + layer = LinearNet() + + inps = paddle.rand([3, 5]) + + path = 'test_jit_save_load_function_with_params_case2/func' + paddle.jit.save(layer.anothor_forward, path) + origin_result = layer.anothor_forward(inps) + load_func = paddle.jit.load(path) + + load_result = load_func(inps) + + self.assertTrue( + np.array_equal(origin_result.numpy(), load_result.numpy())) + + +class TestJitSaveLoadFunctionWithParamCase3(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_jit_save_load_function(self): + class LinearNet(paddle.nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = paddle.nn.Linear(5, 6) + + def forward(self, x): + return paddle.tanh(x) + + @paddle.jit.to_static + def anothor_forward(self, x): + return self._linear(x) + + layer = LinearNet() + + inps = paddle.rand([3, 5]) + origin = layer.anothor_forward(inps) + + path = 'test_jit_save_load_function_with_params_case3/func' + paddle.jit.save(layer.anothor_forward, path) + load_func = paddle.jit.load(path) + + load_result = load_func(inps) + self.assertTrue(np.array_equal(load_result.numpy(), origin.numpy())) + + class TestJitSaveLoadDataParallel(unittest.TestCase): def verify_inference_correctness(self, layer, path): layer.eval()