diff --git a/python/paddle/fluid/dygraph/rnn.py b/python/paddle/fluid/dygraph/rnn.py index 42fdd82b81064a2ffb2dc893599ac08e6b0287ba..15483b4dc192336970aee12918e0dce2b91a16e1 100644 --- a/python/paddle/fluid/dygraph/rnn.py +++ b/python/paddle/fluid/dygraph/rnn.py @@ -39,7 +39,6 @@ class LSTMCell(Layer): \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c) c_t &= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} h_t &= o_t \\odot tanh(c_t) - Args: hidden_size (integer): The hidden size used in the Cell. input_size (integer): The input size used in the Cell. @@ -64,30 +63,25 @@ class LSTMCell(Layer): Returns: None - Examples: .. code-block:: python from paddle import fluid import paddle.fluid.core as core from paddle.fluid.dygraph.rnn import LSTMCell import numpy as np - batch_size = 64 input_size = 128 hidden_size = 256 - step_input_np = np.random.uniform(-0.1, 0.1, ( batch_size, input_size)).astype('float64') pre_hidden_np = np.random.uniform(-0.1, 0.1, ( batch_size, hidden_size)).astype('float64') pre_cell_np = np.random.uniform(-0.1, 0.1, ( batch_size, hidden_size)).astype('float64') - if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) else: place = core.CPUPlace() - with fluid.dygraph.guard(place): cudnn_lstm = LSTMCell(hidden_size, input_size) step_input_var = fluid.dygraph.to_variable(step_input_np) @@ -139,12 +133,12 @@ class LSTMCell(Layer): self._weight_ih = self.create_parameter( attr=weight_ih_param_attr, - shape=[self._input_size, 4 * self._hidden_size], + shape=[4 * self._hidden_size, self._input_size], dtype=self._dtype) self._weight_hh = self.create_parameter( attr=weight_hh_param_attr, - shape=[self._hidden_size, 4 * self._hidden_size], + shape=[4 * self._hidden_size, self._hidden_size], dtype=self._dtype) self._bias_ih = self.create_parameter( @@ -180,10 +174,10 @@ class LSTMCell(Layer): def forward(self, input, pre_hidden, pre_cell): if self._use_cudnn_impl: - - igates = layers.matmul(input, y=self._weight_ih) + igates = layers.matmul(input, y=self._weight_ih, transpose_y=True) igates = layers.elementwise_add(igates, self._bias_ih) - hgates = layers.matmul(pre_hidden, self._weight_hh) + hgates = layers.matmul( + pre_hidden, self._weight_hh, transpose_y=True) hgates = layers.elementwise_add(hgates, self._bias_hh) chunked_igates = layers.split(igates, num_or_sections=4, dim=1) @@ -264,28 +258,23 @@ class GRUCell(Layer): Returns: None - Examples: .. code-block:: python from paddle import fluid import paddle.fluid.core as core from paddle.fluid.dygraph.rnn import GRUCell import numpy as np - batch_size = 64 input_size = 128 hidden_size = 256 - step_input_np = np.random.uniform(-0.1, 0.1, ( batch_size, input_size)).astype('float64') pre_hidden_np = np.random.uniform(-0.1, 0.1, ( batch_size, hidden_size)).astype('float64') - if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) else: place = core.CPUPlace() - with fluid.dygraph.guard(place): cudnn_gru = GRUCell(hidden_size, input_size) step_input_var = fluid.dygraph.to_variable(step_input_np) @@ -334,12 +323,12 @@ class GRUCell(Layer): self._weight_ih = self.create_parameter( attr=weight_ih_param_attr, - shape=[self._input_size, 3 * self._hidden_size], + shape=[3 * self._hidden_size, self._input_size], dtype=self._dtype) self._weight_hh = self.create_parameter( attr=weight_hh_param_attr, - shape=[self._hidden_size, 3 * self._hidden_size], + shape=[3 * self._hidden_size, self._hidden_size], dtype=self._dtype) self._bias_ih = self.create_parameter( @@ -402,9 +391,10 @@ class GRUCell(Layer): if self._use_cudnn_impl: - igates = layers.matmul(input, y=self._weight_ih) + igates = layers.matmul(input, y=self._weight_ih, transpose_y=True) igates = layers.elementwise_add(igates, self._bias_ih) - hgates = layers.matmul(pre_hidden, self._weight_hh) + hgates = layers.matmul( + pre_hidden, self._weight_hh, transpose_y=True) hgates = layers.elementwise_add(hgates, self._bias_hh) chunked_igates = layers.split(igates, num_or_sections=3, dim=1) diff --git a/python/paddle/fluid/tests/unittests/test_cudnn_grucell.py b/python/paddle/fluid/tests/unittests/test_cudnn_grucell.py index b3ad8b037d8c87c776a1390076919633ab8e3006..091233f6d70ad824cd2cc3814bf1864eb9b7ae5a 100644 --- a/python/paddle/fluid/tests/unittests/test_cudnn_grucell.py +++ b/python/paddle/fluid/tests/unittests/test_cudnn_grucell.py @@ -34,9 +34,9 @@ def tanh(x): def cudnn_step(step_input_np, pre_hidden_np, weight_ih, bias_ih, weight_hh, bias_hh): - igates = np.matmul(step_input_np, weight_ih) + igates = np.matmul(step_input_np, weight_ih.transpose(1, 0)) igates += bias_ih - hgates = np.matmul(pre_hidden_np, weight_hh) + hgates = np.matmul(pre_hidden_np, weight_hh.transpose(1, 0)) hgates += bias_hh chunked_igates = np.split(igates, indices_or_sections=3, axis=1) diff --git a/python/paddle/fluid/tests/unittests/test_cudnn_lstmcell.py b/python/paddle/fluid/tests/unittests/test_cudnn_lstmcell.py index 69718f304b6968067454780544ca97f4eb2246e0..201648690f2980162c02c0d3591209c7553041be 100644 --- a/python/paddle/fluid/tests/unittests/test_cudnn_lstmcell.py +++ b/python/paddle/fluid/tests/unittests/test_cudnn_lstmcell.py @@ -32,7 +32,12 @@ def tanh(x): return 2. * sigmoid(2. * x) - 1. -def cudnn_step(step_in, pre_hidden, pre_cell, gate_w, gate_b, forget_bias=1.0): +def non_cudnn_step(step_in, + pre_hidden, + pre_cell, + gate_w, + gate_b, + forget_bias=1.0): concat_1 = np.concatenate([step_in, pre_hidden], 1) gate_input = np.matmul(concat_1, gate_w) @@ -45,12 +50,12 @@ def cudnn_step(step_in, pre_hidden, pre_cell, gate_w, gate_b, forget_bias=1.0): return new_hidden, new_cell -def non_cudnn_step(step_input_np, pre_hidden_np, pre_cell_np, weight_ih, - bias_ih, weight_hh, bias_hh): +def cudnn_step(step_input_np, pre_hidden_np, pre_cell_np, weight_ih, bias_ih, + weight_hh, bias_hh): - igates = np.matmul(step_input_np, weight_ih) + igates = np.matmul(step_input_np, weight_ih.transpose(1, 0)) igates = igates + bias_ih - hgates = np.matmul(pre_hidden_np, weight_hh) + hgates = np.matmul(pre_hidden_np, weight_hh.transpose(1, 0)) hgates = hgates + bias_hh chunked_igates = np.split(igates, indices_or_sections=4, axis=1) @@ -102,7 +107,6 @@ class TestCudnnLSTM(unittest.TestCase): bias_ih_name = "_bias_ih" weight_hh_name = "_weight_hh" bias_hh_name = "_bias_hh" - weight_ih = param_list[weight_ih_name].numpy() weight_ih = np.random.uniform( -0.1, 0.1, size=weight_ih.shape).astype('float64') @@ -146,10 +150,9 @@ class TestCudnnLSTM(unittest.TestCase): named_api_hidden_out = named_api_out[0] named_api_cell_out = named_api_out[1] - np_hidden_out, np_cell_out = non_cudnn_step( + np_hidden_out, np_cell_out = cudnn_step( step_input_np, pre_hidden_np, pre_cell_np, weight_ih, bias_ih, weight_hh, bias_hh) - self.assertTrue( np.allclose( api_hidden_out.numpy(), np_hidden_out, rtol=1e-5, atol=0)) @@ -230,7 +233,7 @@ class TestNonCudnnLSTM(unittest.TestCase): named_api_hidden_out = named_api_out[0] named_api_cell_out = named_api_out[1] - np_hidden_out, np_cell_out = cudnn_step( + np_hidden_out, np_cell_out = non_cudnn_step( step_input_np, pre_hidden_np, pre_cell_np, gate_w, gate_b) self.assertTrue(