未验证 提交 6a9610ed 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2Stat]Support Nest sequtial container (#34246)

* support Nest sequtial container

* rename model path
上级 6883403f
......@@ -88,7 +88,7 @@ def is_unsupported(func):
for v in m.__dict__.values():
func_in_dict = func == v
if isinstance(func_in_dict, (list, numpy.ndarray)):
func_in_dict = any(func_in_dict)
func_in_dict = numpy.array(func_in_dict).any()
if func_in_dict:
translator_logger.log(
2,
......
......@@ -47,10 +47,30 @@ class SequentialNet(paddle.nn.Layer):
return out
class NestSequentialNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
group1 = paddle.nn.Sequential(
paddle.nn.Linear(10, 10),
paddle.nn.Sigmoid(), )
group2 = paddle.nn.Sequential(
paddle.nn.Linear(10, 3),
paddle.nn.ReLU(), )
self.layers = paddle.nn.Sequential(group1, group2)
def forward(self, x):
return self.layers(x)
class TestSequential(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
self.seed = 2021
self._init_config()
def _init_config(self):
self.net = SequentialNet(BufferLayers, 10, 3)
self.model_path = './sequential_net'
def _init_seed(self):
paddle.seed(self.seed)
......@@ -58,13 +78,12 @@ class TestSequential(unittest.TestCase):
def _run(self, to_static):
self._init_seed()
net = SequentialNet(BufferLayers, 10, 3)
if to_static:
net = paddle.jit.to_static(net)
self.net = paddle.jit.to_static(self.net)
x = paddle.rand([16, 10], 'float32')
out = net(x)
out = self.net(x)
if to_static:
load_out = self._test_load(net, x)
load_out = self._test_load(self.net, x)
self.assertTrue(
np.allclose(load_out, out),
msg='load_out is {}\st_out is {}'.format(load_out, out))
......@@ -80,12 +99,17 @@ class TestSequential(unittest.TestCase):
msg='dygraph_res is {}\nstatic_res is {}'.format(dy_out, st_out))
def _test_load(self, net, x):
model_path = './sequential_net'
paddle.jit.save(net, model_path)
load_net = paddle.jit.load(model_path)
paddle.jit.save(net, self.model_path)
load_net = paddle.jit.load(self.model_path)
out = load_net(x)
return out
class TestNestSequential(TestSequential):
def _init_config(self):
self.net = NestSequentialNet()
self.model_path = './nested_sequential_net'
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册