未验证 提交 a9fce95b 编写于 作者: X Xing Wu 提交者: GitHub

fix vae dim error (#4650)

上级 29e4dc5e
......@@ -230,11 +230,10 @@ class VAE(object):
# `sample_output_layer` samples an id from the logits distribution instead of argmax(logits)
# it will be used within BeamSearchDecoder
sample_output_layer = lambda x: layers.unsqueeze(fluid.one_hot(
layers.unsqueeze(
layers.sampling_id(
layers.softmax(
layers.squeeze(output_layer(x),[1])
),dtype='int'), [1]),
),dtype='int'),
depth=self.tar_vocab_size), [1])
if mode == 'train':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册