提交 3c2dcfad 编写于 作者: H Hui Zhang

masked_fill and continues

上级 fa04af90
......@@ -21,6 +21,7 @@ from paddle import nn
from paddle.nn import initializer as I
from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import masked_fill
logger = Log(__name__).getlog()
......@@ -100,17 +101,17 @@ class MultiHeadedAttention(nn.Layer):
if mask is not None:
mask = mask.unsqueeze(1).equal(
paddle.to_tensor(0, dtype=mask.dtype)) # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
scores = masked_fill(scores, mask, -float('inf'))
attn = paddle.softmax(
scores, axis=-1).masked_fill(mask,
0.0) # (batch, head, time1, time2)
scores, axis=-1)
attn = masked_fill(attn, mask, 0.0) # (batch, head, time1, time2)
else:
attn = paddle.softmax(
scores, axis=-1) # (batch, head, time1, time2)
p_attn = self.dropout(attn)
x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose([0, 2, 1, 3]).contiguous().reshape(
x = x.transpose([0, 2, 1, 3]).reshape(
[n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
......
......@@ -14,15 +14,35 @@
"""Unility functions for Transformer."""
from typing import List
from typing import Tuple
from typing import Optional
from typing import Union
import paddle
from deepspeech.utils.log import Log
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"]
__all__ = ["masked_fill", "pad_sequence", "add_sos_eos", "th_accuracy"]
logger = Log(__name__).getlog()
def is_broadcastable(shp1, shp2):
for a, b in zip(shp1[::-1], shp2[::-1]):
if a == 1 or b == 1 or a == b:
pass
else:
return False
return True
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
if paddle.in_dynamic_mode():
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
return xs
def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册