From 641d0e77300300e48f6c0ed588618e2712dee2d8 Mon Sep 17 00:00:00 2001 From: wwhu Date: Fri, 5 May 2017 15:49:40 +0800 Subject: [PATCH] bug fix --- scheduled_sampling/scheduled_sampling.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scheduled_sampling/scheduled_sampling.py b/scheduled_sampling/scheduled_sampling.py index 3caf2300..30d15425 100644 --- a/scheduled_sampling/scheduled_sampling.py +++ b/scheduled_sampling/scheduled_sampling.py @@ -62,7 +62,7 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): decoder_boot += paddle.layer.full_matrix_projection( input=backward_first) - def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word, generated_word, true_token_flag): + def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word, true_token_flag): decoder_mem = paddle.layer.memory( name='gru_decoder', size=decoder_size, boot_layer=decoder_boot) @@ -72,7 +72,10 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): encoded_proj=enc_proj, decoder_state=decoder_mem) - current_word = paddle.layer.multiplex([true_token_flag, true_word, generated_word]) + generated_word_memory = paddle.layer.memory( + name='generated_word', size=1, boot_with_const_id=0) + + current_word = paddle.layer.multiplex([true_token_flag, true_word, generated_word_memory]) with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs: decoder_inputs += paddle.layer.full_matrix_projection(input=context) @@ -90,6 +93,9 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): bias_attr=True, act=paddle.activation.Softmax()) as out: out += paddle.layer.full_matrix_projection(input=gru_step) + + max_id(input=out, name='generated_word') + return out def gru_decoder_with_attention_test(enc_vec, enc_proj, current_word): -- GitLab