未验证 提交 e1f8617e 编写于 作者: F Feiyu Chan 提交者: GitHub

bugfix: RNN does not initialize the state for the cell correctly (#27644)

1. fix a bug that paddle.nn.RNN does not initialize the state for the cell correctly;
2. add unittest for paddle.nn.RNN and paddle.nn.BiRNN
上级 5e4f01f5
...@@ -531,6 +531,10 @@ def _rnn_dynamic_graph(cell, ...@@ -531,6 +531,10 @@ def _rnn_dynamic_graph(cell,
flat_inputs = flatten(inputs) flat_inputs = flatten(inputs)
time_steps = flat_inputs[0].shape[time_step_index] time_steps = flat_inputs[0].shape[time_step_index]
if initial_states is None:
initial_states = cell.get_initial_states(
batch_ref=inputs, batch_dim_idx=1 if time_major else 0)
if not time_major: if not time_major:
inputs = map_structure(_transpose_batch_time, inputs) inputs = map_structure(_transpose_batch_time, inputs)
......
...@@ -61,8 +61,8 @@ class SimpleRNNCell(LayerMixin): ...@@ -61,8 +61,8 @@ class SimpleRNNCell(LayerMixin):
self.bias_ih = None self.bias_ih = None
self.bias_hh = None self.bias_hh = None
def init_state(self, inputs): def init_state(self, inputs, batch_dim_index=0):
batch_size = inputs.shape[0] batch_size = inputs.shape[batch_dim_index]
return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)
def forward(self, inputs, hx=None): def forward(self, inputs, hx=None):
...@@ -103,8 +103,8 @@ class GRUCell(LayerMixin): ...@@ -103,8 +103,8 @@ class GRUCell(LayerMixin):
self.bias_ih = None self.bias_ih = None
self.bias_hh = None self.bias_hh = None
def init_state(self, inputs): def init_state(self, inputs, batch_dim_index=0):
batch_size = inputs.shape[0] batch_size = inputs.shape[batch_dim_index]
return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)
def forward(self, inputs, hx=None): def forward(self, inputs, hx=None):
...@@ -117,7 +117,6 @@ class GRUCell(LayerMixin): ...@@ -117,7 +117,6 @@ class GRUCell(LayerMixin):
h_gates = np.matmul(pre_hidden, self.weight_hh.T) h_gates = np.matmul(pre_hidden, self.weight_hh.T)
if self.bias_hh is not None: if self.bias_hh is not None:
h_gates = h_gates + self.bias_hh h_gates = h_gates + self.bias_hh
x_r, x_z, x_c = np.split(x_gates, 3, 1) x_r, x_z, x_c = np.split(x_gates, 3, 1)
h_r, h_z, h_c = np.split(h_gates, 3, 1) h_r, h_z, h_c = np.split(h_gates, 3, 1)
...@@ -152,8 +151,8 @@ class LSTMCell(LayerMixin): ...@@ -152,8 +151,8 @@ class LSTMCell(LayerMixin):
self.bias_ih = None self.bias_ih = None
self.bias_hh = None self.bias_hh = None
def init_state(self, inputs): def init_state(self, inputs, batch_dim_index=0):
batch_size = inputs.shape[0] batch_size = inputs.shape[batch_dim_index]
init_h = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) init_h = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)
init_c = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) init_c = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)
return init_h, init_c return init_h, init_c
...@@ -206,6 +205,9 @@ def rnn(cell, ...@@ -206,6 +205,9 @@ def rnn(cell,
if is_reverse: if is_reverse:
inputs = np.flip(inputs, 0) inputs = np.flip(inputs, 0)
if initial_states is None:
initial_states = cell.init_state(inputs, 1)
if sequence_length is None: if sequence_length is None:
mask = None mask = None
else: else:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import paddle import paddle
paddle.framework.set_default_dtype("float64") paddle.framework.set_default_dtype("float64")
paddle.enable_static()
import numpy as np import numpy as np
import unittest import unittest
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import paddle import paddle
paddle.set_default_dtype("float64") paddle.set_default_dtype("float64")
from paddle.fluid.layers import sequence_mask from paddle.fluid.layers import sequence_mask
paddle.enable_static()
import numpy as np import numpy as np
import unittest import unittest
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
paddle.set_default_dtype("float64")
from paddle.fluid.layers import sequence_mask
import numpy as np
import unittest
from convert import convert_params_for_cell
from rnn_numpy import GRUCell, RNN, BiRNN
class TestRNNWrapper(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"):
super(TestRNNWrapper, self).__init__("runTest")
self.time_major = time_major
self.direction = direction
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
paddle.disable_static(self.place)
cell1 = GRUCell(16, 32)
cell2 = paddle.nn.GRUCell(16, 32)
convert_params_for_cell(cell1, cell2)
rnn1 = RNN(cell1,
is_reverse=self.direction == "backward",
time_major=self.time_major)
rnn2 = paddle.nn.RNN(cell2,
is_reverse=self.direction == "backward",
time_major=self.time_major)
self.rnn1 = rnn1
self.rnn2 = rnn2
def test_with_initial_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
prev_h = np.random.randn(4, 32)
y1, h1 = rnn1(x, prev_h)
y2, h2 = rnn2(paddle.to_tensor(x), paddle.to_tensor(prev_h))
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
y1, h1 = rnn1(x)
y2, h2 = rnn2(paddle.to_tensor(x))
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_input_lengths(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
sequence_length = np.array([12, 10, 9, 8], dtype=np.int64)
y1, h1 = rnn1(x, sequence_length=sequence_length)
seq_len = paddle.to_tensor(sequence_length)
mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype())
if self.time_major:
mask = paddle.transpose(mask, [1, 0])
y2, h2 = rnn2(paddle.to_tensor(x), sequence_length=seq_len)
y2 = paddle.multiply(y2, mask, axis=0)
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
self.test_with_input_lengths()
class TestBiRNNWrapper(unittest.TestCase):
def __init__(self, time_major=True, place="cpu"):
super(TestBiRNNWrapper, self).__init__("runTest")
self.time_major = time_major
self.place = paddle.CPUPlace() if place == "cpu" \
else paddle.CUDAPlace(0)
def setUp(self):
paddle.disable_static(self.place)
fw_cell1 = GRUCell(16, 32)
bw_cell1 = GRUCell(16, 32)
fw_cell2 = paddle.nn.GRUCell(16, 32)
bw_cell2 = paddle.nn.GRUCell(16, 32)
convert_params_for_cell(fw_cell1, fw_cell2)
convert_params_for_cell(bw_cell1, bw_cell2)
rnn1 = BiRNN(fw_cell1, bw_cell1, time_major=self.time_major)
rnn2 = paddle.nn.BiRNN(fw_cell2, bw_cell2, time_major=self.time_major)
self.rnn1 = rnn1
self.rnn2 = rnn2
def test_with_initial_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
fw_prev_h = np.random.randn(4, 32)
bw_prev_h = np.random.randn(4, 32)
y1, (fw_h1, bw_h1) = rnn1(x, (fw_prev_h, bw_prev_h))
y2, (fw_h2, bw_h2) = rnn2(
paddle.to_tensor(x),
(paddle.to_tensor(fw_prev_h), paddle.to_tensor(bw_prev_h)))
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(fw_h1, fw_h2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(bw_h1, bw_h2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_zero_state(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
y1, (fw_h1, bw_h1) = rnn1(x)
y2, (fw_h2, bw_h2) = rnn2(paddle.to_tensor(x))
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(fw_h1, fw_h2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(bw_h1, bw_h2.numpy(), atol=1e-8, rtol=1e-5)
def test_with_input_lengths(self):
rnn1 = self.rnn1
rnn2 = self.rnn2
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
sequence_length = np.array([12, 10, 9, 8], dtype=np.int64)
y1, (fw_h1, bw_h1) = rnn1(x, sequence_length=sequence_length)
seq_len = paddle.to_tensor(sequence_length)
mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype())
if self.time_major:
mask = paddle.transpose(mask, [1, 0])
y2, (fw_h2, bw_h2) = rnn2(paddle.to_tensor(x), sequence_length=seq_len)
y2 = paddle.multiply(y2, mask, axis=0)
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(fw_h1, fw_h2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(bw_h1, bw_h2.numpy(), atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
self.test_with_input_lengths()
def load_tests(loader, tests, pattern):
suite = unittest.TestSuite()
devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \
else ["cpu"]
for direction in ["forward", "backward"]:
for device in devices:
for time_major in [False]:
suite.addTest(TestRNNWrapper(time_major, direction, device))
suite.addTest(TestBiRNNWrapper(time_major, device))
return suite
...@@ -288,7 +288,7 @@ class SimpleRNNCell(RNNCellBase): ...@@ -288,7 +288,7 @@ class SimpleRNNCell(RNNCellBase):
`weight_hh`. Default: None. `weight_hh`. Default: None.
bias_ih_attr (ParamAttr, optional): The parameter attribute for the bias_ih_attr (ParamAttr, optional): The parameter attribute for the
`bias_ih`. Default: None. `bias_ih`. Default: None.
bias_ih_attr (ParamAttr, optional): The parameter attribute for the bias_hh_attr (ParamAttr, optional): The parameter attribute for the
`bias_hh`. Default: None. `bias_hh`. Default: None.
name (str, optional): Name for the operation (optional, default is name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`. None). For more information, please refer to :ref:`api_guide_Name`.
...@@ -429,7 +429,7 @@ class LSTMCell(RNNCellBase): ...@@ -429,7 +429,7 @@ class LSTMCell(RNNCellBase):
`weight_hh`. Default: None. `weight_hh`. Default: None.
bias_ih_attr (ParamAttr, optional): The parameter attribute for the bias_ih_attr (ParamAttr, optional): The parameter attribute for the
`bias_ih`. Default: None. `bias_ih`. Default: None.
bias_ih_attr (ParamAttr, optional): The parameter attribute for the bias_hh_attr (ParamAttr, optional): The parameter attribute for the
`bias_hh`. Default: None. `bias_hh`. Default: None.
name (str, optional): Name for the operation (optional, default is name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`. None). For more information, please refer to :ref:`api_guide_Name`.
...@@ -582,7 +582,7 @@ class GRUCell(RNNCellBase): ...@@ -582,7 +582,7 @@ class GRUCell(RNNCellBase):
`weight_hh`. Default: None. `weight_hh`. Default: None.
bias_ih_attr (ParamAttr, optional): The parameter attribute for the bias_ih_attr (ParamAttr, optional): The parameter attribute for the
`bias_ih`. Default: None. `bias_ih`. Default: None.
bias_ih_attr (ParamAttr, optional): The parameter attribute for the bias_hh_attr (ParamAttr, optional): The parameter attribute for the
`bias_hh`. Default: None. `bias_hh`. Default: None.
name (str, optional): Name for the operation (optional, default is name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`. None). For more information, please refer to :ref:`api_guide_Name`.
...@@ -778,12 +778,6 @@ class RNN(Layer): ...@@ -778,12 +778,6 @@ class RNN(Layer):
initial_states=None, initial_states=None,
sequence_length=None, sequence_length=None,
**kwargs): **kwargs):
if initial_states is None:
initial_states = self.cell.get_initial_states(
batch_ref=inputs,
dtype=inputs.dtype,
batch_dim_idx=self.batch_index)
final_outputs, final_states = F.rnn(self.cell, final_outputs, final_states = F.rnn(self.cell,
inputs, inputs,
initial_states=initial_states, initial_states=initial_states,
...@@ -880,8 +874,6 @@ class BiRNN(Layer): ...@@ -880,8 +874,6 @@ class BiRNN(Layer):
if isinstance(initial_states, (list, tuple)): if isinstance(initial_states, (list, tuple)):
assert len(initial_states) == 2, \ assert len(initial_states) == 2, \
"length of initial_states should be 2 when it is a list/tuple" "length of initial_states should be 2 when it is a list/tuple"
else:
initial_states = [initial_states, initial_states]
outputs, final_states = F.birnn(self.cell_fw, self.cell_bw, inputs, outputs, final_states = F.birnn(self.cell_fw, self.cell_bw, inputs,
initial_states, sequence_length, initial_states, sequence_length,
...@@ -968,7 +960,7 @@ class SimpleRNN(RNNMixin): ...@@ -968,7 +960,7 @@ class SimpleRNN(RNNMixin):
`weight_hh` of each cell. Defaults to None. `weight_hh` of each cell. Defaults to None.
bias_ih_attr (ParamAttr, optional): The parameter attribute for the bias_ih_attr (ParamAttr, optional): The parameter attribute for the
`bias_ih` of each cells. Defaults to None. `bias_ih` of each cells. Defaults to None.
bias_ih_attr (ParamAttr, optional): The parameter attribute for the bias_hh_attr (ParamAttr, optional): The parameter attribute for the
`bias_hh` of each cells. Defaults to None. `bias_hh` of each cells. Defaults to None.
name (str, optional): Name for the operation (optional, default is name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`. None). For more information, please refer to :ref:`api_guide_Name`.
...@@ -1111,7 +1103,7 @@ class LSTM(RNNMixin): ...@@ -1111,7 +1103,7 @@ class LSTM(RNNMixin):
`weight_hh` of each cell. Default: None. `weight_hh` of each cell. Default: None.
bias_ih_attr (ParamAttr, optional): The parameter attribute for the bias_ih_attr (ParamAttr, optional): The parameter attribute for the
`bias_ih` of each cells. Default: None. `bias_ih` of each cells. Default: None.
bias_ih_attr (ParamAttr, optional): The parameter attribute for the bias_hh_attr (ParamAttr, optional): The parameter attribute for the
`bias_hh` of each cells. Default: None. `bias_hh` of each cells. Default: None.
name (str, optional): Name for the operation (optional, default is name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`. None). For more information, please refer to :ref:`api_guide_Name`.
...@@ -1247,7 +1239,7 @@ class GRU(RNNMixin): ...@@ -1247,7 +1239,7 @@ class GRU(RNNMixin):
`weight_hh` of each cell. Default: None. `weight_hh` of each cell. Default: None.
bias_ih_attr (ParamAttr, optional): The parameter attribute for the bias_ih_attr (ParamAttr, optional): The parameter attribute for the
`bias_ih` of each cells. Default: None. `bias_ih` of each cells. Default: None.
bias_ih_attr (ParamAttr, optional): The parameter attribute for the bias_hh_attr (ParamAttr, optional): The parameter attribute for the
`bias_hh` of each cells. Default: None. `bias_hh` of each cells. Default: None.
name (str, optional): Name for the operation (optional, default is name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`. None). For more information, please refer to :ref:`api_guide_Name`.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册