提交 fa04af90 编写于 作者: H Hui Zhang

eq to equal

上级 5e32139e
......@@ -98,7 +98,8 @@ class MultiHeadedAttention(nn.Layer):
"""
n_batch = value.shape[0]
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
mask = mask.unsqueeze(1).equal(
paddle.to_tensor(0, dtype=mask.dtype)) # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = paddle.softmax(
scores, axis=-1).masked_fill(mask,
......@@ -109,8 +110,8 @@ class MultiHeadedAttention(nn.Layer):
p_attn = self.dropout(attn)
x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose([0, 2, 1, 3]).contiguous().reshape([
n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model)
x = x.transpose([0, 2, 1, 3]).contiguous().reshape(
[n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册