From 95a0f87b442ac8e185b6cc02487e26b48039ab83 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 27 Nov 2020 09:59:31 +0800 Subject: [PATCH] support jit.save datra parallel (#29135) --- python/paddle/fluid/dygraph/jit.py | 22 ++++++++---- .../tests/unittests/test_jit_save_load.py | 34 +++++++++++++++++++ 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index d618874ad98..d1e6b70a198 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -581,6 +581,16 @@ def save(layer, path, input_spec=None, **configs): "The input layer of paddle.jit.save should be 'Layer', but received layer type is %s." % type(layer)) + # NOTE(chenweihang): If the input layer be wrapped by DataParallel, + # the args and kwargs of forward method will can't be parsed by + # function_spec, so here we save DataParallel._layers instead + # DataParallel it self + # NOTE(chenweihang): using inner_layer, do not change input layer + if isinstance(layer, paddle.DataParallel): + inner_layer = layer._layers + else: + inner_layer = layer + # path check file_prefix = os.path.basename(path) if file_prefix == "": @@ -596,8 +606,8 @@ def save(layer, path, input_spec=None, **configs): # avoid change user given input_spec inner_input_spec = None if input_spec is not None: - for attr_func in dir(layer): - static_func = getattr(layer, attr_func, None) + for attr_func in dir(inner_layer): + static_func = getattr(inner_layer, attr_func, None) if isinstance(static_func, StaticFunction) and 'forward' != attr_func: raise ValueError( @@ -623,14 +633,14 @@ def save(layer, path, input_spec=None, **configs): configs = _parse_save_configs(configs) scope = core.Scope() extra_var_info = dict() - for attr_func in dir(layer): - static_func = getattr(layer, attr_func, None) + for attr_func in dir(inner_layer): + static_func = getattr(inner_layer, attr_func, None) if isinstance(static_func, StaticFunction): concrete_program = static_func.concrete_program elif 'forward' == attr_func: # transform in jit.save, if input_spec is incomplete, declarative will throw error static_forward = declarative( - layer.forward, input_spec=inner_input_spec) + inner_layer.forward, input_spec=inner_input_spec) concrete_program = static_forward.concrete_program # the input_spec has been used in declarative, which is equal to # @declarative with input_spec and jit.save without input_spec, @@ -663,7 +673,7 @@ def save(layer, path, input_spec=None, **configs): # saved to inference program may not need by dygraph Layer, # we only record the state_dict variable's structured name state_names_dict = dict() - for structured_name, var in six.iteritems(layer.state_dict()): + for structured_name, var in six.iteritems(inner_layer.state_dict()): state_names_dict[var.name] = structured_name # 4. share parameters from Layer to scope & record var info diff --git a/python/paddle/fluid/tests/unittests/test_jit_save_load.py b/python/paddle/fluid/tests/unittests/test_jit_save_load.py index 62d1d175d10..258136c3cf0 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_jit_save_load.py @@ -863,5 +863,39 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase): layer, model_path, input_spec=[InputSpec(shape=[None, 784])]) +class TestJitSaveLoadDataParallel(unittest.TestCase): + def verify_inference_correctness(self, layer, path): + layer.eval() + loaded_layer = paddle.jit.load(path) + loaded_layer.eval() + # inference & compare + x = paddle.to_tensor(np.random.random((1, 784)).astype('float32')) + pred = layer(x).numpy() + loaded_pred = loaded_layer(x).numpy() + self.assertTrue( + np.array_equal(pred, loaded_pred), + msg="Result diff when load and inference:\nlayer result:\n{}\n" \ + "loaded layer result:\n{}".format(pred, loaded_pred)) + + def test_jit_save_data_parallel_with_inputspec(self): + layer = LinearNetNotDeclarative(784, 1) + layer = paddle.DataParallel(layer) + + path = "jit_save_data_parallel_with_inputspec/model" + paddle.jit.save( + layer=layer, path=path, input_spec=[InputSpec(shape=[None, 784])]) + + self.verify_inference_correctness(layer, path) + + def test_jit_save_data_parallel_with_to_static(self): + layer = LinearNetWithInputSpec(784, 1) + layer = paddle.DataParallel(layer) + + path = "jit_save_data_parallel_with_to_static/model" + paddle.jit.save(layer, path) + + self.verify_inference_correctness(layer, path) + + if __name__ == '__main__': unittest.main() -- GitLab