提交 60fecce4 编写于 作者: Y yangyaming

Fix unit test for lstm_unit.

上级 c0f6f492
...@@ -1199,9 +1199,12 @@ def lstm_unit(x_t, ...@@ -1199,9 +1199,12 @@ def lstm_unit(x_t,
This layer has two outputs including :math:`h_t` and :math:`o_t`. This layer has two outputs including :math:`h_t` and :math:`o_t`.
Args: Args:
x_t (Variable): The input value of current step, a 2-D tensor. x_t (Variable): The input value of current step, a 2-D tensor with shape
hidden_t_prev (Variable): The hidden value of lstm unit, a 2-D tensor. M x N, M for batch size and N for input size.
cell_t_prev (Variable): The cell value of lstm unit, a 2-D tensor. hidden_t_prev (Variable): The hidden value of lstm unit, a 2-D tensor
with shape M x S, M for batch size and S for size of lstm unit.
cell_t_prev (Variable): The cell value of lstm unit, a 2-D tensor with
shape M x S, M for batch size and S for size of lstm unit.
forget_bias (float): The forget bias of lstm unit. forget_bias (float): The forget bias of lstm unit.
param_attr (ParamAttr): The attributes of parameter weights, used to set param_attr (ParamAttr): The attributes of parameter weights, used to set
initializer, name etc. initializer, name etc.
......
...@@ -177,8 +177,8 @@ class TestBook(unittest.TestCase): ...@@ -177,8 +177,8 @@ class TestBook(unittest.TestCase):
name='x_t_data', shape=[10, 10], dtype='float32') name='x_t_data', shape=[10, 10], dtype='float32')
x_t = layers.fc(input=x_t_data, size=10) x_t = layers.fc(input=x_t_data, size=10)
prev_hidden_data = layers.data( prev_hidden_data = layers.data(
name='prev_hidden_data', shape=[10, 20], dtype='float32') name='prev_hidden_data', shape=[10, 30], dtype='float32')
prev_hidden = layers.fc(input=prev_hidden_data, size=20) prev_hidden = layers.fc(input=prev_hidden_data, size=30)
prev_cell_data = layers.data( prev_cell_data = layers.data(
name='prev_cell', shape=[10, 30], dtype='float32') name='prev_cell', shape=[10, 30], dtype='float32')
prev_cell = layers.fc(input=prev_cell_data, size=30) prev_cell = layers.fc(input=prev_cell_data, size=30)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册