diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index fe8ed83923e88be2a0c98a8a539f26500b43b7cb..3f36d47159cd1297d8464c96c8fe60de9c7a85d3 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -531,6 +531,10 @@ def _rnn_dynamic_graph(cell, flat_inputs = flatten(inputs) time_steps = flat_inputs[0].shape[time_step_index] + if initial_states is None: + initial_states = cell.get_initial_states( + batch_ref=inputs, batch_dim_idx=1 if time_major else 0) + if not time_major: inputs = map_structure(_transpose_batch_time, inputs) diff --git a/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py b/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py index 7e0b8374b95cf334b4eced550a79d7c717c07aa7..317be28da43e31cd2bc09423c48ca60fde06315a 100644 --- a/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py +++ b/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py @@ -61,8 +61,8 @@ class SimpleRNNCell(LayerMixin): self.bias_ih = None self.bias_hh = None - def init_state(self, inputs): - batch_size = inputs.shape[0] + def init_state(self, inputs, batch_dim_index=0): + batch_size = inputs.shape[batch_dim_index] return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) def forward(self, inputs, hx=None): @@ -103,8 +103,8 @@ class GRUCell(LayerMixin): self.bias_ih = None self.bias_hh = None - def init_state(self, inputs): - batch_size = inputs.shape[0] + def init_state(self, inputs, batch_dim_index=0): + batch_size = inputs.shape[batch_dim_index] return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) def forward(self, inputs, hx=None): @@ -117,7 +117,6 @@ class GRUCell(LayerMixin): h_gates = np.matmul(pre_hidden, self.weight_hh.T) if self.bias_hh is not None: h_gates = h_gates + self.bias_hh - x_r, x_z, x_c = np.split(x_gates, 3, 1) h_r, h_z, h_c = np.split(h_gates, 3, 1) @@ -152,8 +151,8 @@ class LSTMCell(LayerMixin): self.bias_ih = None self.bias_hh = None - def init_state(self, inputs): - batch_size = inputs.shape[0] + def init_state(self, inputs, batch_dim_index=0): + batch_size = inputs.shape[batch_dim_index] init_h = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) init_c = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype) return init_h, init_c @@ -206,6 +205,9 @@ def rnn(cell, if is_reverse: inputs = np.flip(inputs, 0) + if initial_states is None: + initial_states = cell.init_state(inputs, 1) + if sequence_length is None: mask = None else: diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells_static.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells_static.py index 948e47d5b99462c363015936f84058e222d548e2..41c252c2aa0a74b98e42b5c92dcbf886c057e20c 100644 --- a/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells_static.py +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_cells_static.py @@ -14,6 +14,7 @@ import paddle paddle.framework.set_default_dtype("float64") +paddle.enable_static() import numpy as np import unittest diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py index 90ed6b8b4c9075f5a3e3925bb80e24c81a37869c..71a0b5b7bcb3407c187c6ebb89fc5758cd098f3a 100644 --- a/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_nets_static.py @@ -15,6 +15,7 @@ import paddle paddle.set_default_dtype("float64") from paddle.fluid.layers import sequence_mask +paddle.enable_static() import numpy as np import unittest diff --git a/python/paddle/fluid/tests/unittests/rnn/test_wrappers.py b/python/paddle/fluid/tests/unittests/rnn/test_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..0fa76c9bcb1b761570ceedd5012728f5ed3ca017 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/test_wrappers.py @@ -0,0 +1,193 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +paddle.set_default_dtype("float64") +from paddle.fluid.layers import sequence_mask + +import numpy as np +import unittest + +from convert import convert_params_for_cell +from rnn_numpy import GRUCell, RNN, BiRNN + + +class TestRNNWrapper(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super(TestRNNWrapper, self).__init__("runTest") + self.time_major = time_major + self.direction = direction + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + paddle.disable_static(self.place) + cell1 = GRUCell(16, 32) + cell2 = paddle.nn.GRUCell(16, 32) + convert_params_for_cell(cell1, cell2) + rnn1 = RNN(cell1, + is_reverse=self.direction == "backward", + time_major=self.time_major) + rnn2 = paddle.nn.RNN(cell2, + is_reverse=self.direction == "backward", + time_major=self.time_major) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(4, 32) + + y1, h1 = rnn1(x, prev_h) + y2, h2 = rnn2(paddle.to_tensor(x), paddle.to_tensor(prev_h)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, h1 = rnn1(x) + y2, h2 = rnn2(paddle.to_tensor(x)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, h1 = rnn1(x, sequence_length=sequence_length) + + seq_len = paddle.to_tensor(sequence_length) + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y2, h2 = rnn2(paddle.to_tensor(x), sequence_length=seq_len) + y2 = paddle.multiply(y2, mask, axis=0) + + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + self.test_with_input_lengths() + + +class TestBiRNNWrapper(unittest.TestCase): + def __init__(self, time_major=True, place="cpu"): + super(TestBiRNNWrapper, self).__init__("runTest") + self.time_major = time_major + self.place = paddle.CPUPlace() if place == "cpu" \ + else paddle.CUDAPlace(0) + + def setUp(self): + paddle.disable_static(self.place) + fw_cell1 = GRUCell(16, 32) + bw_cell1 = GRUCell(16, 32) + fw_cell2 = paddle.nn.GRUCell(16, 32) + bw_cell2 = paddle.nn.GRUCell(16, 32) + convert_params_for_cell(fw_cell1, fw_cell2) + convert_params_for_cell(bw_cell1, bw_cell2) + rnn1 = BiRNN(fw_cell1, bw_cell1, time_major=self.time_major) + rnn2 = paddle.nn.BiRNN(fw_cell2, bw_cell2, time_major=self.time_major) + + self.rnn1 = rnn1 + self.rnn2 = rnn2 + + def test_with_initial_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + fw_prev_h = np.random.randn(4, 32) + bw_prev_h = np.random.randn(4, 32) + + y1, (fw_h1, bw_h1) = rnn1(x, (fw_prev_h, bw_prev_h)) + y2, (fw_h2, bw_h2) = rnn2( + paddle.to_tensor(x), + (paddle.to_tensor(fw_prev_h), paddle.to_tensor(bw_prev_h))) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(fw_h1, fw_h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(bw_h1, bw_h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_zero_state(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + + y1, (fw_h1, bw_h1) = rnn1(x) + y2, (fw_h2, bw_h2) = rnn2(paddle.to_tensor(x)) + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(fw_h1, fw_h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(bw_h1, bw_h2.numpy(), atol=1e-8, rtol=1e-5) + + def test_with_input_lengths(self): + rnn1 = self.rnn1 + rnn2 = self.rnn2 + + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + sequence_length = np.array([12, 10, 9, 8], dtype=np.int64) + + y1, (fw_h1, bw_h1) = rnn1(x, sequence_length=sequence_length) + + seq_len = paddle.to_tensor(sequence_length) + mask = sequence_mask(seq_len, dtype=paddle.get_default_dtype()) + if self.time_major: + mask = paddle.transpose(mask, [1, 0]) + y2, (fw_h2, bw_h2) = rnn2(paddle.to_tensor(x), sequence_length=seq_len) + y2 = paddle.multiply(y2, mask, axis=0) + + np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(fw_h1, fw_h2.numpy(), atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(bw_h1, bw_h2.numpy(), atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_with_initial_state() + self.test_with_zero_state() + self.test_with_input_lengths() + + +def load_tests(loader, tests, pattern): + suite = unittest.TestSuite() + devices = ["cpu", "gpu"] if paddle.fluid.is_compiled_with_cuda() \ + else ["cpu"] + for direction in ["forward", "backward"]: + for device in devices: + for time_major in [False]: + suite.addTest(TestRNNWrapper(time_major, direction, device)) + suite.addTest(TestBiRNNWrapper(time_major, device)) + return suite diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 6f1c5f199ac99692840ad3c5cffdb726baf5fa19..0687fefe00506a305a485bcf95c9e0c41acbcf87 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -288,7 +288,7 @@ class SimpleRNNCell(RNNCellBase): `weight_hh`. Default: None. bias_ih_attr (ParamAttr, optional): The parameter attribute for the `bias_ih`. Default: None. - bias_ih_attr (ParamAttr, optional): The parameter attribute for the + bias_hh_attr (ParamAttr, optional): The parameter attribute for the `bias_hh`. Default: None. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -429,7 +429,7 @@ class LSTMCell(RNNCellBase): `weight_hh`. Default: None. bias_ih_attr (ParamAttr, optional): The parameter attribute for the `bias_ih`. Default: None. - bias_ih_attr (ParamAttr, optional): The parameter attribute for the + bias_hh_attr (ParamAttr, optional): The parameter attribute for the `bias_hh`. Default: None. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -582,7 +582,7 @@ class GRUCell(RNNCellBase): `weight_hh`. Default: None. bias_ih_attr (ParamAttr, optional): The parameter attribute for the `bias_ih`. Default: None. - bias_ih_attr (ParamAttr, optional): The parameter attribute for the + bias_hh_attr (ParamAttr, optional): The parameter attribute for the `bias_hh`. Default: None. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -778,12 +778,6 @@ class RNN(Layer): initial_states=None, sequence_length=None, **kwargs): - if initial_states is None: - initial_states = self.cell.get_initial_states( - batch_ref=inputs, - dtype=inputs.dtype, - batch_dim_idx=self.batch_index) - final_outputs, final_states = F.rnn(self.cell, inputs, initial_states=initial_states, @@ -880,8 +874,6 @@ class BiRNN(Layer): if isinstance(initial_states, (list, tuple)): assert len(initial_states) == 2, \ "length of initial_states should be 2 when it is a list/tuple" - else: - initial_states = [initial_states, initial_states] outputs, final_states = F.birnn(self.cell_fw, self.cell_bw, inputs, initial_states, sequence_length, @@ -968,7 +960,7 @@ class SimpleRNN(RNNMixin): `weight_hh` of each cell. Defaults to None. bias_ih_attr (ParamAttr, optional): The parameter attribute for the `bias_ih` of each cells. Defaults to None. - bias_ih_attr (ParamAttr, optional): The parameter attribute for the + bias_hh_attr (ParamAttr, optional): The parameter attribute for the `bias_hh` of each cells. Defaults to None. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -1111,7 +1103,7 @@ class LSTM(RNNMixin): `weight_hh` of each cell. Default: None. bias_ih_attr (ParamAttr, optional): The parameter attribute for the `bias_ih` of each cells. Default: None. - bias_ih_attr (ParamAttr, optional): The parameter attribute for the + bias_hh_attr (ParamAttr, optional): The parameter attribute for the `bias_hh` of each cells. Default: None. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -1247,7 +1239,7 @@ class GRU(RNNMixin): `weight_hh` of each cell. Default: None. bias_ih_attr (ParamAttr, optional): The parameter attribute for the `bias_ih` of each cells. Default: None. - bias_ih_attr (ParamAttr, optional): The parameter attribute for the + bias_hh_attr (ParamAttr, optional): The parameter attribute for the `bias_hh` of each cells. Default: None. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.