diff --git a/python/paddle/fluid/dygraph/__init__.py b/python/paddle/fluid/dygraph/__init__.py index 600347f439076977815fe0ad4434325362516dda..22de864dd696100cd7859e33ad935cd6bb10b9f5 100644 --- a/python/paddle/fluid/dygraph/__init__.py +++ b/python/paddle/fluid/dygraph/__init__.py @@ -50,6 +50,9 @@ from .static_runner import StaticModelRunner from . import dygraph_to_static from .dygraph_to_static import ProgramTranslator +from . import rnn +from .rnn import * + __all__ = [] __all__ += layers.__all__ __all__ += base.__all__ @@ -60,4 +63,5 @@ __all__ += checkpoint.__all__ __all__ += learning_rate_scheduler.__all__ __all__ += backward_strategy.__all__ __all__ += jit.__all__ +__all__ += rnn.__all__ __all__ += ['ProgramTranslator'] diff --git a/python/paddle/fluid/dygraph/rnn.py b/python/paddle/fluid/dygraph/rnn.py index 15483b4dc192336970aee12918e0dce2b91a16e1..9df4188fb7eb872d21ac9e6a1f851074a682ca54 100644 --- a/python/paddle/fluid/dygraph/rnn.py +++ b/python/paddle/fluid/dygraph/rnn.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .layers import Layer -from paddle.fluid import layers +from . import Layer +from ..layers import sigmoid, tanh, concat, fill_constant, matmul, elementwise_add, elementwise_mul, split import copy __all__ = ['LSTMCell', 'GRUCell'] @@ -24,32 +24,49 @@ class LSTMCell(Layer): LSTMCell implementation using basic operators. There are two LSTMCell version, the default one is compatible with CUDNN LSTM implementation. The algorithm can be described as the equations below. + .. math:: + i_t &= sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) + f_t &= sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f) + o_t &= sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o) + \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c) + c_t &= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} + h_t &= o_t \\odot tanh(c_t) + The other LSTMCell version is compatible with the BasicLSTMUnit used in static graph. The algorithm can be described as the equations below. + + .. math:: + i_t &= sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i) + f_t &= sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias ) + o_t &= sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o) + \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c) + c_t &= f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} + h_t &= o_t \\odot tanh(c_t) + Args: hidden_size (integer): The hidden size used in the Cell. input_size (integer): The input size used in the Cell. param_attr(ParamAttr|None): The parameter attribute for the learnable weight matrix. Note: - If it is set to None or one attribute of ParamAttr, lstm_unit will + If it is set to None or one attribute of ParamAttr, LSTMCell will create ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is initialized with Xavier. Default: None. bias_attr (ParamAttr|None): The parameter attribute for the bias - of LSTM unit. - If it is set to None or one attribute of ParamAttr, lstm_unit will + of LSTMCell. + If it is set to None or one attribute of ParamAttr, LSTMCell will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not set, the bias is initialized as zero. Default: None. gate_activation (function|None): The activation function for gates (actGate). @@ -59,15 +76,18 @@ class LSTMCell(Layer): forget_bias(float|1.0): forget bias used when computing forget gate. This is not used in default LSTMCell implementation (CUDNN compatiable) use_cudnn_impl(bool|True): whether to use CUDNN compatible LSTMCell - dtype(string): data type used in this unit + dtype(string): data type used in this cell Returns: None + Examples: + .. code-block:: python + from paddle import fluid import paddle.fluid.core as core - from paddle.fluid.dygraph.rnn import LSTMCell + from paddle.fluid.dygraph import LSTMCell import numpy as np batch_size = 64 input_size = 128 @@ -88,6 +108,7 @@ class LSTMCell(Layer): pre_hidden_var = fluid.dygraph.to_variable(pre_hidden_np) pre_cell_var = fluid.dygraph.to_variable(pre_cell_np) new_hidden, new_cell = cudnn_lstm(step_input_var, pre_hidden_var, pre_cell_var) + """ def __init__(self, @@ -107,8 +128,8 @@ class LSTMCell(Layer): self._param_attr = param_attr self._bias_attr = bias_attr self._dtype = dtype - self._gate_activation = gate_activation or layers.sigmoid - self._activation = activation or layers.tanh + self._gate_activation = gate_activation or sigmoid + self._activation = activation or tanh self._use_cudnn_impl = use_cudnn_impl if self._use_cudnn_impl: @@ -154,7 +175,7 @@ class LSTMCell(Layer): else: - self._forget_bias = layers.fill_constant( + self._forget_bias = fill_constant( [1], dtype=dtype, value=forget_bias) self._forget_bias.stop_gradient = False @@ -174,29 +195,24 @@ class LSTMCell(Layer): def forward(self, input, pre_hidden, pre_cell): if self._use_cudnn_impl: - igates = layers.matmul(input, y=self._weight_ih, transpose_y=True) - igates = layers.elementwise_add(igates, self._bias_ih) - hgates = layers.matmul( - pre_hidden, self._weight_hh, transpose_y=True) - hgates = layers.elementwise_add(hgates, self._bias_hh) + igates = matmul(input, y=self._weight_ih, transpose_y=True) + igates = elementwise_add(igates, self._bias_ih) + hgates = matmul(pre_hidden, self._weight_hh, transpose_y=True) + hgates = elementwise_add(hgates, self._bias_hh) - chunked_igates = layers.split(igates, num_or_sections=4, dim=1) - chunked_hgates = layers.split(hgates, num_or_sections=4, dim=1) + chunked_igates = split(igates, num_or_sections=4, dim=1) + chunked_hgates = split(hgates, num_or_sections=4, dim=1) - ingate = layers.elementwise_add(chunked_igates[0], - chunked_hgates[0]) + ingate = elementwise_add(chunked_igates[0], chunked_hgates[0]) ingate = self._gate_activation(ingate) - forgetgate = layers.elementwise_add(chunked_igates[1], - chunked_hgates[1]) + forgetgate = elementwise_add(chunked_igates[1], chunked_hgates[1]) forgetgate = self._gate_activation(forgetgate) - cellgate = layers.elementwise_add(chunked_igates[2], - chunked_hgates[2]) + cellgate = elementwise_add(chunked_igates[2], chunked_hgates[2]) cellgate = self._activation(cellgate) - outgate = layers.elementwise_add(chunked_igates[3], - chunked_hgates[3]) + outgate = elementwise_add(chunked_igates[3], chunked_hgates[3]) outgate = self._gate_activation(outgate) new_cell = (forgetgate * pre_cell) + (ingate * cellgate) @@ -204,17 +220,16 @@ class LSTMCell(Layer): else: - concat_input_hidden = layers.concat([input, pre_hidden], 1) - gate_input = layers.matmul(x=concat_input_hidden, y=self._weight) - - gate_input = layers.elementwise_add(gate_input, self._bias) - i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1) - new_cell = layers.elementwise_add( - layers.elementwise_mul( - pre_cell, - self._gate_activation( - layers.elementwise_add(f, self._forget_bias))), - layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j))) + concat_input_hidden = concat([input, pre_hidden], 1) + gate_input = matmul(x=concat_input_hidden, y=self._weight) + + gate_input = elementwise_add(gate_input, self._bias) + i, j, f, o = split(gate_input, num_or_sections=4, dim=-1) + new_cell = elementwise_add( + elementwise_mul(pre_cell, + self._gate_activation( + elementwise_add(f, self._forget_bias))), + elementwise_mul(sigmoid(i), tanh(j))) new_hidden = self._activation(new_cell) * self._gate_activation(o) return new_hidden, new_cell @@ -225,28 +240,41 @@ class GRUCell(Layer): GRU implementation using basic operators. There are two GRUCell version, the default one is compatible with CUDNN GRU implementation. The algorithm can be described as the equations below. + .. math:: + u_t & = sigmoid(W_{ux} x_{t} + b_ux + W_{uh} h_{t-1} + b_uh) + r_t & = sigmoid(W_{rx} x_{t} + b_rx + W_{rh} h_{t-1} + b_rh) + \\tilde{h_{t}} & = tanh(W_{cx} x_{t} + b_cx + r_t \\odot (W_{ch} h_{t-1} + b_ch)) + h_t & = u_t h_{t-1} + (1-u_t) \\tilde{h_{t}} + The other LSTMCell version is compatible with the BasicGRUUnit used in static graph. The algorithm can be described as the equations below. + + .. math:: + u_t & = sigmoid(W_{ux} x_{t} + W_{uh} h_{t-1} + b_u) + r_t & = sigmoid(W_{rx} x_{t} + W_{rh} h_{t-1} + b_r) + \\tilde{h_{t}} & = tanh(W_{cx} x_{t} + W_{ch} \\odot(r_t, h_{t-1}) + b_m) + h_t & = u_t h_{t-1} + (1-u_t) \\tilde{h_{t}} + Args: hidden_size (integer): The hidden size used in the Cell. input_size (integer): The input size used in the Cell. param_attr(ParamAttr|None): The parameter attribute for the learnable weight matrix. Note: - If it is set to None or one attribute of ParamAttr, gru_unit will + If it is set to None or one attribute of ParamAttr, GRUCell will create ParamAttr as param_attr. If the Initializer of the param_attr is not set, the parameter is initialized with Xavier. Default: None. bias_attr (ParamAttr|None): The parameter attribute for the bias - of GRU unit. - If it is set to None or one attribute of ParamAttr, gru_unit will + of GRUCell. + If it is set to None or one attribute of ParamAttr, GRUCell will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. gate_activation (function|None): The activation function for gates (actGate). @@ -254,15 +282,18 @@ class GRUCell(Layer): activation (function|None): The activation function for cell (actNode). Default: 'fluid.layers.tanh' use_cudnn_impl(bool|True): whether to use CUDNN compatible LSTMCell - dtype(string): data type used in this unit + dtype(string): data type used in this cell Returns: None + Examples: + .. code-block:: python + from paddle import fluid import paddle.fluid.core as core - from paddle.fluid.dygraph.rnn import GRUCell + from paddle.fluid.dygraph import GRUCell import numpy as np batch_size = 64 input_size = 128 @@ -279,6 +310,7 @@ class GRUCell(Layer): cudnn_gru = GRUCell(hidden_size, input_size) step_input_var = fluid.dygraph.to_variable(step_input_np) pre_hidden_var = fluid.dygraph.to_variable(pre_hidden_np) + """ def __init__(self, @@ -297,8 +329,8 @@ class GRUCell(Layer): self._param_attr = param_attr self._bias_attr = bias_attr self._dtype = dtype - self._gate_activation = gate_activation or layers.sigmoid - self._activation = activation or layers.tanh + self._gate_activation = gate_activation or sigmoid + self._activation = activation or tanh self._use_cudnn_impl = use_cudnn_impl if self._use_cudnn_impl: @@ -391,45 +423,41 @@ class GRUCell(Layer): if self._use_cudnn_impl: - igates = layers.matmul(input, y=self._weight_ih, transpose_y=True) - igates = layers.elementwise_add(igates, self._bias_ih) - hgates = layers.matmul( - pre_hidden, self._weight_hh, transpose_y=True) - hgates = layers.elementwise_add(hgates, self._bias_hh) + igates = matmul(input, y=self._weight_ih, transpose_y=True) + igates = elementwise_add(igates, self._bias_ih) + hgates = matmul(pre_hidden, self._weight_hh, transpose_y=True) + hgates = elementwise_add(hgates, self._bias_hh) - chunked_igates = layers.split(igates, num_or_sections=3, dim=1) - chunked_hgates = layers.split(hgates, num_or_sections=3, dim=1) + chunked_igates = split(igates, num_or_sections=3, dim=1) + chunked_hgates = split(hgates, num_or_sections=3, dim=1) - reset_gate = layers.elementwise_add(chunked_igates[0], - chunked_hgates[0]) + reset_gate = elementwise_add(chunked_igates[0], chunked_hgates[0]) reset_gate = self._gate_activation(reset_gate) - input_gate = layers.elementwise_add(chunked_igates[1], - chunked_hgates[1]) + input_gate = elementwise_add(chunked_igates[1], chunked_hgates[1]) input_gate = self._gate_activation(input_gate) _temp = reset_gate * chunked_hgates[2] - new_gate = layers.elementwise_add(chunked_igates[2], _temp) + new_gate = elementwise_add(chunked_igates[2], _temp) new_gate = self._activation(new_gate) new_hidden = (pre_hidden - new_gate) * input_gate + new_gate else: - concat_input_hidden = layers.concat([input, pre_hidden], 1) + concat_input_hidden = concat([input, pre_hidden], 1) - gate_input = layers.matmul( - x=concat_input_hidden, y=self._gate_weight) + gate_input = matmul(x=concat_input_hidden, y=self._gate_weight) - gate_input = layers.elementwise_add(gate_input, self._gate_bias) + gate_input = elementwise_add(gate_input, self._gate_bias) gate_input = self._gate_activation(gate_input) - r, u = layers.split(gate_input, num_or_sections=2, dim=1) + r, u = split(gate_input, num_or_sections=2, dim=1) r_hidden = r * pre_hidden - candidate = layers.matmul( - layers.concat([input, r_hidden], 1), self._candidate_weight) - candidate = layers.elementwise_add(candidate, self._candidate_bias) + candidate = matmul( + concat([input, r_hidden], 1), self._candidate_weight) + candidate = elementwise_add(candidate, self._candidate_bias) c = self._activation(candidate) new_hidden = u * pre_hidden + (1 - u) * c diff --git a/python/paddle/fluid/tests/unittests/test_cudnn_grucell.py b/python/paddle/fluid/tests/unittests/test_cudnn_grucell.py index 091233f6d70ad824cd2cc3814bf1864eb9b7ae5a..2335293b22e7e59b5ced164736390fc93cf5a683 100644 --- a/python/paddle/fluid/tests/unittests/test_cudnn_grucell.py +++ b/python/paddle/fluid/tests/unittests/test_cudnn_grucell.py @@ -17,7 +17,7 @@ from __future__ import print_function import unittest import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.fluid.dygraph.rnn import GRUCell +from paddle.fluid.dygraph import GRUCell import numpy as np diff --git a/python/paddle/fluid/tests/unittests/test_cudnn_lstmcell.py b/python/paddle/fluid/tests/unittests/test_cudnn_lstmcell.py index 201648690f2980162c02c0d3591209c7553041be..ddba6bc69d25e3abd17592301611f05dfa3510cc 100644 --- a/python/paddle/fluid/tests/unittests/test_cudnn_lstmcell.py +++ b/python/paddle/fluid/tests/unittests/test_cudnn_lstmcell.py @@ -17,7 +17,7 @@ from __future__ import print_function import unittest import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.fluid.dygraph.rnn import LSTMCell +from paddle.fluid.dygraph import LSTMCell import numpy as np