提交 3c2dcfad 编写于 作者: H Hui Zhang

masked_fill and continues

上级 fa04af90
...@@ -21,6 +21,7 @@ from paddle import nn ...@@ -21,6 +21,7 @@ from paddle import nn
from paddle.nn import initializer as I from paddle.nn import initializer as I
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import masked_fill
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -100,17 +101,17 @@ class MultiHeadedAttention(nn.Layer): ...@@ -100,17 +101,17 @@ class MultiHeadedAttention(nn.Layer):
if mask is not None: if mask is not None:
mask = mask.unsqueeze(1).equal( mask = mask.unsqueeze(1).equal(
paddle.to_tensor(0, dtype=mask.dtype)) # (batch, 1, *, time2) paddle.to_tensor(0, dtype=mask.dtype)) # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf')) scores = masked_fill(scores, mask, -float('inf'))
attn = paddle.softmax( attn = paddle.softmax(
scores, axis=-1).masked_fill(mask, scores, axis=-1)
0.0) # (batch, head, time1, time2) attn = masked_fill(attn, mask, 0.0) # (batch, head, time1, time2)
else: else:
attn = paddle.softmax( attn = paddle.softmax(
scores, axis=-1) # (batch, head, time1, time2) scores, axis=-1) # (batch, head, time1, time2)
p_attn = self.dropout(attn) p_attn = self.dropout(attn)
x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose([0, 2, 1, 3]).contiguous().reshape( x = x.transpose([0, 2, 1, 3]).reshape(
[n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model) [n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model)
......
...@@ -14,15 +14,35 @@ ...@@ -14,15 +14,35 @@
"""Unility functions for Transformer.""" """Unility functions for Transformer."""
from typing import List from typing import List
from typing import Tuple from typing import Tuple
from typing import Optional
from typing import Union
import paddle import paddle
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy"] __all__ = ["masked_fill", "pad_sequence", "add_sos_eos", "th_accuracy"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
def is_broadcastable(shp1, shp2):
for a, b in zip(shp1[::-1], shp2[::-1]):
if a == 1 or b == 1 or a == b:
pass
else:
return False
return True
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
if paddle.in_dynamic_mode():
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
return xs
def pad_sequence(sequences: List[paddle.Tensor], def pad_sequence(sequences: List[paddle.Tensor],
batch_first: bool=False, batch_first: bool=False,
...@@ -164,4 +184,4 @@ def th_accuracy(pad_outputs: paddle.Tensor, ...@@ -164,4 +184,4 @@ def th_accuracy(pad_outputs: paddle.Tensor,
#TODO(Hui Zhang): sum not support bool type #TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask) # denominator = paddle.sum(mask)
denominator = paddle.sum(mask.astype(pad_targets.dtype)) denominator = paddle.sum(mask.astype(pad_targets.dtype))
return float(numerator) / float(denominator) return float(numerator) / float(denominator)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册