未验证 提交 7a3a05cc 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Support to save model with nested output (#28224)

上级 4671d85a
...@@ -25,6 +25,7 @@ import paddle ...@@ -25,6 +25,7 @@ import paddle
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity
...@@ -397,7 +398,7 @@ def _get_output_vars(outputs, output_spec): ...@@ -397,7 +398,7 @@ def _get_output_vars(outputs, output_spec):
"Layer.forward method." "Layer.forward method."
result_list = [] result_list = []
output_vars_dict = OrderedDict() output_vars_dict = OrderedDict()
for var in outputs: for var in flatten(outputs):
if isinstance(var, Variable): if isinstance(var, Variable):
output_vars_dict[var.name] = var output_vars_dict[var.name] = var
if output_spec is None: if output_spec is None:
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import paddle import paddle
from paddle.static import InputSpec from paddle.static import InputSpec
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph import Linear
from paddle.fluid.dygraph import declarative, ProgramTranslator from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
...@@ -153,6 +154,21 @@ class LinearNetReturnHidden(fluid.dygraph.Layer): ...@@ -153,6 +154,21 @@ class LinearNetReturnHidden(fluid.dygraph.Layer):
return y, loss return y, loss
class LinearNetWithNestOut(fluid.dygraph.Layer):
def __init__(self, in_size, out_size):
super(LinearNetWithNestOut, self).__init__()
self._linear_1 = Linear(in_size, out_size)
self._linear_2 = Linear(in_size, out_size)
@declarative
def forward(self, x):
y = self._linear_1(x)
z = self._linear_2(y)
out = y + z
loss = fluid.layers.mean(out)
return y, [(z, loss), out]
class EmptyLayer(paddle.nn.Layer): class EmptyLayer(paddle.nn.Layer):
def __init__(self): def __init__(self):
super(EmptyLayer, self).__init__() super(EmptyLayer, self).__init__()
...@@ -299,6 +315,30 @@ class TestJitSaveLoad(unittest.TestCase): ...@@ -299,6 +315,30 @@ class TestJitSaveLoad(unittest.TestCase):
loaded_layer = paddle.jit.load(path) loaded_layer = paddle.jit.load(path)
class TestSaveLoadWithNestOut(unittest.TestCase):
def setUp(self):
# enable dygraph mode
fluid.enable_dygraph()
def test_nest_output(self):
x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32'))
net = LinearNetWithNestOut(8, 8)
dy_outs = flatten(net(x))
net = declarative(net, input_spec=[InputSpec([None, 8], name='x')])
model_path = "net_with_nest_out/model"
paddle.jit.save(net, model_path)
load_net = paddle.jit.load(model_path)
load_outs = flatten(load_net(x))
self.assertTrue(len(dy_outs) == 4)
for dy_out, load_out in zip(dy_outs, load_outs):
self.assertTrue(np.allclose(dy_out.numpy(), load_out.numpy()))
class TestSaveLoadWithInputSpec(unittest.TestCase): class TestSaveLoadWithInputSpec(unittest.TestCase):
def setUp(self): def setUp(self):
# enable dygraph mode # enable dygraph mode
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册