From efb32c744bbc5a15ce4ba533ea199d6fd5bc1c0e Mon Sep 17 00:00:00 2001 From: wuxing03 Date: Tue, 19 May 2020 13:47:15 +0000 Subject: [PATCH] fix vae dim error --- PaddleNLP/seq2seq/variational_seq2seq/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/PaddleNLP/seq2seq/variational_seq2seq/model.py b/PaddleNLP/seq2seq/variational_seq2seq/model.py index f3f7d77f..0225ba1d 100644 --- a/PaddleNLP/seq2seq/variational_seq2seq/model.py +++ b/PaddleNLP/seq2seq/variational_seq2seq/model.py @@ -230,11 +230,10 @@ class VAE(object): # `sample_output_layer` samples an id from the logits distribution instead of argmax(logits) # it will be used within BeamSearchDecoder sample_output_layer = lambda x: layers.unsqueeze(fluid.one_hot( - layers.unsqueeze( layers.sampling_id( layers.softmax( layers.squeeze(output_layer(x),[1]) - ),dtype='int'), [1]), + ),dtype='int'), depth=self.tar_vocab_size), [1]) if mode == 'train': -- GitLab