diff --git a/deepspeech/modules/attention.py b/deepspeech/modules/attention.py index eef1e95f0298814855a866f002f98a002ca7d180..977cdd583d00c549dbbce5a731eee883226d7db2 100644 --- a/deepspeech/modules/attention.py +++ b/deepspeech/modules/attention.py @@ -21,6 +21,7 @@ from paddle import nn from paddle.nn import initializer as I from deepspeech.utils.log import Log +from deepspeech.utils.tensor_utils import masked_fill logger = Log(__name__).getlog() @@ -100,17 +101,17 @@ class MultiHeadedAttention(nn.Layer): if mask is not None: mask = mask.unsqueeze(1).equal( paddle.to_tensor(0, dtype=mask.dtype)) # (batch, 1, *, time2) - scores = scores.masked_fill(mask, -float('inf')) + scores = masked_fill(scores, mask, -float('inf')) attn = paddle.softmax( - scores, axis=-1).masked_fill(mask, - 0.0) # (batch, head, time1, time2) + scores, axis=-1) + attn = masked_fill(attn, mask, 0.0) # (batch, head, time1, time2) else: attn = paddle.softmax( scores, axis=-1) # (batch, head, time1, time2) p_attn = self.dropout(attn) x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) - x = x.transpose([0, 2, 1, 3]).contiguous().reshape( + x = x.transpose([0, 2, 1, 3]).reshape( [n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 17becf6dc08079a8068761149208cc8351e00583..d9e91a5076b7bbeb458b5ff3dfd8b797433ed575 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -14,15 +14,35 @@ """Unility functions for Transformer.""" from typing import List from typing import Tuple +from typing import Optional +from typing import Union import paddle from deepspeech.utils.log import Log -__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"] +__all__ = ["masked_fill", "pad_sequence", "add_sos_eos", "th_accuracy"] logger = Log(__name__).getlog() +def is_broadcastable(shp1, shp2): + for a, b in zip(shp1[::-1], shp2[::-1]): + if a == 1 or b == 1 or a == b: + pass + else: + return False + return True + +def masked_fill(xs: paddle.Tensor, + mask: paddle.Tensor, + value: Union[float, int]): + if paddle.in_dynamic_mode(): + assert is_broadcastable(xs.shape, mask.shape) is True + bshape = paddle.broadcast_shape(xs.shape, mask.shape) + mask = mask.broadcast_to(bshape) + trues = paddle.ones_like(xs) * value + xs = paddle.where(mask, trues, xs) + return xs def pad_sequence(sequences: List[paddle.Tensor], batch_first: bool=False, @@ -164,4 +184,4 @@ def th_accuracy(pad_outputs: paddle.Tensor, #TODO(Hui Zhang): sum not support bool type # denominator = paddle.sum(mask) denominator = paddle.sum(mask.astype(pad_targets.dtype)) - return float(numerator) / float(denominator) + return float(numerator) / float(denominator) \ No newline at end of file