未验证 提交 c6482444 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #7766 from guoshengCS/add-python-GRU

Add python wrapper for GRU
......@@ -18,6 +18,11 @@ dynamic_lstm
.. autofunction:: paddle.v2.fluid.layers.dynamic_lstm
:noindex:
dynamic_gru
-----------
.. autofunction:: paddle.v2.fluid.layers.dynamic_gru
:noindex:
data
----
.. autofunction:: paddle.v2.fluid.layers.data
......
......@@ -26,6 +26,7 @@ __all__ = [
'fc',
'embedding',
'dynamic_lstm',
'dynamic_gru',
'gru_unit',
'linear_chain_crf',
'crf_decoding',
......@@ -368,6 +369,113 @@ def dynamic_lstm(input,
return hidden, cell
def dynamic_gru(input,
size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
h_0=None):
"""
**Dynamic GRU Layer**
Refer to `Empirical Evaluation of Gated Recurrent Neural Networks on
Sequence Modeling <https://arxiv.org/abs/1412.3555>`_
The formula is as follows:
.. math::
u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u)
r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r)
\\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c)
h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t}
The :math:`\odot` is the element-wise product of the vectors. :math:`act_g`
is the update gate and reset gate activation function and :math:`sigmoid`
is usually used for it. :math:`act_c` is the activation function for
candidate hidden state and :math:`tanh` is usually used for it.
Note that these :math:`W_{ux}x_{t}, W_{rx}x_{t}, W_{cx}x_{t}` operations on
the input :math:`x_{t}` are NOT included in this operator. Users can choose
to use fully-connect layer before GRU layer.
Args:
input(Variable): The input of dynamic_gru layer, which supports
variable-time length input sequence. The underlying tensor in this
Variable is a matrix with shape :math:`(T \\times 3D)`, where
:math:`T` is the total time steps in this mini-batch, :math:`D`
is the hidden size.
size(int): The dimension of the gru cell.
param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weight matrix. Note:
- The shape of the weight matrix is :math:`(T \\times 3D)`, where
:math:`D` is the hidden size.
- All elements in the weight matrix can be divided into two parts.
The first part are weights of the update gate and reset gate with
shape :math:`(D \\times 2D)`, and the second part are weights for
candidate hidden state with shape :math:`(D \\times D)`.
bias_attr(ParamAttr): The parameter attribute for learnable the
hidden-hidden bias.
is_reverse(bool): Whether to compute reversed GRU, default
:attr:`False`.
gate_activation(str): The activation for update gate and reset gate.
Choices = ["sigmoid", "tanh", "relu", "identity"], default "sigmoid".
activation(str): The activation for candidate hidden state.
Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh".
Returns:
Variable: The hidden state of GRU. The shape is (T \\times D), and lod \
is the same with the input.
Examples:
.. code-block:: python
hidden_dim = 512
x = fluid.layers.fc(input=data, size=hidden_dim * 3)
hidden = fluid.layers.dynamic_gru(input=x, dim=hidden_dim)
"""
helper = LayerHelper('gru', **locals())
dtype = helper.input_dtype()
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype)
bias = helper.create_parameter(
attr=helper.bias_attr, shape=[1, 3 * size], dtype=dtype, is_bias=True)
inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
if h_0 != None:
assert h_0.shape == (
size, size), 'The shape of h0 should be(%d, %d)' % (size, size)
inputs['h0'] = h_0
hidden = helper.create_tmp_variable(dtype)
batch_gate = helper.create_tmp_variable(dtype)
batch_reset_hidden_prev = helper.create_tmp_variable(dtype)
batch_hidden = helper.create_tmp_variable(dtype)
helper.append_op(
type='gru',
inputs=inputs,
outputs={
'Hidden': hidden,
'BatchGate': batch_gate,
'BatchResetHiddenPrev': batch_reset_hidden_prev,
'BatchHidden': batch_hidden
},
attrs={
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'activation': candidate_activation
})
return hidden
def gru_unit(input,
hidden,
size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册