提交 d7a9bb6e 编写于 作者: L Luo Tao

add python wrap for sequence_first/last_step

上级 22022017
......@@ -13,7 +13,7 @@ __all__ = [
'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy',
'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d',
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand',
'lstm_unit', 'reduce_sum'
'lstm_unit', 'reduce_sum', 'sequence_first_step', 'sequence_last_step'
]
......@@ -583,6 +583,14 @@ def sequence_pool(input, pool_type, **kwargs):
return pool_out
def sequence_first_step(input, **kwargs):
return sequence_pool(input=input, pool_type="first")
def sequence_last_step(input, **kwargs):
return sequence_pool(input=input, pool_type="last")
def pool2d(input,
pool_size,
pool_type,
......
......@@ -33,7 +33,7 @@ def encoder_decoder():
fc1 = fluid.layers.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
lstm_hidden0, lstm_0 = layers.dynamic_lstm(input=fc1, size=hidden_dim * 4)
encoder_out = layers.sequence_pool(input=lstm_hidden0, pool_type="last")
encoder_out = layers.sequence_last_step(input=lstm_hidden0)
# decoder
trg_language_word = layers.data(
......
......@@ -63,8 +63,7 @@ class TestDynRNN(unittest.TestCase):
all_timesteps = fluid.layers.array_to_lod_tensor(
x=out, table=rank_table)
last = fluid.layers.sequence_pool(
input=all_timesteps, pool_type='last')
last = fluid.layers.sequence_last_step(input=all_timesteps)
logits = fluid.layers.fc(input=last, size=1, act=None)
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=logits, label=label)
......@@ -101,7 +100,7 @@ class TestDynRNN(unittest.TestCase):
rnn.update_memory(mem, out_)
rnn.output(out_)
last = fluid.layers.sequence_pool(input=rnn(), pool_type='last')
last = fluid.layers.sequence_last_step(input=rnn())
logits = fluid.layers.fc(input=last, size=1, act=None)
label = fluid.layers.data(name='label', shape=[1], dtype='float32')
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册