未验证 提交 3718b2e7 编写于 作者: A Aurelius84 提交者: GitHub

Fix test_lstm unittest failed and Add more unittest (#28029)

* fix test_lstm unittest failed

* add more unittest

* modify cmakelist

* fix judgement
上级 bf5325f3
...@@ -627,7 +627,7 @@ def _extract_indeed_params_buffers(class_instance): ...@@ -627,7 +627,7 @@ def _extract_indeed_params_buffers(class_instance):
""" """
params = list(get_parameters(class_instance).values()) params = list(get_parameters(class_instance).values())
buffers = list(get_buffers(class_instance).values()) buffers = list(get_buffers(class_instance).values())
buffers = [buffer for buffer in buffers if buffer.shape != []] buffers = [buffer for buffer in buffers if len(buffer.shape) != 0]
return params + buffers return params + buffers
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
list(REMOVE_ITEM TEST_OPS test_lstm)
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP) endforeach(TEST_OP)
......
...@@ -24,7 +24,6 @@ class Net(nn.Layer): ...@@ -24,7 +24,6 @@ class Net(nn.Layer):
self.lstm = nn.LSTM( self.lstm = nn.LSTM(
in_channels, hidden_size, direction='bidirectional', num_layers=2) in_channels, hidden_size, direction='bidirectional', num_layers=2)
@paddle.jit.to_static
def forward(self, x): def forward(self, x):
x, _ = self.lstm(x) x, _ = self.lstm(x)
return x return x
...@@ -39,6 +38,7 @@ class TestLstm(unittest.TestCase): ...@@ -39,6 +38,7 @@ class TestLstm(unittest.TestCase):
paddle.static.default_startup_program().random_seed = 1001 paddle.static.default_startup_program().random_seed = 1001
net = Net(12, 2) net = Net(12, 2)
net = paddle.jit.to_static(net)
x = paddle.zeros((2, 10, 12)) x = paddle.zeros((2, 10, 12))
y = net(paddle.to_tensor(x)) y = net(paddle.to_tensor(x))
return y.numpy() return y.numpy()
...@@ -54,16 +54,17 @@ class TestLstm(unittest.TestCase): ...@@ -54,16 +54,17 @@ class TestLstm(unittest.TestCase):
def test_save_in_eval(self): def test_save_in_eval(self):
paddle.jit.ProgramTranslator().enable(True) paddle.jit.ProgramTranslator().enable(True)
net = Net(12, 2) net = Net(12, 2)
x = paddle.randn((2, 10, 12))
dygraph_out = net(x)
# switch eval mode firstly # switch eval mode firstly
net.eval() net.eval()
net = paddle.jit.to_static( net = paddle.jit.to_static(
net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])]) net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])])
paddle.jit.save(net, 'simple_lstm') paddle.jit.save(net, 'simple_lstm')
# load saved model # load saved model
load_net = paddle.jit.load('simple_lstm') load_net = paddle.jit.load('simple_lstm')
x = paddle.randn((2, 10, 12))
dygraph_out = net(x)
static_out = load_net(x) static_out = load_net(x)
self.assertTrue( self.assertTrue(
np.allclose(dygraph_out.numpy(), static_out.numpy()), np.allclose(dygraph_out.numpy(), static_out.numpy()),
...@@ -78,5 +79,41 @@ class TestLstm(unittest.TestCase): ...@@ -78,5 +79,41 @@ class TestLstm(unittest.TestCase):
train_out)) train_out))
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Linear(10, 12)
self.dropout = nn.Dropout(0.5)
@paddle.jit.to_static
def forward(self, x):
y = self.fc(x)
y = self.dropout(y)
return y
class TestSaveInEvalMode(unittest.TestCase):
def test_save_in_eval(self):
paddle.jit.ProgramTranslator().enable(True)
net = LinearNet()
# switch eval mode firstly
net.eval()
# save directly
net = paddle.jit.to_static(
net, input_spec=[paddle.static.InputSpec(shape=[-1, 10])])
paddle.jit.save(net, 'linear_net')
# load saved model
load_net = paddle.jit.load('linear_net')
x = paddle.randn((2, 10))
eval_out = net(x)
infer_out = load_net(x)
self.assertTrue(
np.allclose(eval_out.numpy(), infer_out.numpy()),
msg='eval_out is {}\n infer_out is \n{}'.format(eval_out,
infer_out))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册