未验证 提交 1b48f2f7 编写于 作者: S smallv0221 提交者: GitHub

Fix en doc for rnn.py. test=document_fix (#27835)

* Fix en doc for rnn.py. test=document_fix
上级 049696bf
...@@ -48,7 +48,7 @@ def split_states(states, bidirectional=False, state_components=1): ...@@ -48,7 +48,7 @@ def split_states(states, bidirectional=False, state_components=1):
Split states of RNN network into possibly nested list or tuple of Split states of RNN network into possibly nested list or tuple of
states of each RNN cells of the RNN network. states of each RNN cells of the RNN network.
Arguments: Parameters:
states (Tensor|tuple|list): the concatenated states for RNN network. states (Tensor|tuple|list): the concatenated states for RNN network.
When `state_components` is 1, states in a Tensor with shape When `state_components` is 1, states in a Tensor with shape
`(L*D, N, C)` where `L` is the number of layers of the RNN `(L*D, N, C)` where `L` is the number of layers of the RNN
...@@ -101,7 +101,7 @@ def concat_states(states, bidirectional=False, state_components=1): ...@@ -101,7 +101,7 @@ def concat_states(states, bidirectional=False, state_components=1):
Concatenate a possibly nested list or tuple of RNN cell states into a Concatenate a possibly nested list or tuple of RNN cell states into a
compact form. compact form.
Arguments: Parameters:
states (list|tuple): a possibly nested list or tuple of RNN cell states (list|tuple): a possibly nested list or tuple of RNN cell
states. states.
If `bidirectional` is True, it can be indexed twice to get an If `bidirectional` is True, it can be indexed twice to get an
...@@ -154,13 +154,14 @@ class RNNCellBase(Layer): ...@@ -154,13 +154,14 @@ class RNNCellBase(Layer):
r""" r"""
Generate initialized states according to provided shape, data type and Generate initialized states according to provided shape, data type and
value. value.
Arguments:
Parameters:
batch_ref (Tensor): A tensor, which shape would be used to batch_ref (Tensor): A tensor, which shape would be used to
determine the batch size, which is used to generate initial determine the batch size, which is used to generate initial
states. For `batch_ref`'s shape d, `d[batch_dim_idx]` is states. For `batch_ref`'s shape d, `d[batch_dim_idx]` is
treated as batch size. treated as batch size.
shape (list|tuple, optional): A (possibly nested structure of) shape[s], shape (list|tuple, optional): A (possibly nested structure of) shape[s],
where a shape is a list/tuple of integer). `-1` (for batch size) where a shape is a list/tuple of integer. `-1` (for batch size)
will be automatically prepended if a shape does not starts with will be automatically prepended if a shape does not starts with
it. If None, property `state_shape` will be used. Defaults to it. If None, property `state_shape` will be used. Defaults to
None. None.
...@@ -174,6 +175,7 @@ class RNNCellBase(Layer): ...@@ -174,6 +175,7 @@ class RNNCellBase(Layer):
Defaults to 0. Defaults to 0.
batch_dim_idx (int, optional): An integer indicating which batch_dim_idx (int, optional): An integer indicating which
dimension of the of `batch_ref` represents batch. Defaults to 0. dimension of the of `batch_ref` represents batch. Defaults to 0.
Returns: Returns:
init_states (Tensor|tuple|list): tensor of the provided shape and init_states (Tensor|tuple|list): tensor of the provided shape and
dtype, or list of tensors that each satisfies the requirements, dtype, or list of tensors that each satisfies the requirements,
...@@ -268,16 +270,14 @@ class SimpleRNNCell(RNNCellBase): ...@@ -268,16 +270,14 @@ class SimpleRNNCell(RNNCellBase):
The formula used is as follows: The formula used is as follows:
.. math:: .. math::
h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h{t-1} + b_{hh}) h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise y_{t} & = h_{t}
multiplication operator.
Please refer to `Finding Structure in Time Please refer to `Finding Structure in Time
<https://crl.ucsd.edu/~elman/Papers/fsit.pdf>`_ for more details. <https://crl.ucsd.edu/~elman/Papers/fsit.pdf>`_ for more details.
Arguments: Parameters:
input_size (int): The input size. input_size (int): The input size.
hidden_size (int): The hidden size. hidden_size (int): The hidden size.
activation (str, optional): The activation in the SimpleRNN cell. activation (str, optional): The activation in the SimpleRNN cell.
...@@ -293,7 +293,7 @@ class SimpleRNNCell(RNNCellBase): ...@@ -293,7 +293,7 @@ class SimpleRNNCell(RNNCellBase):
name (str, optional): Name for the operation (optional, default is name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`. None). For more information, please refer to :ref:`api_guide_Name`.
Parameters: Attributes:
weight_ih (Parameter): shape (hidden_size, input_size), input to hidden weight_ih (Parameter): shape (hidden_size, input_size), input to hidden
weight, corresponding to :math:`W_{ih}` in the formula. weight, corresponding to :math:`W_{ih}` in the formula.
weight_hh (Parameter): shape (hidden_size, hidden_size), hidden to weight_hh (Parameter): shape (hidden_size, hidden_size), hidden to
...@@ -329,13 +329,15 @@ class SimpleRNNCell(RNNCellBase): ...@@ -329,13 +329,15 @@ class SimpleRNNCell(RNNCellBase):
.. code-block:: python .. code-block:: python
import paddle import paddle
paddle.disable_static()
x = paddle.randn((4, 16)) x = paddle.randn((4, 16))
prev_h = paddle.randn((4, 32)) prev_h = paddle.randn((4, 32))
cell = paddle.nn.SimpleRNNCell(16, 32) cell = paddle.nn.SimpleRNNCell(16, 32)
y, h = cell(x, prev_h) y, h = cell(x, prev_h)
print(y.shape)
#[4,32]
""" """
...@@ -407,20 +409,26 @@ class LSTMCell(RNNCellBase): ...@@ -407,20 +409,26 @@ class LSTMCell(RNNCellBase):
.. math:: .. math::
i_{t} & = \sigma(W_{ii}x_{t} + b_{ii} + W_{hi}h_{t-1} + b_{hi}) i_{t} & = \sigma(W_{ii}x_{t} + b_{ii} + W_{hi}h_{t-1} + b_{hi})
f_{t} & = \sigma(W_{if}x_{t} + b_{if} + W_{hf}h_{t-1} + b_{hf}) f_{t} & = \sigma(W_{if}x_{t} + b_{if} + W_{hf}h_{t-1} + b_{hf})
o_{t} & = \sigma(W_{io}x_{t} + b_{io} + W_{ho}h_{t-1} + b_{ho}) o_{t} & = \sigma(W_{io}x_{t} + b_{io} + W_{ho}h_{t-1} + b_{ho})
\\widetilde{c}_{t} & = \\tanh (W_{ig}x_{t} + b_{ig} + W_{hg}h_{t-1} + b_{hg})
c_{t} & = f_{t} \* c{t-1} + i{t} \* \\widetile{c}_{t} \widetilde{c}_{t} & = \tanh (W_{ig}x_{t} + b_{ig} + W_{hg}h_{t-1} + b_{hg})
h_{t} & = o_{t} \* \\tanh(c_{t})
c_{t} & = f_{t} * c_{t-1} + i_{t} * \widetilde{c}_{t}
h_{t} & = o_{t} * \tanh(c_{t})
y_{t} & = h_{t} y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator. multiplication operator.
Please refer to `An Empirical Exploration of Recurrent Network Architectures Please refer to `An Empirical Exploration of Recurrent Network Architectures
<http://proceedings.mlr.press/v37/jozefowicz15.pdf>`_ for more details. <http://proceedings.mlr.press/v37/jozefowicz15.pdf>`_ for more details.
Arguments: Parameters:
input_size (int): The input size. input_size (int): The input size.
hidden_size (int): The hidden size. hidden_size (int): The hidden size.
weight_ih_attr(ParamAttr, optional): The parameter attribute for weight_ih_attr(ParamAttr, optional): The parameter attribute for
...@@ -434,7 +442,7 @@ class LSTMCell(RNNCellBase): ...@@ -434,7 +442,7 @@ class LSTMCell(RNNCellBase):
name (str, optional): Name for the operation (optional, default is name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`. None). For more information, please refer to :ref:`api_guide_Name`.
Parameters: Attributes:
weight_ih (Parameter): shape (4 * hidden_size, input_size), input to weight_ih (Parameter): shape (4 * hidden_size, input_size), input to
hidden weight, which corresponds to the concatenation of hidden weight, which corresponds to the concatenation of
:math:`W_{ii}, W_{if}, W_{ig}, W_{io}` in the formula. :math:`W_{ii}, W_{if}, W_{ig}, W_{io}` in the formula.
...@@ -462,7 +470,7 @@ class LSTMCell(RNNCellBase): ...@@ -462,7 +470,7 @@ class LSTMCell(RNNCellBase):
corresponding to :math:`h_{t}` in the formula. corresponding to :math:`h_{t}` in the formula.
states (tuple): a tuple of two tensors, each of shape states (tuple): a tuple of two tensors, each of shape
`[batch_size, hidden_size]`, the new hidden states, `[batch_size, hidden_size]`, the new hidden states,
corresponding to :math:`h_{t}, c{t}` in the formula. corresponding to :math:`h_{t}, c_{t}` in the formula.
Notes: Notes:
All the weights and bias are initialized with `Uniform(-std, std)` by All the weights and bias are initialized with `Uniform(-std, std)` by
...@@ -475,7 +483,6 @@ class LSTMCell(RNNCellBase): ...@@ -475,7 +483,6 @@ class LSTMCell(RNNCellBase):
.. code-block:: python .. code-block:: python
import paddle import paddle
paddle.disable_static()
x = paddle.randn((4, 16)) x = paddle.randn((4, 16))
prev_h = paddle.randn((4, 32)) prev_h = paddle.randn((4, 32))
...@@ -484,6 +491,14 @@ class LSTMCell(RNNCellBase): ...@@ -484,6 +491,14 @@ class LSTMCell(RNNCellBase):
cell = paddle.nn.LSTMCell(16, 32) cell = paddle.nn.LSTMCell(16, 32)
y, (h, c) = cell(x, (prev_h, prev_c)) y, (h, c) = cell(x, (prev_h, prev_c))
print(y.shape)
print(h.shape)
print(c.shape)
#[4,32]
#[4,32]
#[4,32]
""" """
def __init__(self, def __init__(self,
...@@ -562,12 +577,16 @@ class GRUCell(RNNCellBase): ...@@ -562,12 +577,16 @@ class GRUCell(RNNCellBase):
.. math:: .. math::
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}x_{t} + b_{hr}) r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}x_{t} + b_{hr})
z_{t} & = \sigma(W_{iz)x_{t} + b_{iz} + W_{hz}x_{t} + b_{hz})
\\widetilde{h}_{t} & = \\tanh(W_{ic)x_{t} + b_{ic} + r_{t} \* (W_{hc}x_{t} + b{hc})) z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}x_{t} + b_{hz})
h_{t} & = z_{t} \* h_{t-1} + (1 - z_{t}) \* \\widetilde{h}_{t}
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}x_{t} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t} y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator. multiplication operator.
Please refer to `An Empirical Exploration of Recurrent Network Architectures Please refer to `An Empirical Exploration of Recurrent Network Architectures
...@@ -587,7 +606,7 @@ class GRUCell(RNNCellBase): ...@@ -587,7 +606,7 @@ class GRUCell(RNNCellBase):
name (str, optional): Name for the operation (optional, default is name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`. None). For more information, please refer to :ref:`api_guide_Name`.
Parameters: Attributes:
weight_ih (Parameter): shape (3 * hidden_size, input_size), input to weight_ih (Parameter): shape (3 * hidden_size, input_size), input to
hidden weight, which corresponds to the concatenation of hidden weight, which corresponds to the concatenation of
:math:`W_{ir}, W_{iz}, W_{ic}` in the formula. :math:`W_{ir}, W_{iz}, W_{ic}` in the formula.
...@@ -625,7 +644,6 @@ class GRUCell(RNNCellBase): ...@@ -625,7 +644,6 @@ class GRUCell(RNNCellBase):
.. code-block:: python .. code-block:: python
import paddle import paddle
paddle.disable_static()
x = paddle.randn((4, 16)) x = paddle.randn((4, 16))
prev_h = paddle.randn((4, 32)) prev_h = paddle.randn((4, 32))
...@@ -633,6 +651,12 @@ class GRUCell(RNNCellBase): ...@@ -633,6 +651,12 @@ class GRUCell(RNNCellBase):
cell = paddle.nn.GRUCell(16, 32) cell = paddle.nn.GRUCell(16, 32)
y, h = cell(x, prev_h) y, h = cell(x, prev_h)
print(y.shape)
print(h.shape)
#[4,32]
#[4,32]
""" """
def __init__(self, def __init__(self,
...@@ -707,7 +731,7 @@ class RNN(Layer): ...@@ -707,7 +731,7 @@ class RNN(Layer):
It performs :code:`cell.forward()` repeatedly until reaches to the maximum It performs :code:`cell.forward()` repeatedly until reaches to the maximum
length of `inputs`. length of `inputs`.
Arguments: Parameters:
cell(RNNCellBase): An instance of `RNNCellBase`. cell(RNNCellBase): An instance of `RNNCellBase`.
is_reverse (bool, optional): Indicate whether to calculate in the reverse is_reverse (bool, optional): Indicate whether to calculate in the reverse
order of input sequences. Defaults to False. order of input sequences. Defaults to False.
...@@ -717,8 +741,8 @@ class RNN(Layer): ...@@ -717,8 +741,8 @@ class RNN(Layer):
Inputs: Inputs:
inputs (Tensor): A (possibly nested structure of) tensor[s]. The input inputs (Tensor): A (possibly nested structure of) tensor[s]. The input
sequences. sequences.
If time major is True, the shape is `[batch_size, time_steps, input_size]` If time major is False, the shape is `[batch_size, time_steps, input_size]`
If time major is False, the shape is [time_steps, batch_size, input_size]` If time major is True, the shape is `[time_steps, batch_size, input_size]`
where `input_size` is the input size of the cell. where `input_size` is the input size of the cell.
initial_states (Tensor|list|tuple, optional): Tensor of a possibly initial_states (Tensor|list|tuple, optional): Tensor of a possibly
nested structure of tensors, representing the initial state for nested structure of tensors, representing the initial state for
...@@ -753,7 +777,6 @@ class RNN(Layer): ...@@ -753,7 +777,6 @@ class RNN(Layer):
.. code-block:: python .. code-block:: python
import paddle import paddle
paddle.disable_static()
inputs = paddle.rand((4, 23, 16)) inputs = paddle.rand((4, 23, 16))
prev_h = paddle.randn((4, 32)) prev_h = paddle.randn((4, 32))
...@@ -762,6 +785,12 @@ class RNN(Layer): ...@@ -762,6 +785,12 @@ class RNN(Layer):
rnn = paddle.nn.RNN(cell) rnn = paddle.nn.RNN(cell)
outputs, final_states = rnn(inputs, prev_h) outputs, final_states = rnn(inputs, prev_h)
print(outputs.shape)
print(final_states.shape)
#[4,23,32]
#[4,32]
""" """
def __init__(self, cell, is_reverse=False, time_major=False): def __init__(self, cell, is_reverse=False, time_major=False):
...@@ -795,7 +824,7 @@ class BiRNN(Layer): ...@@ -795,7 +824,7 @@ class BiRNN(Layer):
backward RNN with coresponding cells separately and concats the outputs backward RNN with coresponding cells separately and concats the outputs
along the last axis. along the last axis.
Arguments: Parameters:
cell_fw (RNNCellBase): A RNNCellBase instance used for forward RNN. cell_fw (RNNCellBase): A RNNCellBase instance used for forward RNN.
cell_bw (RNNCellBase): A RNNCellBase instance used for backward RNN. cell_bw (RNNCellBase): A RNNCellBase instance used for backward RNN.
time_major (bool): Whether the first dimension of the input means the time_major (bool): Whether the first dimension of the input means the
...@@ -841,7 +870,6 @@ class BiRNN(Layer): ...@@ -841,7 +870,6 @@ class BiRNN(Layer):
.. code-block:: python .. code-block:: python
import paddle import paddle
paddle.disable_static()
cell_fw = paddle.nn.LSTMCell(16, 32) cell_fw = paddle.nn.LSTMCell(16, 32)
cell_bw = paddle.nn.LSTMCell(16, 32) cell_bw = paddle.nn.LSTMCell(16, 32)
...@@ -850,6 +878,12 @@ class BiRNN(Layer): ...@@ -850,6 +878,12 @@ class BiRNN(Layer):
inputs = paddle.rand((2, 23, 16)) inputs = paddle.rand((2, 23, 16))
outputs, final_states = rnn(inputs) outputs, final_states = rnn(inputs)
print(outputs.shape)
print(final_states[0][0].shape,len(final_states),len(final_states[0]))
#[4,23,64]
#[2,32] 2 2
""" """
def __init__(self, cell_fw, cell_bw, time_major=False): def __init__(self, cell_fw, cell_bw, time_major=False):
...@@ -936,13 +970,11 @@ class SimpleRNN(RNNMixin): ...@@ -936,13 +970,11 @@ class SimpleRNN(RNNMixin):
.. math:: .. math::
h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h{t-1} + b_{hh}) h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise y_{t} & = h_{t}
multiplication operator.
Arguments: Parameters:
input_size (int): The input size for the first layer's cell. input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell. hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1. num_layers (int, optional): Number of layers. Defaults to 1.
...@@ -997,7 +1029,6 @@ class SimpleRNN(RNNMixin): ...@@ -997,7 +1029,6 @@ class SimpleRNN(RNNMixin):
.. code-block:: python .. code-block:: python
import paddle import paddle
paddle.disable_static()
rnn = paddle.nn.SimpleRNN(16, 32, 2) rnn = paddle.nn.SimpleRNN(16, 32, 2)
...@@ -1005,6 +1036,12 @@ class SimpleRNN(RNNMixin): ...@@ -1005,6 +1036,12 @@ class SimpleRNN(RNNMixin):
prev_h = paddle.randn((2, 4, 32)) prev_h = paddle.randn((2, 4, 32))
y, h = rnn(x, prev_h) y, h = rnn(x, prev_h)
print(y.shape)
print(h.shape)
#[4,23,32]
#[2,4,32]
""" """
def __init__(self, def __init__(self,
...@@ -1077,17 +1114,23 @@ class LSTM(RNNMixin): ...@@ -1077,17 +1114,23 @@ class LSTM(RNNMixin):
.. math:: .. math::
i_{t} & = \sigma(W_{ii}x_{t} + b_{ii} + W_{hi}h_{t-1} + b_{hi}) i_{t} & = \sigma(W_{ii}x_{t} + b_{ii} + W_{hi}h_{t-1} + b_{hi})
f_{t} & = \sigma(W_{if}x_{t} + b_{if} + W_{hf}h_{t-1} + b_{hf}) f_{t} & = \sigma(W_{if}x_{t} + b_{if} + W_{hf}h_{t-1} + b_{hf})
o_{t} & = \sigma(W_{io}x_{t} + b_{io} + W_{ho}h_{t-1} + b_{ho}) o_{t} & = \sigma(W_{io}x_{t} + b_{io} + W_{ho}h_{t-1} + b_{ho})
\\widetilde{c}_{t} & = \\tanh (W_{ig}x_{t} + b_{ig} + W_{hg}h_{t-1} + b_{hg})
c_{t} & = f_{t} \* c{t-1} + i{t} \* \\widetile{c}_{t} \widetilde{c}_{t} & = \tanh (W_{ig}x_{t} + b_{ig} + W_{hg}h_{t-1} + b_{hg})
h_{t} & = o_{t} \* \\tanh(c_{t})
c_{t} & = f_{t} * c_{t-1} + i_{t} * \widetilde{c}_{t}
h_{t} & = o_{t} * \tanh(c_{t})
y_{t} & = h_{t} y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator. multiplication operator.
Arguments: Parameters:
input_size (int): The input size for the first layer's cell. input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell. hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1. num_layers (int, optional): Number of layers. Defaults to 1.
...@@ -1130,7 +1173,7 @@ class LSTM(RNNMixin): ...@@ -1130,7 +1173,7 @@ class LSTM(RNNMixin):
`[batch_size, time_steps, num_directions * hidden_size]`. `[batch_size, time_steps, num_directions * hidden_size]`.
Note that `num_directions` is 2 if direction is "bidirectional" Note that `num_directions` is 2 if direction is "bidirectional"
else 1. else 1.
final_states (Tensor): the final state, a tuple of two tensors, h and c. final_states (tuple): the final state, a tuple of two tensors, h and c.
The shape of each is The shape of each is
`[num_lauers * num_directions, batch_size, hidden_size]`. `[num_lauers * num_directions, batch_size, hidden_size]`.
Note that `num_directions` is 2 if direction is "bidirectional" Note that `num_directions` is 2 if direction is "bidirectional"
...@@ -1141,7 +1184,6 @@ class LSTM(RNNMixin): ...@@ -1141,7 +1184,6 @@ class LSTM(RNNMixin):
.. code-block:: python .. code-block:: python
import paddle import paddle
paddle.disable_static()
rnn = paddle.nn.LSTM(16, 32, 2) rnn = paddle.nn.LSTM(16, 32, 2)
...@@ -1150,6 +1192,14 @@ class LSTM(RNNMixin): ...@@ -1150,6 +1192,14 @@ class LSTM(RNNMixin):
prev_c = paddle.randn((2, 4, 32)) prev_c = paddle.randn((2, 4, 32))
y, (h, c) = rnn(x, (prev_h, prev_c)) y, (h, c) = rnn(x, (prev_h, prev_c))
print(y.shape)
print(h.shape)
print(c.shape)
#[4,23,32]
#[2,4,32]
#[2,4,32]
""" """
def __init__(self, def __init__(self,
...@@ -1215,15 +1265,19 @@ class GRU(RNNMixin): ...@@ -1215,15 +1265,19 @@ class GRU(RNNMixin):
.. math:: .. math::
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}x_{t} + b_{hr}) r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}x_{t} + b_{hr})
z_{t} & = \sigma(W_{iz)x_{t} + b_{iz} + W_{hz}x_{t} + b_{hz})
\\widetilde{h}_{t} & = \\tanh(W_{ic)x_{t} + b_{ic} + r_{t} \* (W_{hc}x_{t} + b{hc})) z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}x_{t} + b_{hz})
h_{t} & = z_{t} \* h_{t-1} + (1 - z_{t}) \* \\widetilde{h}_{t}
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}x_{t} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t} y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and \* is the elemetwise where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator. multiplication operator.
Arguments: Parameters:
input_size (int): The input size for the first layer's cell. input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell. hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1. num_layers (int, optional): Number of layers. Defaults to 1.
...@@ -1277,7 +1331,6 @@ class GRU(RNNMixin): ...@@ -1277,7 +1331,6 @@ class GRU(RNNMixin):
.. code-block:: python .. code-block:: python
import paddle import paddle
paddle.disable_static()
rnn = paddle.nn.GRU(16, 32, 2) rnn = paddle.nn.GRU(16, 32, 2)
...@@ -1285,6 +1338,12 @@ class GRU(RNNMixin): ...@@ -1285,6 +1338,12 @@ class GRU(RNNMixin):
prev_h = paddle.randn((2, 4, 32)) prev_h = paddle.randn((2, 4, 32))
y, h = rnn(x, prev_h) y, h = rnn(x, prev_h)
print(y.shape)
print(h.shape)
#[4,23,32]
#[2,4,32]
""" """
def __init__(self, def __init__(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册