diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py index ab1127afa58dd93aa92688eebdf82292990f59b1..cade4b850cd1d635b2c480d941c5411cdd6c0e71 100644 --- a/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py @@ -60,9 +60,16 @@ class TestSimpleRNNCell(unittest.TestCase): y2, h2 = rnn2(paddle.to_tensor(x)) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + def test_errors(self): + def test_zero_hidden_size(): + cell = paddle.nn.SimpleRNNCell(-1, 0) + + self.assertRaises(ValueError, test_zero_hidden_size) + def runTest(self): self.test_with_initial_state() self.test_with_zero_state() + self.test_errors() class TestGRUCell(unittest.TestCase): @@ -103,9 +110,16 @@ class TestGRUCell(unittest.TestCase): y2, h2 = rnn2(paddle.to_tensor(x)) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + def test_errors(self): + def test_zero_hidden_size(): + cell = paddle.nn.GRUCell(-1, 0) + + self.assertRaises(ValueError, test_zero_hidden_size) + def runTest(self): self.test_with_initial_state() self.test_with_zero_state() + self.test_errors() class TestLSTMCell(unittest.TestCase): @@ -150,9 +164,16 @@ class TestLSTMCell(unittest.TestCase): np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + def test_errors(self): + def test_zero_hidden_size(): + cell = paddle.nn.LSTMCell(-1, 0) + + self.assertRaises(ValueError, test_zero_hidden_size) + def runTest(self): self.test_with_initial_state() self.test_with_zero_state() + self.test_errors() def load_tests(loader, tests, pattern): diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 77168566d88c6055bdce3a8f168b102a1ef29343..fbb648af42a337419dac44d666425172fd368032 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -332,6 +332,10 @@ class SimpleRNNCell(RNNCellBase): bias_hh_attr=None, name=None): super(SimpleRNNCell, self).__init__() + if hidden_size <= 0: + raise ValueError( + "hidden_size of {} must be greater than 0, but now equals to {}". + format(self.__class__.__name__, hidden_size)) std = 1.0 / math.sqrt(hidden_size) self.weight_ih = self.create_parameter( (hidden_size, input_size), @@ -480,6 +484,10 @@ class LSTMCell(RNNCellBase): bias_hh_attr=None, name=None): super(LSTMCell, self).__init__() + if hidden_size <= 0: + raise ValueError( + "hidden_size of {} must be greater than 0, but now equals to {}". + format(self.__class__.__name__, hidden_size)) std = 1.0 / math.sqrt(hidden_size) self.weight_ih = self.create_parameter( (4 * hidden_size, input_size), @@ -627,6 +635,10 @@ class GRUCell(RNNCellBase): bias_hh_attr=None, name=None): super(GRUCell, self).__init__() + if hidden_size <= 0: + raise ValueError( + "hidden_size of {} must be greater than 0, but now equals to {}". + format(self.__class__.__name__, hidden_size)) std = 1.0 / math.sqrt(hidden_size) self.weight_ih = self.create_parameter( (3 * hidden_size, input_size),