提交 efb63225 编写于 作者: H Hui Zhang

fix mask_fill

上级 77e8940c
......@@ -19,6 +19,7 @@ import paddle
from paddle import nn
from typeguard import check_argument_types
from deepspeech.utils.tensor_utils import masked_fill
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
......@@ -121,7 +122,8 @@ class ConvolutionModule(nn.Layer):
# mask batch padding
if mask_pad is not None:
x = x.masked_fill(mask_pad, 0.0)
# TODO(Hui Zhang): `x = x.masked_fill(mask_pad, 0.0)` for jit
x = masked_fill(x, mask_pad, 0.0)
if self.lorder > 0:
if cache is None:
......
......@@ -40,9 +40,11 @@ def masked_fill(xs: paddle.Tensor,
value: Union[float, int]):
if paddle.in_dynamic_mode():
assert is_broadcastable(xs.shape, mask.shape) is True
# broadcast_shape input should be `list`
bshape = paddle.broadcast_shape(list(xs.shape), list(mask.shape))
mask = mask.broadcast_to(bshape)
# TODO(Hui Zhang): broadcast_shape input should be `list`
bshape = paddle.broadcast_shape(list(xs.shape), list(mask.shape))
# TODO(Hui Zhang): broadcast_to use `-1` as copy dim
# mask = mask.broadcast_to(bshape)
mask = mask.broadcast_to(bshape).reshape(list(xs.shape))
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
return xs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册