From 9ee9fefd2de46f2383309f489033fc6d94cd8628 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 19 Dec 2017 11:27:35 +0800 Subject: [PATCH] Change the return order to h, c. --- python/paddle/v2/fluid/layers/nn.py | 8 ++++---- python/paddle/v2/fluid/tests/test_layers.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 31a0a312d..dd6bb5459 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -900,7 +900,7 @@ def lstm_unit(x_t, i_t = \sigma(L_{i_t}) - This layer has two outputs including :math:`o_t` and :math:`h_t`. + This layer has two outputs including :math:`h_t` and :math:`o_t`. Args: x_t (Variable): The input value of current step. @@ -915,7 +915,7 @@ def lstm_unit(x_t, startup_program (Program): the startup program. Returns: - tuple: The cell value and hidden value of lstm unit. + tuple: The hidden value and cell value of lstm unit. Raises: ValueError: The ranks of **x_t**, **hidden_t_prev** and **cell_t_prev**\ @@ -929,7 +929,7 @@ def lstm_unit(x_t, x_t = fluid.layers.fc(input=x_t_data, size=10) prev_hidden = fluid.layers.fc(input=prev_hidden_data, size=20) prev_cell = fluid.layers.fc(input=prev_cell_data, size=30) - cell_value, hidden_value = fluid.layers.lstm_unit(x_t=x_t, + hidden_value, cell_value = fluid.layers.lstm_unit(x_t=x_t, hidden_t_prev=prev_hidden, cell_t_prev=prev_cell) """ @@ -977,4 +977,4 @@ def lstm_unit(x_t, "H": h}, attrs={"forget_bias": forget_bias}) - return c, h + return h, c diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index 7b56ae464..d4a95bf6f 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -161,7 +161,7 @@ class TestBook(unittest.TestCase): x=dat, label=lbl)) print(str(program)) - def test_seq_expand(self): + def test_sequence_expand(self): program = Program() with program_guard(program): x = layers.data(name='x', shape=[10], dtype='float32') -- GitLab