From 3bd88f6a6734c1dd0bdb96583e5991868a7c9517 Mon Sep 17 00:00:00 2001 From: wwhu Date: Fri, 5 May 2017 16:09:11 +0800 Subject: [PATCH] bug fix --- scheduled_sampling/scheduled_sampling.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scheduled_sampling/scheduled_sampling.py b/scheduled_sampling/scheduled_sampling.py index 3e58786f..f52ac215 100644 --- a/scheduled_sampling/scheduled_sampling.py +++ b/scheduled_sampling/scheduled_sampling.py @@ -75,7 +75,12 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): generated_word_memory = paddle.layer.memory( name='generated_word', size=1, boot_with_const_id=0) - current_word = paddle.layer.multiplex(input=[true_token_flag, true_word, generated_word_memory]) + generated_word_emb = embedding( + input=generated_word_memory, + size=word_vector_dim, + param_attr=paddle.attr.ParamAttr(name='_target_language_embedding')) + + current_word = paddle.layer.multiplex(input=[true_token_flag, true_word, generated_word_emb]) with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs: decoder_inputs += paddle.layer.full_matrix_projection(input=context) -- GitLab