diff --git a/deepspeech/modules/attention.py b/deepspeech/modules/attention.py index 4ff57a94a0c428672beef1c248bae8b50ef11fa2..eef1e95f0298814855a866f002f98a002ca7d180 100644 --- a/deepspeech/modules/attention.py +++ b/deepspeech/modules/attention.py @@ -98,7 +98,8 @@ class MultiHeadedAttention(nn.Layer): """ n_batch = value.shape[0] if mask is not None: - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + mask = mask.unsqueeze(1).equal( + paddle.to_tensor(0, dtype=mask.dtype)) # (batch, 1, *, time2) scores = scores.masked_fill(mask, -float('inf')) attn = paddle.softmax( scores, axis=-1).masked_fill(mask, @@ -109,8 +110,8 @@ class MultiHeadedAttention(nn.Layer): p_attn = self.dropout(attn) x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) - x = x.transpose([0, 2, 1, 3]).contiguous().reshape([ - n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model) + x = x.transpose([0, 2, 1, 3]).contiguous().reshape( + [n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model)