diff --git a/PaddleNLP/seq2seq/variational_seq2seq/model.py b/PaddleNLP/seq2seq/variational_seq2seq/model.py index f3f7d77f87a028627dad65b830f7b9b7aa683f80..0225ba1d194995c05e24e27a5cfc72420855f960 100644 --- a/PaddleNLP/seq2seq/variational_seq2seq/model.py +++ b/PaddleNLP/seq2seq/variational_seq2seq/model.py @@ -230,11 +230,10 @@ class VAE(object): # `sample_output_layer` samples an id from the logits distribution instead of argmax(logits) # it will be used within BeamSearchDecoder sample_output_layer = lambda x: layers.unsqueeze(fluid.one_hot( - layers.unsqueeze( layers.sampling_id( layers.softmax( layers.squeeze(output_layer(x),[1]) - ),dtype='int'), [1]), + ),dtype='int'), depth=self.tar_vocab_size), [1]) if mode == 'train':