From d7a9bb6e19dd601a554cc157bb741685485cd789 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Thu, 21 Dec 2017 11:31:54 +0800 Subject: [PATCH] add python wrap for sequence_first/last_step --- python/paddle/v2/fluid/layers/nn.py | 10 +++++++++- .../v2/fluid/tests/book/test_machine_translation.py | 2 +- python/paddle/v2/fluid/tests/test_dyn_rnn.py | 5 ++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 59212e8497d..ca073b29144 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -13,7 +13,7 @@ __all__ = [ 'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy', 'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d', 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', - 'lstm_unit', 'reduce_sum' + 'lstm_unit', 'reduce_sum', 'sequence_first_step', 'sequence_last_step' ] @@ -583,6 +583,14 @@ def sequence_pool(input, pool_type, **kwargs): return pool_out +def sequence_first_step(input, **kwargs): + return sequence_pool(input=input, pool_type="first") + + +def sequence_last_step(input, **kwargs): + return sequence_pool(input=input, pool_type="last") + + def pool2d(input, pool_size, pool_type, diff --git a/python/paddle/v2/fluid/tests/book/test_machine_translation.py b/python/paddle/v2/fluid/tests/book/test_machine_translation.py index 80ffc5a544c..e79864b3977 100644 --- a/python/paddle/v2/fluid/tests/book/test_machine_translation.py +++ b/python/paddle/v2/fluid/tests/book/test_machine_translation.py @@ -33,7 +33,7 @@ def encoder_decoder(): fc1 = fluid.layers.fc(input=src_embedding, size=hidden_dim * 4, act='tanh') lstm_hidden0, lstm_0 = layers.dynamic_lstm(input=fc1, size=hidden_dim * 4) - encoder_out = layers.sequence_pool(input=lstm_hidden0, pool_type="last") + encoder_out = layers.sequence_last_step(input=lstm_hidden0) # decoder trg_language_word = layers.data( diff --git a/python/paddle/v2/fluid/tests/test_dyn_rnn.py b/python/paddle/v2/fluid/tests/test_dyn_rnn.py index 034266c26f4..8090c5f4781 100644 --- a/python/paddle/v2/fluid/tests/test_dyn_rnn.py +++ b/python/paddle/v2/fluid/tests/test_dyn_rnn.py @@ -63,8 +63,7 @@ class TestDynRNN(unittest.TestCase): all_timesteps = fluid.layers.array_to_lod_tensor( x=out, table=rank_table) - last = fluid.layers.sequence_pool( - input=all_timesteps, pool_type='last') + last = fluid.layers.sequence_last_step(input=all_timesteps) logits = fluid.layers.fc(input=last, size=1, act=None) loss = fluid.layers.sigmoid_cross_entropy_with_logits( x=logits, label=label) @@ -101,7 +100,7 @@ class TestDynRNN(unittest.TestCase): rnn.update_memory(mem, out_) rnn.output(out_) - last = fluid.layers.sequence_pool(input=rnn(), pool_type='last') + last = fluid.layers.sequence_last_step(input=rnn()) logits = fluid.layers.fc(input=last, size=1, act=None) label = fluid.layers.data(name='label', shape=[1], dtype='float32') loss = fluid.layers.sigmoid_cross_entropy_with_logits( -- GitLab