提交 efb63225 编写于 作者: H Hui Zhang

fix mask_fill

上级 77e8940c
...@@ -19,6 +19,7 @@ import paddle ...@@ -19,6 +19,7 @@ import paddle
from paddle import nn from paddle import nn
from typeguard import check_argument_types from typeguard import check_argument_types
from deepspeech.utils.tensor_utils import masked_fill
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
...@@ -121,7 +122,8 @@ class ConvolutionModule(nn.Layer): ...@@ -121,7 +122,8 @@ class ConvolutionModule(nn.Layer):
# mask batch padding # mask batch padding
if mask_pad is not None: if mask_pad is not None:
x = x.masked_fill(mask_pad, 0.0) # TODO(Hui Zhang): `x = x.masked_fill(mask_pad, 0.0)` for jit
x = masked_fill(x, mask_pad, 0.0)
if self.lorder > 0: if self.lorder > 0:
if cache is None: if cache is None:
......
...@@ -40,9 +40,11 @@ def masked_fill(xs: paddle.Tensor, ...@@ -40,9 +40,11 @@ def masked_fill(xs: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
if paddle.in_dynamic_mode(): if paddle.in_dynamic_mode():
assert is_broadcastable(xs.shape, mask.shape) is True assert is_broadcastable(xs.shape, mask.shape) is True
# broadcast_shape input should be `list` # TODO(Hui Zhang): broadcast_shape input should be `list`
bshape = paddle.broadcast_shape(list(xs.shape), list(mask.shape)) bshape = paddle.broadcast_shape(list(xs.shape), list(mask.shape))
mask = mask.broadcast_to(bshape) # TODO(Hui Zhang): broadcast_to use `-1` as copy dim
# mask = mask.broadcast_to(bshape)
mask = mask.broadcast_to(bshape).reshape(list(xs.shape))
trues = paddle.ones_like(xs) * value trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs) xs = paddle.where(mask, trues, xs)
return xs return xs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册