未验证 提交 4caf60df 编写于 作者: J Jack Zhou 提交者: GitHub

fix simple_rnn_cell, gru_cell and lstm_cell zero_div_error (#34627)

上级 21beef91
...@@ -60,9 +60,16 @@ class TestSimpleRNNCell(unittest.TestCase): ...@@ -60,9 +60,16 @@ class TestSimpleRNNCell(unittest.TestCase):
y2, h2 = rnn2(paddle.to_tensor(x)) y2, h2 = rnn2(paddle.to_tensor(x))
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def test_errors(self):
def test_zero_hidden_size():
cell = paddle.nn.SimpleRNNCell(-1, 0)
self.assertRaises(ValueError, test_zero_hidden_size)
def runTest(self): def runTest(self):
self.test_with_initial_state() self.test_with_initial_state()
self.test_with_zero_state() self.test_with_zero_state()
self.test_errors()
class TestGRUCell(unittest.TestCase): class TestGRUCell(unittest.TestCase):
...@@ -103,9 +110,16 @@ class TestGRUCell(unittest.TestCase): ...@@ -103,9 +110,16 @@ class TestGRUCell(unittest.TestCase):
y2, h2 = rnn2(paddle.to_tensor(x)) y2, h2 = rnn2(paddle.to_tensor(x))
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def test_errors(self):
def test_zero_hidden_size():
cell = paddle.nn.GRUCell(-1, 0)
self.assertRaises(ValueError, test_zero_hidden_size)
def runTest(self): def runTest(self):
self.test_with_initial_state() self.test_with_initial_state()
self.test_with_zero_state() self.test_with_zero_state()
self.test_errors()
class TestLSTMCell(unittest.TestCase): class TestLSTMCell(unittest.TestCase):
...@@ -150,9 +164,16 @@ class TestLSTMCell(unittest.TestCase): ...@@ -150,9 +164,16 @@ class TestLSTMCell(unittest.TestCase):
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)
def test_errors(self):
def test_zero_hidden_size():
cell = paddle.nn.LSTMCell(-1, 0)
self.assertRaises(ValueError, test_zero_hidden_size)
def runTest(self): def runTest(self):
self.test_with_initial_state() self.test_with_initial_state()
self.test_with_zero_state() self.test_with_zero_state()
self.test_errors()
def load_tests(loader, tests, pattern): def load_tests(loader, tests, pattern):
......
...@@ -332,6 +332,10 @@ class SimpleRNNCell(RNNCellBase): ...@@ -332,6 +332,10 @@ class SimpleRNNCell(RNNCellBase):
bias_hh_attr=None, bias_hh_attr=None,
name=None): name=None):
super(SimpleRNNCell, self).__init__() super(SimpleRNNCell, self).__init__()
if hidden_size <= 0:
raise ValueError(
"hidden_size of {} must be greater than 0, but now equals to {}".
format(self.__class__.__name__, hidden_size))
std = 1.0 / math.sqrt(hidden_size) std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = self.create_parameter( self.weight_ih = self.create_parameter(
(hidden_size, input_size), (hidden_size, input_size),
...@@ -480,6 +484,10 @@ class LSTMCell(RNNCellBase): ...@@ -480,6 +484,10 @@ class LSTMCell(RNNCellBase):
bias_hh_attr=None, bias_hh_attr=None,
name=None): name=None):
super(LSTMCell, self).__init__() super(LSTMCell, self).__init__()
if hidden_size <= 0:
raise ValueError(
"hidden_size of {} must be greater than 0, but now equals to {}".
format(self.__class__.__name__, hidden_size))
std = 1.0 / math.sqrt(hidden_size) std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = self.create_parameter( self.weight_ih = self.create_parameter(
(4 * hidden_size, input_size), (4 * hidden_size, input_size),
...@@ -627,6 +635,10 @@ class GRUCell(RNNCellBase): ...@@ -627,6 +635,10 @@ class GRUCell(RNNCellBase):
bias_hh_attr=None, bias_hh_attr=None,
name=None): name=None):
super(GRUCell, self).__init__() super(GRUCell, self).__init__()
if hidden_size <= 0:
raise ValueError(
"hidden_size of {} must be greater than 0, but now equals to {}".
format(self.__class__.__name__, hidden_size))
std = 1.0 / math.sqrt(hidden_size) std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = self.create_parameter( self.weight_ih = self.create_parameter(
(3 * hidden_size, input_size), (3 * hidden_size, input_size),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册