From b20bf7d5dee23eef82ef4a810db2eafe8752e6d8 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 27 Sep 2022 08:47:22 +0000 Subject: [PATCH] masked_fill by multiply, remove while --- paddlespeech/s2t/__init__.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/paddlespeech/s2t/__init__.py b/paddlespeech/s2t/__init__.py index 3c704b27..b67322cd 100644 --- a/paddlespeech/s2t/__init__.py +++ b/paddlespeech/s2t/__init__.py @@ -166,15 +166,9 @@ def broadcast_shape(shp1, shp2): def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): - bshape = broadcast_shape(xs.shape, mask.shape) mask.stop_gradient = True - tmp = paddle.ones(shape=[len(bshape)], dtype='int32') - for index in range(len(bshape)): - tmp[index] = bshape[index] - mask = mask.broadcast_to(tmp) - trues = paddle.ones_like(xs) * value - xs = paddle.where(mask, trues, xs) - return xs + mask = mask.astype(xs.dtype) + return xs * (1.0 - mask) + mask * value if not hasattr(paddle.Tensor, 'masked_fill'): -- GitLab