You need to sign in or sign up before continuing.
提交 f0e797e5 编写于 作者: Y yangyaming

Doc fix and enhancement for lstm_unit python wrapper.

上级 39502e6e
...@@ -1168,25 +1168,26 @@ def lstm_unit(x_t, ...@@ -1168,25 +1168,26 @@ def lstm_unit(x_t,
.. math:: .. math::
i_t & = \sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i) i_t & = \sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + b_i)
f_t & = \sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f) f_t & = \sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + b_f)
c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c) c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t + W_{h_c}h_{t-1} + b_c)
o_t & = \sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o) o_t & = \sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + b_o)
h_t & = o_t tanh(c_t) h_t & = o_t tanh(c_t)
The inputs of lstm unit includes :math:`x_t`, :math:`h_{t-1}` and The inputs of lstm unit include :math:`x_t`, :math:`h_{t-1}` and
:math:`c_{t-1}`. The implementation separates the linear transformation :math:`c_{t-1}`. The 2nd dimensions of :math:`h_{t-1}` and :math:`c_{t-1}`
and non-linear transformation apart. Here, we take :math:`i_t` as an should be same. The implementation separates the linear transformation and
example. The linear transformation is applied by calling a `fc` layer and non-linear transformation apart. Here, we take :math:`i_t` as an example.
the equation is: The linear transformation is applied by calling a `fc` layer and the
equation is:
.. math:: .. math::
L_{i_t} = W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i L_{i_t} = W_{x_i}x_{t} + W_{h_i}h_{t-1} + b_i
The non-linear transformation is applied by calling `lstm_unit_op` and the The non-linear transformation is applied by calling `lstm_unit_op` and the
equation is: equation is:
...@@ -1213,14 +1214,15 @@ def lstm_unit(x_t, ...@@ -1213,14 +1214,15 @@ def lstm_unit(x_t,
Raises: Raises:
ValueError: The ranks of **x_t**, **hidden_t_prev** and **cell_t_prev**\ ValueError: The ranks of **x_t**, **hidden_t_prev** and **cell_t_prev**\
not be 2 or the 1st dimensions of **x_t**, **hidden_t_prev** \ not be 2 or the 1st dimensions of **x_t**, **hidden_t_prev** \
and **cell_t_prev** not be the same. and **cell_t_prev** not be the same or the 2nd dimensions of \
**hidden_t_prev** and **cell_t_prev** not be the same.
Examples: Examples:
.. code-block:: python .. code-block:: python
x_t = fluid.layers.fc(input=x_t_data, size=10) x_t = fluid.layers.fc(input=x_t_data, size=10)
prev_hidden = fluid.layers.fc(input=prev_hidden_data, size=20) prev_hidden = fluid.layers.fc(input=prev_hidden_data, size=30)
prev_cell = fluid.layers.fc(input=prev_cell_data, size=30) prev_cell = fluid.layers.fc(input=prev_cell_data, size=30)
hidden_value, cell_value = fluid.layers.lstm_unit(x_t=x_t, hidden_value, cell_value = fluid.layers.lstm_unit(x_t=x_t,
hidden_t_prev=prev_hidden, hidden_t_prev=prev_hidden,
...@@ -1239,7 +1241,11 @@ def lstm_unit(x_t, ...@@ -1239,7 +1241,11 @@ def lstm_unit(x_t,
if x_t.shape[0] != hidden_t_prev.shape[0] or x_t.shape[ if x_t.shape[0] != hidden_t_prev.shape[0] or x_t.shape[
0] != cell_t_prev.shape[0]: 0] != cell_t_prev.shape[0]:
raise ValueError("The 1s dimension of x_t, hidden_t_prev and " raise ValueError("The 1s dimensions of x_t, hidden_t_prev and "
"cell_t_prev must be the same.")
if hidden_t_prev.shape[1] != cell_t_prev.shape[1]:
raise ValueError("The 2nd dimensions of hidden_t_prev and "
"cell_t_prev must be the same.") "cell_t_prev must be the same.")
if bias_attr is None: if bias_attr is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册