提交 77e8940c 编写于 作者: H Hui Zhang

fix broadcast_shape

上级 42a5bdf8
......@@ -25,6 +25,7 @@ __all__ = ["masked_fill", "pad_sequence", "add_sos_eos", "th_accuracy"]
logger = Log(__name__).getlog()
@paddle.jit.not_to_static
def is_broadcastable(shp1, shp2):
for a, b in zip(shp1[::-1], shp2[::-1]):
if a == 1 or b == 1 or a == b:
......@@ -39,10 +40,8 @@ def masked_fill(xs: paddle.Tensor,
value: Union[float, int]):
if paddle.in_dynamic_mode():
assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
else:
# TODO(Hui Zhang): support broadcast_shape in static graph
bshape = xs.shape
# broadcast_shape input should be `list`
bshape = paddle.broadcast_shape(list(xs.shape), list(mask.shape))
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册