未验证 提交 95a0f87b 编写于 作者: C Chen Weihang 提交者: GitHub

support jit.save datra parallel (#29135)

上级 449903de
......@@ -581,6 +581,16 @@ def save(layer, path, input_spec=None, **configs):
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
% type(layer))
# NOTE(chenweihang): If the input layer be wrapped by DataParallel,
# the args and kwargs of forward method will can't be parsed by
# function_spec, so here we save DataParallel._layers instead
# DataParallel it self
# NOTE(chenweihang): using inner_layer, do not change input layer
if isinstance(layer, paddle.DataParallel):
inner_layer = layer._layers
else:
inner_layer = layer
# path check
file_prefix = os.path.basename(path)
if file_prefix == "":
......@@ -596,8 +606,8 @@ def save(layer, path, input_spec=None, **configs):
# avoid change user given input_spec
inner_input_spec = None
if input_spec is not None:
for attr_func in dir(layer):
static_func = getattr(layer, attr_func, None)
for attr_func in dir(inner_layer):
static_func = getattr(inner_layer, attr_func, None)
if isinstance(static_func,
StaticFunction) and 'forward' != attr_func:
raise ValueError(
......@@ -623,14 +633,14 @@ def save(layer, path, input_spec=None, **configs):
configs = _parse_save_configs(configs)
scope = core.Scope()
extra_var_info = dict()
for attr_func in dir(layer):
static_func = getattr(layer, attr_func, None)
for attr_func in dir(inner_layer):
static_func = getattr(inner_layer, attr_func, None)
if isinstance(static_func, StaticFunction):
concrete_program = static_func.concrete_program
elif 'forward' == attr_func:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
static_forward = declarative(
layer.forward, input_spec=inner_input_spec)
inner_layer.forward, input_spec=inner_input_spec)
concrete_program = static_forward.concrete_program
# the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec,
......@@ -663,7 +673,7 @@ def save(layer, path, input_spec=None, **configs):
# saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name
state_names_dict = dict()
for structured_name, var in six.iteritems(layer.state_dict()):
for structured_name, var in six.iteritems(inner_layer.state_dict()):
state_names_dict[var.name] = structured_name
# 4. share parameters from Layer to scope & record var info
......
......@@ -863,5 +863,39 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase):
layer, model_path, input_spec=[InputSpec(shape=[None, 784])])
class TestJitSaveLoadDataParallel(unittest.TestCase):
def verify_inference_correctness(self, layer, path):
layer.eval()
loaded_layer = paddle.jit.load(path)
loaded_layer.eval()
# inference & compare
x = paddle.to_tensor(np.random.random((1, 784)).astype('float32'))
pred = layer(x).numpy()
loaded_pred = loaded_layer(x).numpy()
self.assertTrue(
np.array_equal(pred, loaded_pred),
msg="Result diff when load and inference:\nlayer result:\n{}\n" \
"loaded layer result:\n{}".format(pred, loaded_pred))
def test_jit_save_data_parallel_with_inputspec(self):
layer = LinearNetNotDeclarative(784, 1)
layer = paddle.DataParallel(layer)
path = "jit_save_data_parallel_with_inputspec/model"
paddle.jit.save(
layer=layer, path=path, input_spec=[InputSpec(shape=[None, 784])])
self.verify_inference_correctness(layer, path)
def test_jit_save_data_parallel_with_to_static(self):
layer = LinearNetWithInputSpec(784, 1)
layer = paddle.DataParallel(layer)
path = "jit_save_data_parallel_with_to_static/model"
paddle.jit.save(layer, path)
self.verify_inference_correctness(layer, path)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册