提交 7a2e6dea 编写于 作者: P peterzhang2029

fix test_rnn_encoder_decoder

上级 7278aa7b
......@@ -118,12 +118,13 @@ def seq_to_seq_net():
src_forward, src_backward = bi_lstm_encoder(
input_seq=src_embedding, hidden_size=encoder_size)
encoded_vector = fluid.layers.concat(
input=[src_forward, src_backward], axis=1)
src_forward_last = fluid.layers.sequence_last_step(input=src_forward)
src_backward_first = fluid.layers.sequence_first_step(input=src_backward)
enc_vec_last = fluid.layers.sequence_last_step(input=encoded_vector)
encoded_vector = fluid.layers.concat(
input=[src_forward_last, src_backward_first], axis=1)
decoder_boot = fluid.layers.fc(input=enc_vec_last,
decoder_boot = fluid.layers.fc(input=encoded_vector,
size=decoder_size,
bias_attr=False,
act='tanh')
......@@ -137,7 +138,7 @@ def seq_to_seq_net():
dtype='float32')
prediction = lstm_decoder_without_attention(trg_embedding, decoder_boot,
enc_vec_last, decoder_size)
encoded_vector, decoder_size)
label = fluid.layers.data(
name='label_sequence', shape=[1], dtype='int64', lod_level=1)
cost = fluid.layers.cross_entropy(input=prediction, label=label)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册