diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 66834c5d860cdb3f6d5526aa9e134985df2a7529..c6fbe1b3fe682808158e9f7186435876003f31ae 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -25,6 +25,7 @@ __all__ = ["masked_fill", "pad_sequence", "add_sos_eos", "th_accuracy"] logger = Log(__name__).getlog() +@paddle.jit.not_to_static def is_broadcastable(shp1, shp2): for a, b in zip(shp1[::-1], shp2[::-1]): if a == 1 or b == 1 or a == b: @@ -39,10 +40,8 @@ def masked_fill(xs: paddle.Tensor, value: Union[float, int]): if paddle.in_dynamic_mode(): assert is_broadcastable(xs.shape, mask.shape) is True - bshape = paddle.broadcast_shape(xs.shape, mask.shape) - else: - # TODO(Hui Zhang): support broadcast_shape in static graph - bshape = xs.shape + # broadcast_shape input should be `list` + bshape = paddle.broadcast_shape(list(xs.shape), list(mask.shape)) mask = mask.broadcast_to(bshape) trues = paddle.ones_like(xs) * value xs = paddle.where(mask, trues, xs)