From 77e8940cbc8bebcf57e253b6bfd62cc7db0292ea Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 13 Jul 2021 07:45:15 +0000 Subject: [PATCH] fix broadcast_shape --- deepspeech/utils/tensor_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 66834c5d..c6fbe1b3 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -25,6 +25,7 @@ __all__ = ["masked_fill", "pad_sequence", "add_sos_eos", "th_accuracy"] logger = Log(__name__).getlog() +@paddle.jit.not_to_static def is_broadcastable(shp1, shp2): for a, b in zip(shp1[::-1], shp2[::-1]): if a == 1 or b == 1 or a == b: @@ -39,10 +40,8 @@ def masked_fill(xs: paddle.Tensor, value: Union[float, int]): if paddle.in_dynamic_mode(): assert is_broadcastable(xs.shape, mask.shape) is True - bshape = paddle.broadcast_shape(xs.shape, mask.shape) - else: - # TODO(Hui Zhang): support broadcast_shape in static graph - bshape = xs.shape + # broadcast_shape input should be `list` + bshape = paddle.broadcast_shape(list(xs.shape), list(mask.shape)) mask = mask.broadcast_to(bshape) trues = paddle.ones_like(xs) * value xs = paddle.where(mask, trues, xs) -- GitLab