未验证 提交 3d0c8b12 编写于 作者: A Aurelius84 提交者: GitHub

[Unitttet] Fix axes error from migrating paddle.squeeze in test_seq2seq (#48620)

上级 1da6f2e3
......@@ -297,7 +297,7 @@ class BaseModel(fluid.dygraph.Layer):
loss = fluid.layers.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False
)
loss = paddle.squeeze(loss, axes=[2])
loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = fluid.layers.shape(tar)[1]
tar_mask = fluid.layers.sequence_mask(
tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32'
......@@ -831,7 +831,7 @@ class AttentionModel(fluid.dygraph.Layer):
loss = fluid.layers.softmax_with_cross_entropy(
logits=dec_output, label=label, soft_label=False
)
loss = paddle.squeeze(loss, axes=[2])
loss = paddle.squeeze(loss, axis=[2])
max_tar_seq_len = fluid.layers.shape(tar)[1]
tar_mask = fluid.layers.sequence_mask(
tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册