diff --git a/python/paddle/fluid/dygraph/rnn.py b/python/paddle/fluid/dygraph/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..42fdd82b81064a2ffb2dc893599ac08e6b0287ba --- /dev/null +++ b/python/paddle/fluid/dygraph/rnn.py @@ -0,0 +1,447 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .layers import Layer +from paddle.fluid import layers +import copy + +__all__ = ['LSTMCell', 'GRUCell'] + + +class LSTMCell(Layer): + """ + LSTMCell implementation using basic operators. + There are two LSTMCell version, the default one is compatible with CUDNN LSTM implementation. + The algorithm can be described as the equations below. + .. math:: + i_t &= sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) + f_t &= sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f) + o_t &= sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o) + \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c) + c_t &= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} + h_t &= o_t \\odot tanh(c_t) + The other LSTMCell version is compatible with the BasicLSTMUnit used in static graph. + The algorithm can be described as the equations below. + i_t &= sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i) + f_t &= sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias ) + o_t &= sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o) + \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c) + c_t &= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} + h_t &= o_t \\odot tanh(c_t) + + Args: + hidden_size (integer): The hidden size used in the Cell. + input_size (integer): The input size used in the Cell. + param_attr(ParamAttr|None): The parameter attribute for the learnable + weight matrix. Note: + If it is set to None or one attribute of ParamAttr, lstm_unit will + create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. Default: None. + bias_attr (ParamAttr|None): The parameter attribute for the bias + of LSTM unit. + If it is set to None or one attribute of ParamAttr, lstm_unit will + create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized as zero. Default: None. + gate_activation (function|None): The activation function for gates (actGate). + Default: 'fluid.layers.sigmoid' + activation (function|None): The activation function for cells (actNode). + Default: 'fluid.layers.tanh' + forget_bias(float|1.0): forget bias used when computing forget gate. This + is not used in default LSTMCell implementation (CUDNN compatiable) + use_cudnn_impl(bool|True): whether to use CUDNN compatible LSTMCell + dtype(string): data type used in this unit + + Returns: + None + + Examples: + .. code-block:: python + from paddle import fluid + import paddle.fluid.core as core + from paddle.fluid.dygraph.rnn import LSTMCell + import numpy as np + + batch_size = 64 + input_size = 128 + hidden_size = 256 + + step_input_np = np.random.uniform(-0.1, 0.1, ( + batch_size, input_size)).astype('float64') + pre_hidden_np = np.random.uniform(-0.1, 0.1, ( + batch_size, hidden_size)).astype('float64') + pre_cell_np = np.random.uniform(-0.1, 0.1, ( + batch_size, hidden_size)).astype('float64') + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + + with fluid.dygraph.guard(place): + cudnn_lstm = LSTMCell(hidden_size, input_size) + step_input_var = fluid.dygraph.to_variable(step_input_np) + pre_hidden_var = fluid.dygraph.to_variable(pre_hidden_np) + pre_cell_var = fluid.dygraph.to_variable(pre_cell_np) + new_hidden, new_cell = cudnn_lstm(step_input_var, pre_hidden_var, pre_cell_var) + """ + + def __init__(self, + hidden_size, + input_size, + param_attr=None, + bias_attr=None, + gate_activation=None, + activation=None, + forget_bias=1.0, + use_cudnn_impl=True, + dtype='float64'): + super(LSTMCell, self).__init__(dtype) + + self._hidden_size = hidden_size + self._input_size = input_size + self._param_attr = param_attr + self._bias_attr = bias_attr + self._dtype = dtype + self._gate_activation = gate_activation or layers.sigmoid + self._activation = activation or layers.tanh + self._use_cudnn_impl = use_cudnn_impl + + if self._use_cudnn_impl: + + if self._param_attr is not None and self._param_attr.name is not None: + weight_ih_param_attr = copy.deepcopy(self._param_attr) + weight_hh_param_attr = copy.deepcopy(self._param_attr) + weight_ih_param_attr.name += "_weight_ih" + weight_hh_param_attr.name += "_weight_hh" + else: + weight_ih_param_attr = self._param_attr + weight_hh_param_attr = self._param_attr + + if self._bias_attr is not None and self._bias_attr.name is not None: + bias_ih_param_attr = copy.deepcopy(self._bias_attr) + bias_hh_param_attr = copy.deepcopy(self._bias_attr) + bias_ih_param_attr.name += "_bias_ih" + bias_hh_param_attr.name += "_bias_hh" + else: + bias_ih_param_attr = self._bias_attr + bias_hh_param_attr = self._bias_attr + + self._weight_ih = self.create_parameter( + attr=weight_ih_param_attr, + shape=[self._input_size, 4 * self._hidden_size], + dtype=self._dtype) + + self._weight_hh = self.create_parameter( + attr=weight_hh_param_attr, + shape=[self._hidden_size, 4 * self._hidden_size], + dtype=self._dtype) + + self._bias_ih = self.create_parameter( + attr=bias_ih_param_attr, + shape=[4 * self._hidden_size], + dtype=self._dtype, + is_bias=True) + self._bias_hh = self.create_parameter( + attr=bias_hh_param_attr, + shape=[4 * self._hidden_size], + dtype=self._dtype, + is_bias=True) + + else: + + self._forget_bias = layers.fill_constant( + [1], dtype=dtype, value=forget_bias) + self._forget_bias.stop_gradient = False + + self._weight = self.create_parameter( + attr=self._param_attr, + shape=[ + self._input_size + self._hidden_size, 4 * self._hidden_size + ], + dtype=dtype) + + self._bias = self.create_parameter( + attr=self._bias_attr, + shape=[4 * self._hidden_size], + dtype=dtype, + is_bias=True) + + def forward(self, input, pre_hidden, pre_cell): + + if self._use_cudnn_impl: + + igates = layers.matmul(input, y=self._weight_ih) + igates = layers.elementwise_add(igates, self._bias_ih) + hgates = layers.matmul(pre_hidden, self._weight_hh) + hgates = layers.elementwise_add(hgates, self._bias_hh) + + chunked_igates = layers.split(igates, num_or_sections=4, dim=1) + chunked_hgates = layers.split(hgates, num_or_sections=4, dim=1) + + ingate = layers.elementwise_add(chunked_igates[0], + chunked_hgates[0]) + ingate = self._gate_activation(ingate) + + forgetgate = layers.elementwise_add(chunked_igates[1], + chunked_hgates[1]) + forgetgate = self._gate_activation(forgetgate) + + cellgate = layers.elementwise_add(chunked_igates[2], + chunked_hgates[2]) + cellgate = self._activation(cellgate) + + outgate = layers.elementwise_add(chunked_igates[3], + chunked_hgates[3]) + outgate = self._gate_activation(outgate) + + new_cell = (forgetgate * pre_cell) + (ingate * cellgate) + new_hidden = outgate * self._activation(new_cell) + + else: + + concat_input_hidden = layers.concat([input, pre_hidden], 1) + gate_input = layers.matmul(x=concat_input_hidden, y=self._weight) + + gate_input = layers.elementwise_add(gate_input, self._bias) + i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1) + new_cell = layers.elementwise_add( + layers.elementwise_mul( + pre_cell, + self._gate_activation( + layers.elementwise_add(f, self._forget_bias))), + layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j))) + new_hidden = self._activation(new_cell) * self._gate_activation(o) + + return new_hidden, new_cell + + +class GRUCell(Layer): + """ + GRU implementation using basic operators. + There are two GRUCell version, the default one is compatible with CUDNN GRU implementation. + The algorithm can be described as the equations below. + .. math:: + u_t & = sigmoid(W_{ux} x_{t} + b_ux + W_{uh} h_{t-1} + b_uh) + r_t & = sigmoid(W_{rx} x_{t} + b_rx + W_{rh} h_{t-1} + b_rh) + \\tilde{h_{t}} & = tanh(W_{cx} x_{t} + b_cx + r_t \\odot (W_{ch} h_{t-1} + b_ch)) + h_t & = u_t h_{t-1} + (1-u_t) \\tilde{h_{t}} + The other LSTMCell version is compatible with the BasicGRUUnit used in static graph. + The algorithm can be described as the equations below. + u_t & = sigmoid(W_{ux} x_{t} + W_{uh} h_{t-1} + b_u) + r_t & = sigmoid(W_{rx} x_{t} + W_{rh} h_{t-1} + b_r) + \\tilde{h_{t}} & = tanh(W_{cx} x_{t} + W_{ch} \\odot(r_t, h_{t-1}) + b_m) + h_t & = u_t h_{t-1} + (1-u_t) \\tilde{h_{t}} + Args: + hidden_size (integer): The hidden size used in the Cell. + input_size (integer): The input size used in the Cell. + param_attr(ParamAttr|None): The parameter attribute for the learnable + weight matrix. Note: + If it is set to None or one attribute of ParamAttr, gru_unit will + create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. Default: None. + bias_attr (ParamAttr|None): The parameter attribute for the bias + of GRU unit. + If it is set to None or one attribute of ParamAttr, gru_unit will + create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + gate_activation (function|None): The activation function for gates (actGate). + Default: 'fluid.layers.sigmoid' + activation (function|None): The activation function for cell (actNode). + Default: 'fluid.layers.tanh' + use_cudnn_impl(bool|True): whether to use CUDNN compatible LSTMCell + dtype(string): data type used in this unit + + Returns: + None + + Examples: + .. code-block:: python + from paddle import fluid + import paddle.fluid.core as core + from paddle.fluid.dygraph.rnn import GRUCell + import numpy as np + + batch_size = 64 + input_size = 128 + hidden_size = 256 + + step_input_np = np.random.uniform(-0.1, 0.1, ( + batch_size, input_size)).astype('float64') + pre_hidden_np = np.random.uniform(-0.1, 0.1, ( + batch_size, hidden_size)).astype('float64') + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + + with fluid.dygraph.guard(place): + cudnn_gru = GRUCell(hidden_size, input_size) + step_input_var = fluid.dygraph.to_variable(step_input_np) + pre_hidden_var = fluid.dygraph.to_variable(pre_hidden_np) + """ + + def __init__(self, + hidden_size, + input_size, + param_attr=None, + bias_attr=None, + gate_activation=None, + activation=None, + use_cudnn_impl=True, + dtype='float64'): + super(GRUCell, self).__init__() + + self._hidden_size = hidden_size + self._input_size = input_size + self._param_attr = param_attr + self._bias_attr = bias_attr + self._dtype = dtype + self._gate_activation = gate_activation or layers.sigmoid + self._activation = activation or layers.tanh + self._use_cudnn_impl = use_cudnn_impl + + if self._use_cudnn_impl: + + if self._param_attr is not None and self._param_attr.name is not None: + weight_ih_param_attr = copy.deepcopy(self._param_attr) + weight_hh_param_attr = copy.deepcopy(self._param_attr) + weight_ih_param_attr.name += "_weight_ih" + weight_hh_param_attr.name += "_weight_hh" + else: + weight_ih_param_attr = self._param_attr + weight_hh_param_attr = self._param_attr + + if self._bias_attr is not None and self._bias_attr.name is not None: + bias_ih_param_attr = copy.deepcopy(self._bias_attr) + bias_hh_param_attr = copy.deepcopy(self._bias_attr) + bias_ih_param_attr.name += "_bias_ih" + bias_hh_param_attr.name += "_bias_hh" + else: + bias_ih_param_attr = self._bias_attr + bias_hh_param_attr = self._bias_attr + + self._weight_ih = self.create_parameter( + attr=weight_ih_param_attr, + shape=[self._input_size, 3 * self._hidden_size], + dtype=self._dtype) + + self._weight_hh = self.create_parameter( + attr=weight_hh_param_attr, + shape=[self._hidden_size, 3 * self._hidden_size], + dtype=self._dtype) + + self._bias_ih = self.create_parameter( + attr=bias_ih_param_attr, + shape=[3 * self._hidden_size], + dtype=self._dtype, + is_bias=True) + self._bias_hh = self.create_parameter( + attr=bias_hh_param_attr, + shape=[3 * self._hidden_size], + dtype=self._dtype, + is_bias=True) + + else: + + if self._param_attr is not None and self._param_attr.name is not None: + gate_weight_param_attr = copy.deepcopy(self._param_attr) + candidate_weight_param_attr = copy.deepcopy(self._param_attr) + gate_weight_param_attr.name += "_gate_weight" + candidate_weight_param_attr.name += "_candidate_weight" + else: + gate_weight_param_attr = self._param_attr + candidate_weight_param_attr = self._param_attr + + if self._bias_attr is not None and self._bias_attr.name is not None: + gate_bias_param_attr = copy.deepcopy(self._bias_attr) + candidate_bias_param_attr = copy.deepcopy(self._bias_attr) + gate_bias_param_attr.name += "_gate_bias" + candidate_bias_param_attr.name += "_candidate_bias" + else: + gate_bias_param_attr = self._bias_attr + candidate_bias_param_attr = self._bias_attr + + self._gate_weight = self.create_parameter( + attr=gate_weight_param_attr, + shape=[ + self._input_size + self._hidden_size, 2 * self._hidden_size + ], + dtype=dtype) + + self._candidate_weight = self.create_parameter( + attr=candidate_weight_param_attr, + shape=[ + self._input_size + self._hidden_size, self._hidden_size + ], + dtype=dtype) + + self._gate_bias = self.create_parameter( + attr=gate_bias_param_attr, + shape=[2 * self._hidden_size], + dtype=dtype, + is_bias=True) + self._candidate_bias = self.create_parameter( + attr=candidate_bias_param_attr, + shape=[self._hidden_size], + dtype=dtype, + is_bias=True) + + def forward(self, input, pre_hidden): + + if self._use_cudnn_impl: + + igates = layers.matmul(input, y=self._weight_ih) + igates = layers.elementwise_add(igates, self._bias_ih) + hgates = layers.matmul(pre_hidden, self._weight_hh) + hgates = layers.elementwise_add(hgates, self._bias_hh) + + chunked_igates = layers.split(igates, num_or_sections=3, dim=1) + chunked_hgates = layers.split(hgates, num_or_sections=3, dim=1) + + reset_gate = layers.elementwise_add(chunked_igates[0], + chunked_hgates[0]) + reset_gate = self._gate_activation(reset_gate) + + input_gate = layers.elementwise_add(chunked_igates[1], + chunked_hgates[1]) + input_gate = self._gate_activation(input_gate) + + _temp = reset_gate * chunked_hgates[2] + new_gate = layers.elementwise_add(chunked_igates[2], _temp) + new_gate = self._activation(new_gate) + + new_hidden = (pre_hidden - new_gate) * input_gate + new_gate + + else: + + concat_input_hidden = layers.concat([input, pre_hidden], 1) + + gate_input = layers.matmul( + x=concat_input_hidden, y=self._gate_weight) + + gate_input = layers.elementwise_add(gate_input, self._gate_bias) + gate_input = self._gate_activation(gate_input) + r, u = layers.split(gate_input, num_or_sections=2, dim=1) + + r_hidden = r * pre_hidden + + candidate = layers.matmul( + layers.concat([input, r_hidden], 1), self._candidate_weight) + candidate = layers.elementwise_add(candidate, self._candidate_bias) + + c = self._activation(candidate) + new_hidden = u * pre_hidden + (1 - u) * c + + return new_hidden diff --git a/python/paddle/fluid/tests/unittests/test_cudnn_grucell.py b/python/paddle/fluid/tests/unittests/test_cudnn_grucell.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ad8b037d8c87c776a1390076919633ab8e3006 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cudnn_grucell.py @@ -0,0 +1,233 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.dygraph.rnn import GRUCell + +import numpy as np + +np.random.seed = 123 + + +def sigmoid(x): + return 1. / (1. + np.exp(-x)) + + +def tanh(x): + return 2. * sigmoid(2. * x) - 1. + + +def cudnn_step(step_input_np, pre_hidden_np, weight_ih, bias_ih, weight_hh, + bias_hh): + igates = np.matmul(step_input_np, weight_ih) + igates += bias_ih + hgates = np.matmul(pre_hidden_np, weight_hh) + hgates += bias_hh + + chunked_igates = np.split(igates, indices_or_sections=3, axis=1) + chunked_hgates = np.split(hgates, indices_or_sections=3, axis=1) + + reset_gate = chunked_igates[0] + chunked_hgates[0] + reset_gate = sigmoid(reset_gate) + + input_gate = chunked_igates[1] + chunked_hgates[1] + input_gate = sigmoid(input_gate) + + _temp = reset_gate * chunked_hgates[2] + new_gate = chunked_igates[2] + _temp + new_gate = tanh(new_gate) + + new_hidden = (pre_hidden_np - new_gate) * input_gate + new_gate + + return new_hidden + + +def non_cudnn_step(step_in, pre_hidden, gate_w, gate_b, candidate_w, + candidate_b): + concat_1 = np.concatenate([step_in, pre_hidden], 1) + + gate_input = np.matmul(concat_1, gate_w) + gate_input += gate_b + gate_input = sigmoid(gate_input) + r, u = np.split(gate_input, indices_or_sections=2, axis=1) + + r_hidden = r * pre_hidden + + candidate = np.matmul(np.concatenate([step_in, r_hidden], 1), candidate_w) + + candidate += candidate_b + c = tanh(candidate) + + new_hidden = u * pre_hidden + (1 - u) * c + + return new_hidden + + +class TestCudnnGRU(unittest.TestCase): + def setUp(self): + self.input_size = 100 + self.hidden_size = 200 + self.batch_size = 64 + + def test_run(self): + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + + with fluid.dygraph.guard(place): + param_attr = fluid.ParamAttr(name="param_attr") + bias_attr = fluid.ParamAttr(name="bias_attr") + named_cudnn_gru = GRUCell(self.hidden_size, self.input_size, + param_attr, bias_attr) + cudnn_gru = GRUCell(self.hidden_size, self.input_size) + + param_list = cudnn_gru.state_dict() + named_param_list = named_cudnn_gru.state_dict() + + # process weight and bias + + weight_ih_name = "_weight_ih" + bias_ih_name = "_bias_ih" + weight_hh_name = "_weight_hh" + bias_hh_name = "_bias_hh" + + weight_ih = param_list[weight_ih_name].numpy() + weight_ih = np.random.uniform( + -0.1, 0.1, size=weight_ih.shape).astype('float64') + param_list[weight_ih_name].set_value(weight_ih) + named_param_list[weight_ih_name].set_value(weight_ih) + + bias_ih = param_list[bias_ih_name].numpy() + bias_ih = np.random.uniform( + -0.1, 0.1, size=bias_ih.shape).astype('float64') + param_list[bias_ih_name].set_value(bias_ih) + named_param_list[bias_ih_name].set_value(bias_ih) + + weight_hh = param_list[weight_hh_name].numpy() + weight_hh = np.random.uniform( + -0.1, 0.1, size=weight_hh.shape).astype('float64') + param_list[weight_hh_name].set_value(weight_hh) + named_param_list[weight_hh_name].set_value(weight_hh) + + bias_hh = param_list[bias_hh_name].numpy() + bias_hh = np.random.uniform( + -0.1, 0.1, size=bias_hh.shape).astype('float64') + param_list[bias_hh_name].set_value(bias_hh) + named_param_list[bias_hh_name].set_value(bias_hh) + + step_input_np = np.random.uniform(-0.1, 0.1, ( + self.batch_size, self.input_size)).astype('float64') + pre_hidden_np = np.random.uniform(-0.1, 0.1, ( + self.batch_size, self.hidden_size)).astype('float64') + + step_input_var = fluid.dygraph.to_variable(step_input_np) + pre_hidden_var = fluid.dygraph.to_variable(pre_hidden_np) + api_out = cudnn_gru(step_input_var, pre_hidden_var) + named_api_out = named_cudnn_gru(step_input_var, pre_hidden_var) + + np_out = cudnn_step(step_input_np, pre_hidden_np, weight_ih, bias_ih, + weight_hh, bias_hh) + + self.assertTrue(np.allclose(api_out.numpy(), np_out, rtol=1e-5, atol=0)) + self.assertTrue( + np.allclose( + named_api_out.numpy(), np_out, rtol=1e-5, atol=0)) + + +class TestNonCudnnGRU(unittest.TestCase): + def setUp(self): + self.input_size = 100 + self.hidden_size = 200 + self.batch_size = 64 + + def test_run(self): + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + + with fluid.dygraph.guard(place): + param_attr = fluid.ParamAttr(name="param_attr") + bias_attr = fluid.ParamAttr(name="bias_attr") + named_non_cudnn_gru = GRUCell( + self.hidden_size, + self.input_size, + param_attr, + bias_attr, + use_cudnn_impl=False) + non_cudnn_gru = GRUCell( + self.hidden_size, self.input_size, use_cudnn_impl=False) + + param_list = non_cudnn_gru.state_dict() + named_param_list = named_non_cudnn_gru.state_dict() + + # process weight and bias + + gate_w_name = "_gate_weight" + gate_b_name = "_gate_bias" + candidate_w_name = "_candidate_weight" + candidate_b_name = "_candidate_bias" + + gate_w = param_list[gate_w_name].numpy() + gate_w = np.random.uniform( + -0.1, 0.1, size=gate_w.shape).astype('float64') + param_list[gate_w_name].set_value(gate_w) + named_param_list[gate_w_name].set_value(gate_w) + + gate_b = param_list[gate_b_name].numpy() + gate_b = np.random.uniform( + -0.1, 0.1, size=gate_b.shape).astype('float64') + param_list[gate_b_name].set_value(gate_b) + named_param_list[gate_b_name].set_value(gate_b) + + candidate_w = param_list[candidate_w_name].numpy() + candidate_w = np.random.uniform( + -0.1, 0.1, size=candidate_w.shape).astype('float64') + param_list[candidate_w_name].set_value(candidate_w) + named_param_list[candidate_w_name].set_value(candidate_w) + + candidate_b = param_list[candidate_b_name].numpy() + candidate_b = np.random.uniform( + -0.1, 0.1, size=candidate_b.shape).astype('float64') + param_list[candidate_b_name].set_value(candidate_b) + named_param_list[candidate_b_name].set_value(candidate_b) + + step_input_np = np.random.uniform(-0.1, 0.1, ( + self.batch_size, self.input_size)).astype('float64') + pre_hidden_np = np.random.uniform(-0.1, 0.1, ( + self.batch_size, self.hidden_size)).astype('float64') + + step_input_var = fluid.dygraph.to_variable(step_input_np) + pre_hidden_var = fluid.dygraph.to_variable(pre_hidden_np) + api_out = non_cudnn_gru(step_input_var, pre_hidden_var) + named_api_out = named_non_cudnn_gru(step_input_var, pre_hidden_var) + + np_out = non_cudnn_step(step_input_np, pre_hidden_np, gate_w, gate_b, + candidate_w, candidate_b) + + self.assertTrue(np.allclose(api_out.numpy(), np_out, rtol=1e-5, atol=0)) + self.assertTrue( + np.allclose( + named_api_out.numpy(), np_out, rtol=1e-5, atol=0)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_cudnn_lstmcell.py b/python/paddle/fluid/tests/unittests/test_cudnn_lstmcell.py new file mode 100644 index 0000000000000000000000000000000000000000..69718f304b6968067454780544ca97f4eb2246e0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cudnn_lstmcell.py @@ -0,0 +1,254 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.dygraph.rnn import LSTMCell + +import numpy as np + +np.random.seed = 123 + + +def sigmoid(x): + return 1. / (1. + np.exp(-x)) + + +def tanh(x): + return 2. * sigmoid(2. * x) - 1. + + +def cudnn_step(step_in, pre_hidden, pre_cell, gate_w, gate_b, forget_bias=1.0): + concat_1 = np.concatenate([step_in, pre_hidden], 1) + + gate_input = np.matmul(concat_1, gate_w) + gate_input += gate_b + i, j, f, o = np.split(gate_input, indices_or_sections=4, axis=1) + + new_cell = pre_cell * sigmoid(f + forget_bias) + sigmoid(i) * tanh(j) + new_hidden = tanh(new_cell) * sigmoid(o) + + return new_hidden, new_cell + + +def non_cudnn_step(step_input_np, pre_hidden_np, pre_cell_np, weight_ih, + bias_ih, weight_hh, bias_hh): + + igates = np.matmul(step_input_np, weight_ih) + igates = igates + bias_ih + hgates = np.matmul(pre_hidden_np, weight_hh) + hgates = hgates + bias_hh + + chunked_igates = np.split(igates, indices_or_sections=4, axis=1) + chunked_hgates = np.split(hgates, indices_or_sections=4, axis=1) + + ingate = chunked_igates[0] + chunked_hgates[0] + ingate = sigmoid(ingate) + + forgetgate = chunked_igates[1] + chunked_hgates[1] + forgetgate = sigmoid(forgetgate) + + cellgate = chunked_igates[2] + chunked_hgates[2] + cellgate = tanh(cellgate) + + outgate = chunked_igates[3] + chunked_hgates[3] + outgate = sigmoid(outgate) + + new_cell = (forgetgate * pre_cell_np) + (ingate * cellgate) + new_hidden = outgate * tanh(new_cell) + + return new_hidden, new_cell + + +class TestCudnnLSTM(unittest.TestCase): + def setUp(self): + self.input_size = 100 + self.hidden_size = 200 + self.batch_size = 128 + + def test_run(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + + with fluid.dygraph.guard(place): + param_attr = fluid.ParamAttr(name="param_attr") + bias_attr = fluid.ParamAttr(name="bias_attr") + named_cudnn_lstm = LSTMCell(self.hidden_size, self.input_size, + param_attr, bias_attr) + cudnn_lstm = LSTMCell(self.hidden_size, self.input_size) + + param_list = cudnn_lstm.state_dict() + named_param_list = named_cudnn_lstm.state_dict() + + # process weight and bias + + weight_ih_name = "_weight_ih" + bias_ih_name = "_bias_ih" + weight_hh_name = "_weight_hh" + bias_hh_name = "_bias_hh" + + weight_ih = param_list[weight_ih_name].numpy() + weight_ih = np.random.uniform( + -0.1, 0.1, size=weight_ih.shape).astype('float64') + param_list[weight_ih_name].set_value(weight_ih) + named_param_list[weight_ih_name].set_value(weight_ih) + + bias_ih = param_list[bias_ih_name].numpy() + bias_ih = np.random.uniform( + -0.1, 0.1, size=bias_ih.shape).astype('float64') + param_list[bias_ih_name].set_value(bias_ih) + named_param_list[bias_ih_name].set_value(bias_ih) + + weight_hh = param_list[weight_hh_name].numpy() + weight_hh = np.random.uniform( + -0.1, 0.1, size=weight_hh.shape).astype('float64') + param_list[weight_hh_name].set_value(weight_hh) + named_param_list[weight_hh_name].set_value(weight_hh) + + bias_hh = param_list[bias_hh_name].numpy() + bias_hh = np.random.uniform( + -0.1, 0.1, size=bias_hh.shape).astype('float64') + param_list[bias_hh_name].set_value(bias_hh) + named_param_list[bias_hh_name].set_value(bias_hh) + + step_input_np = np.random.uniform(-0.1, 0.1, ( + self.batch_size, self.input_size)).astype('float64') + pre_hidden_np = np.random.uniform(-0.1, 0.1, ( + self.batch_size, self.hidden_size)).astype('float64') + pre_cell_np = np.random.uniform(-0.1, 0.1, ( + self.batch_size, self.hidden_size)).astype('float64') + + step_input_var = fluid.dygraph.to_variable(step_input_np) + pre_hidden_var = fluid.dygraph.to_variable(pre_hidden_np) + pre_cell_var = fluid.dygraph.to_variable(pre_cell_np) + api_out = cudnn_lstm(step_input_var, pre_hidden_var, pre_cell_var) + named_api_out = named_cudnn_lstm(step_input_var, pre_hidden_var, + pre_cell_var) + + api_hidden_out = api_out[0] + api_cell_out = api_out[1] + named_api_hidden_out = named_api_out[0] + named_api_cell_out = named_api_out[1] + + np_hidden_out, np_cell_out = non_cudnn_step( + step_input_np, pre_hidden_np, pre_cell_np, weight_ih, bias_ih, + weight_hh, bias_hh) + + self.assertTrue( + np.allclose( + api_hidden_out.numpy(), np_hidden_out, rtol=1e-5, atol=0)) + self.assertTrue( + np.allclose( + api_cell_out.numpy(), np_cell_out, rtol=1e-5, atol=0)) + self.assertTrue( + np.allclose( + named_api_hidden_out.numpy(), + np_hidden_out, + rtol=1e-5, + atol=0)) + self.assertTrue( + np.allclose( + named_api_cell_out.numpy(), np_cell_out, rtol=1e-5, atol=0)) + + +class TestNonCudnnLSTM(unittest.TestCase): + def setUp(self): + self.input_size = 100 + self.hidden_size = 200 + self.batch_size = 128 + + def test_run(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + + with fluid.dygraph.guard(place): + param_attr = fluid.ParamAttr(name="param_attr") + bias_attr = fluid.ParamAttr(name="bias_attr") + named_cudnn_lstm = LSTMCell( + self.hidden_size, + self.input_size, + param_attr, + bias_attr, + use_cudnn_impl=False) + cudnn_lstm = LSTMCell( + self.hidden_size, self.input_size, use_cudnn_impl=False) + + param_list = cudnn_lstm.state_dict() + named_param_list = named_cudnn_lstm.state_dict() + + # process weight and bias + + gate_w_name = "_weight" + gate_b_name = "_bias" + + gate_w = param_list[gate_w_name].numpy() + gate_w = np.random.uniform( + -0.1, 0.1, size=gate_w.shape).astype('float64') + param_list[gate_w_name].set_value(gate_w) + named_param_list[gate_w_name].set_value(gate_w) + + gate_b = param_list[gate_b_name].numpy() + gate_b = np.random.uniform( + -0.1, 0.1, size=gate_b.shape).astype('float64') + param_list[gate_b_name].set_value(gate_b) + named_param_list[gate_b_name].set_value(gate_b) + + step_input_np = np.random.uniform(-0.1, 0.1, ( + self.batch_size, self.input_size)).astype('float64') + pre_hidden_np = np.random.uniform(-0.1, 0.1, ( + self.batch_size, self.hidden_size)).astype('float64') + pre_cell_np = np.random.uniform(-0.1, 0.1, ( + self.batch_size, self.hidden_size)).astype('float64') + + step_input_var = fluid.dygraph.to_variable(step_input_np) + pre_hidden_var = fluid.dygraph.to_variable(pre_hidden_np) + pre_cell_var = fluid.dygraph.to_variable(pre_cell_np) + api_out = cudnn_lstm(step_input_var, pre_hidden_var, pre_cell_var) + named_api_out = named_cudnn_lstm(step_input_var, pre_hidden_var, + pre_cell_var) + + api_hidden_out = api_out[0] + api_cell_out = api_out[1] + named_api_hidden_out = named_api_out[0] + named_api_cell_out = named_api_out[1] + + np_hidden_out, np_cell_out = cudnn_step( + step_input_np, pre_hidden_np, pre_cell_np, gate_w, gate_b) + + self.assertTrue( + np.allclose( + api_hidden_out.numpy(), np_hidden_out, rtol=1e-5, atol=0)) + self.assertTrue( + np.allclose( + api_cell_out.numpy(), np_cell_out, rtol=1e-5, atol=0)) + self.assertTrue( + np.allclose( + named_api_hidden_out.numpy(), + np_hidden_out, + rtol=1e-5, + atol=0)) + self.assertTrue( + np.allclose( + named_api_cell_out.numpy(), np_cell_out, rtol=1e-5, atol=0)) + + +if __name__ == '__main__': + unittest.main()