提交 9f4bffbd 编写于 作者: M Megvii Engine Team

fix(mge/tensor): fix valid_broadcast

GitOrigin-RevId: 562b7664e23cd336d942568203df03958b67a4b7
上级 af349d61
...@@ -173,7 +173,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): ...@@ -173,7 +173,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
item.append(True) item.append(True)
v = get_index(v) v = get_index(v)
assert np.issubdtype(v.dtype, np.integer) or np.issubdtype( assert np.issubdtype(v.dtype, np.integer) or np.issubdtype(
v.dtype, np.bool v.dtype, np.bool_
), "var type in the subscript must be int or bool" ), "var type in the subscript must be int or bool"
tensors.append(v) tensors.append(v)
......
...@@ -65,10 +65,10 @@ def _broadcast(inp, shape): ...@@ -65,10 +65,10 @@ def _broadcast(inp, shape):
) )
) )
if isinstance(src, (Tensor, TensorWrapperBase)): if isinstance(src, (TensorBase, TensorWrapperBase)):
src = src.numpy() src = src.numpy()
if isinstance(tar, (Tensor, TensorWrapperBase)): if isinstance(tar, (TensorBase, TensorWrapperBase)):
tar = tar.numpy() tar = tar.numpy()
if len(src) > len(tar): if len(src) > len(tar):
...@@ -78,8 +78,8 @@ def _broadcast(inp, shape): ...@@ -78,8 +78,8 @@ def _broadcast(inp, shape):
if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]:
failed() failed()
valid_broadcast(inp.shape, shape)
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device)
valid_broadcast(inp.shape, shape)
(result,) = apply(builtin.Broadcast(), inp, shape) (result,) = apply(builtin.Broadcast(), inp, shape)
return result return result
......
...@@ -379,3 +379,18 @@ def test_trace_nms(): ...@@ -379,3 +379,18 @@ def test_trace_nms():
f(*make_inputs(10)) f(*make_inputs(10))
f(*make_inputs(20)) f(*make_inputs(20))
f(*make_inputs(30)) f(*make_inputs(30))
def test_trace_valid_broadcast():
set_tensor_shape(True)
x1 = tensor(np.random.randn(1, 1))
x2 = tensor(np.random.randn(1, 2))
shape = (tensor([2]), tensor([2]))
@trace(symbolic=False)
def f(x, shape):
y = F.broadcast_to(x, shape)
return y
f(x1, shape)
f(x2, shape)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册