From 4caf60dfd87974bdce233ab95f98ab266a77431c Mon Sep 17 00:00:00 2001 From: Jack Zhou Date: Fri, 6 Aug 2021 17:40:42 +0800 Subject: [PATCH] fix simple_rnn_cell, gru_cell and lstm_cell zero_div_error (#34627) --- .../tests/unittests/rnn/test_rnn_cells.py | 21 +++++++++++++++++++ python/paddle/nn/layer/rnn.py | 12 +++++++++++ 2 files changed, 33 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py index ab1127afa5..cade4b850c 100644 --- a/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py @@ -60,9 +60,16 @@ class TestSimpleRNNCell(unittest.TestCase): y2, h2 = rnn2(paddle.to_tensor(x)) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + def test_errors(self): + def test_zero_hidden_size(): + cell = paddle.nn.SimpleRNNCell(-1, 0) + + self.assertRaises(ValueError, test_zero_hidden_size) + def runTest(self): self.test_with_initial_state() self.test_with_zero_state() + self.test_errors() class TestGRUCell(unittest.TestCase): @@ -103,9 +110,16 @@ class TestGRUCell(unittest.TestCase): y2, h2 = rnn2(paddle.to_tensor(x)) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + def test_errors(self): + def test_zero_hidden_size(): + cell = paddle.nn.GRUCell(-1, 0) + + self.assertRaises(ValueError, test_zero_hidden_size) + def runTest(self): self.test_with_initial_state() self.test_with_zero_state() + self.test_errors() class TestLSTMCell(unittest.TestCase): @@ -150,9 +164,16 @@ class TestLSTMCell(unittest.TestCase): np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) + def test_errors(self): + def test_zero_hidden_size(): + cell = paddle.nn.LSTMCell(-1, 0) + + self.assertRaises(ValueError, test_zero_hidden_size) + def runTest(self): self.test_with_initial_state() self.test_with_zero_state() + self.test_errors() def load_tests(loader, tests, pattern): diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 77168566d8..fbb648af42 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -332,6 +332,10 @@ class SimpleRNNCell(RNNCellBase): bias_hh_attr=None, name=None): super(SimpleRNNCell, self).__init__() + if hidden_size <= 0: + raise ValueError( + "hidden_size of {} must be greater than 0, but now equals to {}". + format(self.__class__.__name__, hidden_size)) std = 1.0 / math.sqrt(hidden_size) self.weight_ih = self.create_parameter( (hidden_size, input_size), @@ -480,6 +484,10 @@ class LSTMCell(RNNCellBase): bias_hh_attr=None, name=None): super(LSTMCell, self).__init__() + if hidden_size <= 0: + raise ValueError( + "hidden_size of {} must be greater than 0, but now equals to {}". + format(self.__class__.__name__, hidden_size)) std = 1.0 / math.sqrt(hidden_size) self.weight_ih = self.create_parameter( (4 * hidden_size, input_size), @@ -627,6 +635,10 @@ class GRUCell(RNNCellBase): bias_hh_attr=None, name=None): super(GRUCell, self).__init__() + if hidden_size <= 0: + raise ValueError( + "hidden_size of {} must be greater than 0, but now equals to {}". + format(self.__class__.__name__, hidden_size)) std = 1.0 / math.sqrt(hidden_size) self.weight_ih = self.create_parameter( (3 * hidden_size, input_size), -- GitLab