From 3718b2e706dd6b86c224d757f5e714699798d171 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Sat, 17 Oct 2020 08:48:02 +0800 Subject: [PATCH] Fix test_lstm unittest failed and Add more unittest (#28029) * fix test_lstm unittest failed * add more unittest * modify cmakelist * fix judgement --- .../dygraph_to_static/program_translator.py | 2 +- .../dygraph_to_static/CMakeLists.txt | 1 - .../unittests/dygraph_to_static/test_lstm.py | 43 +++++++++++++++++-- 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index ba76af98e41..2ff3fe833d6 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -627,7 +627,7 @@ def _extract_indeed_params_buffers(class_instance): """ params = list(get_parameters(class_instance).values()) buffers = list(get_buffers(class_instance).values()) - buffers = [buffer for buffer in buffers if buffer.shape != []] + buffers = [buffer for buffer in buffers if len(buffer.shape) != 0] return params + buffers diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt b/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt index 821b5bac297..629716cc315 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt @@ -1,7 +1,6 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -list(REMOVE_ITEM TEST_OPS test_lstm) foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach(TEST_OP) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py index 1ed06f24bd0..279c44d3245 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py @@ -24,7 +24,6 @@ class Net(nn.Layer): self.lstm = nn.LSTM( in_channels, hidden_size, direction='bidirectional', num_layers=2) - @paddle.jit.to_static def forward(self, x): x, _ = self.lstm(x) return x @@ -39,6 +38,7 @@ class TestLstm(unittest.TestCase): paddle.static.default_startup_program().random_seed = 1001 net = Net(12, 2) + net = paddle.jit.to_static(net) x = paddle.zeros((2, 10, 12)) y = net(paddle.to_tensor(x)) return y.numpy() @@ -54,16 +54,17 @@ class TestLstm(unittest.TestCase): def test_save_in_eval(self): paddle.jit.ProgramTranslator().enable(True) net = Net(12, 2) + x = paddle.randn((2, 10, 12)) + dygraph_out = net(x) # switch eval mode firstly net.eval() + net = paddle.jit.to_static( net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])]) paddle.jit.save(net, 'simple_lstm') # load saved model load_net = paddle.jit.load('simple_lstm') - x = paddle.randn((2, 10, 12)) - dygraph_out = net(x) static_out = load_net(x) self.assertTrue( np.allclose(dygraph_out.numpy(), static_out.numpy()), @@ -78,5 +79,41 @@ class TestLstm(unittest.TestCase): train_out)) +class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self.fc = nn.Linear(10, 12) + self.dropout = nn.Dropout(0.5) + + @paddle.jit.to_static + def forward(self, x): + y = self.fc(x) + y = self.dropout(y) + return y + + +class TestSaveInEvalMode(unittest.TestCase): + def test_save_in_eval(self): + paddle.jit.ProgramTranslator().enable(True) + net = LinearNet() + # switch eval mode firstly + net.eval() + # save directly + net = paddle.jit.to_static( + net, input_spec=[paddle.static.InputSpec(shape=[-1, 10])]) + paddle.jit.save(net, 'linear_net') + # load saved model + load_net = paddle.jit.load('linear_net') + + x = paddle.randn((2, 10)) + eval_out = net(x) + + infer_out = load_net(x) + self.assertTrue( + np.allclose(eval_out.numpy(), infer_out.numpy()), + msg='eval_out is {}\n infer_out is \n{}'.format(eval_out, + infer_out)) + + if __name__ == "__main__": unittest.main() -- GitLab