提交 fa04af90 编写于 作者: H Hui Zhang

eq to equal

上级 5e32139e
...@@ -98,7 +98,8 @@ class MultiHeadedAttention(nn.Layer): ...@@ -98,7 +98,8 @@ class MultiHeadedAttention(nn.Layer):
""" """
n_batch = value.shape[0] n_batch = value.shape[0]
if mask is not None: if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) mask = mask.unsqueeze(1).equal(
paddle.to_tensor(0, dtype=mask.dtype)) # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf')) scores = scores.masked_fill(mask, -float('inf'))
attn = paddle.softmax( attn = paddle.softmax(
scores, axis=-1).masked_fill(mask, scores, axis=-1).masked_fill(mask,
...@@ -109,8 +110,8 @@ class MultiHeadedAttention(nn.Layer): ...@@ -109,8 +110,8 @@ class MultiHeadedAttention(nn.Layer):
p_attn = self.dropout(attn) p_attn = self.dropout(attn)
x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose([0, 2, 1, 3]).contiguous().reshape([ x = x.transpose([0, 2, 1, 3]).contiguous().reshape(
n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model) [n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册