未验证 提交 7a3a05cc 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Support to save model with nested output (#28224)

上级 4671d85a
......@@ -25,6 +25,7 @@ import paddle
from paddle.fluid import core
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type
from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity
......@@ -397,7 +398,7 @@ def _get_output_vars(outputs, output_spec):
"Layer.forward method."
result_list = []
output_vars_dict = OrderedDict()
for var in outputs:
for var in flatten(outputs):
if isinstance(var, Variable):
output_vars_dict[var.name] = var
if output_spec is None:
......
......@@ -21,6 +21,7 @@ import numpy as np
import paddle
from paddle.static import InputSpec
import paddle.fluid as fluid
from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph import Linear
from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
......@@ -153,6 +154,21 @@ class LinearNetReturnHidden(fluid.dygraph.Layer):
return y, loss
class LinearNetWithNestOut(fluid.dygraph.Layer):
def __init__(self, in_size, out_size):
super(LinearNetWithNestOut, self).__init__()
self._linear_1 = Linear(in_size, out_size)
self._linear_2 = Linear(in_size, out_size)
@declarative
def forward(self, x):
y = self._linear_1(x)
z = self._linear_2(y)
out = y + z
loss = fluid.layers.mean(out)
return y, [(z, loss), out]
class EmptyLayer(paddle.nn.Layer):
def __init__(self):
super(EmptyLayer, self).__init__()
......@@ -299,6 +315,30 @@ class TestJitSaveLoad(unittest.TestCase):
loaded_layer = paddle.jit.load(path)
class TestSaveLoadWithNestOut(unittest.TestCase):
def setUp(self):
# enable dygraph mode
fluid.enable_dygraph()
def test_nest_output(self):
x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32'))
net = LinearNetWithNestOut(8, 8)
dy_outs = flatten(net(x))
net = declarative(net, input_spec=[InputSpec([None, 8], name='x')])
model_path = "net_with_nest_out/model"
paddle.jit.save(net, model_path)
load_net = paddle.jit.load(model_path)
load_outs = flatten(load_net(x))
self.assertTrue(len(dy_outs) == 4)
for dy_out, load_out in zip(dy_outs, load_outs):
self.assertTrue(np.allclose(dy_out.numpy(), load_out.numpy()))
class TestSaveLoadWithInputSpec(unittest.TestCase):
def setUp(self):
# enable dygraph mode
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册