From 60fecce43db68281112a91198d85a79a972f03f9 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 3 Jan 2018 15:20:00 +0800 Subject: [PATCH] Fix unit test for lstm_unit. --- python/paddle/v2/fluid/layers/nn.py | 9 ++++++--- python/paddle/v2/fluid/tests/test_layers.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 5442cce49..1c1c09dd2 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -1199,9 +1199,12 @@ def lstm_unit(x_t, This layer has two outputs including :math:`h_t` and :math:`o_t`. Args: - x_t (Variable): The input value of current step, a 2-D tensor. - hidden_t_prev (Variable): The hidden value of lstm unit, a 2-D tensor. - cell_t_prev (Variable): The cell value of lstm unit, a 2-D tensor. + x_t (Variable): The input value of current step, a 2-D tensor with shape + M x N, M for batch size and N for input size. + hidden_t_prev (Variable): The hidden value of lstm unit, a 2-D tensor + with shape M x S, M for batch size and S for size of lstm unit. + cell_t_prev (Variable): The cell value of lstm unit, a 2-D tensor with + shape M x S, M for batch size and S for size of lstm unit. forget_bias (float): The forget bias of lstm unit. param_attr (ParamAttr): The attributes of parameter weights, used to set initializer, name etc. diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index 9d2dcca56..77f0f11f1 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -177,8 +177,8 @@ class TestBook(unittest.TestCase): name='x_t_data', shape=[10, 10], dtype='float32') x_t = layers.fc(input=x_t_data, size=10) prev_hidden_data = layers.data( - name='prev_hidden_data', shape=[10, 20], dtype='float32') - prev_hidden = layers.fc(input=prev_hidden_data, size=20) + name='prev_hidden_data', shape=[10, 30], dtype='float32') + prev_hidden = layers.fc(input=prev_hidden_data, size=30) prev_cell_data = layers.data( name='prev_cell', shape=[10, 30], dtype='float32') prev_cell = layers.fc(input=prev_cell_data, size=30) -- GitLab