diff --git a/paddlespeech/s2t/__init__.py b/paddlespeech/s2t/__init__.py index 3c704b272bcd62de08ae50f7ecb801a88f9ed57e..b67322cdc2f4ac305b42b399ea5feb87207e0375 100644 --- a/paddlespeech/s2t/__init__.py +++ b/paddlespeech/s2t/__init__.py @@ -166,15 +166,9 @@ def broadcast_shape(shp1, shp2): def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): - bshape = broadcast_shape(xs.shape, mask.shape) mask.stop_gradient = True - tmp = paddle.ones(shape=[len(bshape)], dtype='int32') - for index in range(len(bshape)): - tmp[index] = bshape[index] - mask = mask.broadcast_to(tmp) - trues = paddle.ones_like(xs) * value - xs = paddle.where(mask, trues, xs) - return xs + mask = mask.astype(xs.dtype) + return xs * (1.0 - mask) + mask * value if not hasattr(paddle.Tensor, 'masked_fill'):