From 734e87e55b00418aed0fac5a879b2704d62cf3ab Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 15 Dec 2017 20:08:55 +0800 Subject: [PATCH] Add python wrapper for lstm unit op. --- doc/api/v2/fluid/layers.rst | 11 +- python/paddle/v2/fluid/layers/nn.py | 112 +++++++++++++++++++- python/paddle/v2/fluid/tests/test_layers.py | 17 +++ 3 files changed, 132 insertions(+), 8 deletions(-) diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst index 89e5fec13..0ab36402f 100644 --- a/doc/api/v2/fluid/layers.rst +++ b/doc/api/v2/fluid/layers.rst @@ -188,12 +188,6 @@ beam_search_decode :noindex: -lstm ---------- -.. autofunction:: paddle.v2.fluid.layers.lstm - :noindex: - - lod_rank_table --------- .. autofunction:: paddle.v2.fluid.layers.lod_rank_table @@ -300,3 +294,8 @@ conv2d_transpose .. autofunction:: paddle.v2.fluid.layers.conv2d_transpose :noindex: + +lstm_unit +--------- +.. autofunction:: paddle.v2.fluid.layers.lstm_unit + :noindex: diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index bad7dbd84..84e62d988 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -5,12 +5,13 @@ All layers just related to the neural network. from ..layer_helper import LayerHelper from ..initializer import Normal, Constant from ..framework import Variable +from tensor import concat __all__ = [ 'fc', 'embedding', 'dynamic_lstm', 'gru_unit', 'linear_chain_crf', 'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy', 'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d', - 'batch_norm', 'beam_search_decode', 'conv2d_transpose' + 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'lstm_unit' ] @@ -392,7 +393,7 @@ def chunk_eval(input, excluded_chunk_types=None, **kwargs): """ - This function computes and outputs the precision, recall and + This function computes and outputs the precision, recall and F1-score of chunk detection. """ helper = LayerHelper("chunk_eval", **kwargs) @@ -789,3 +790,110 @@ def conv2d_transpose(input, attrs=op_attr) return out + + +def lstm_unit(x_t, + hidden_t_prev, + cell_t_prev, + forget_bias=0.0, + main_program=None, + startup_program=None): + """Lstm unit layer. The equation of a lstm step is: + + .. math:: + + i_t & = \sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i) + + f_t & = \sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f) + + c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c) + + o_t & = \sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o) + + h_t & = o_t tanh(c_t) + + The inputs of lstm unit includes :math:`x_t`, :math:`h_{t-1}` and + :math:`c_{t-1}`. The implementation separates the linear transformation + and non-linear transformation apart. Here, we take :math:`i_t` as an + example. The linear transformation is applied by calling a `fc` layer and + the equation is: + + .. math:: + + L_{i_t} = W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i + + The non-linear transformation is applied by calling `lstm_unit_op` and the + equation is: + + .. math:: + + i_t = \sigma(L_{i_t}) + + This layer has two outputs including :math:`o_t` and :math:`h_t`. + + Args: + x_t (Variable): The input value of current step. + hidden_t_prev (Variable): The hidden value of lstm unit. + cell_t_prev (Variable): The cell value of lstm unit. + forget_bias (float): The forget bias of lstm unit. + main_program (Program): The main program. + startup_program (Program): the startup program. + + Returns: + tuple: The cell value and hidden value of lstm unit. + + Raises: + ValueError: The ranks of **x_t**, **hidden_t_prev** and **cell_t_prev**\ + not be 2 or the 1st dimensions of **x_t**, **hidden_t_prev** \ + and **cell_t_prev** not be the same. + + Examples: + + .. code-block:: python + + x_t = fluid.layers.fc(input=x_t_data, size=10) + prev_hidden = fluid.layers.fc(input=prev_hidden_data, size=20) + prev_cell = fluid.layers.fc(input=prev_cell_data, size=30) + cell_value, hidden_value = fluid.layers.lstm_unit(x_t=x_t, + hidden_t_prev=prev_hidden, + cell_t_prev=prev_cell) + """ + helper = LayerHelper('lstm_unit', **locals()) + + if len(x_t.shape) != 2: + raise ValueError("Rank of x_t must be 2.") + + if len(hidden_t_prev.shape) != 2: + raise ValueError("Rank of hidden_t_prev must be 2.") + + if len(cell_t_prev.shape) != 2: + raise ValueError("Rank of cell_t_prev must be 2.") + + if x_t.shape[0] != hidden_t_prev.shape[0] or x_t.shape[ + 0] != cell_t_prev.shape[0]: + raise ValueError("The 1s dimension of x_t, hidden_t_prev and " + "cell_t_prev must be the same.") + + size = cell_t_prev.shape[1] + concat_out = concat( + input=[x_t, hidden_t_prev], + axis=1, + main_program=main_program, + startup_program=startup_program) + fc_out = fc(input=concat_out, + size=4 * size, + main_program=main_program, + startup_program=startup_program) + dtype = x_t.dtype + c = helper.create_tmp_variable(dtype) + h = helper.create_tmp_variable(dtype) + + helper.append_op( + type='lstm_unit', + inputs={"X": fc_out, + "C_prev": cell_t_prev}, + outputs={"C": c, + "H": h}, + attrs={"forget_bias": forget_bias}) + + return c, h diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index 9b8808015..468bd4128 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -161,6 +161,23 @@ class TestBook(unittest.TestCase): x=dat, label=lbl)) print(str(program)) + def test_lstm_unit(self): + program = Program() + with program_guard(program): + x_t_data = layers.data( + name='x_t_data', shape=[10, 10], dtype='float32') + x_t = layers.fc(input=x_t_data, size=10) + prev_hidden_data = layers.data( + name='prev_hidden_data', shape=[10, 20], dtype='float32') + prev_hidden = layers.fc(input=prev_hidden_data, size=20) + prev_cell_data = layers.data( + name='prev_cell', shape=[10, 30], dtype='float32') + prev_cell = layers.fc(input=prev_cell_data, size=30) + self.assertIsNotNone( + layers.lstm_unit( + x_t=x_t, hidden_t_prev=prev_hidden, cell_t_prev=prev_cell)) + print(str(program)) + if __name__ == '__main__': unittest.main() -- GitLab