提交 57c4eccf 编写于 作者: M Megvii Engine Team

fix(mge/functional): add shape check for bc

GitOrigin-RevId: e152c1928c6336102995ce9c51ac6a874b1cd7d1
上级 c92044c5
......@@ -57,6 +57,28 @@ def _transpose(data, axes):
def _broadcast(inp, shape):
def valid_broadcast(src, tar):
def failed():
raise ValueError(
"the input shape {} can not be broadcasted to target shape {}".format(
src, tar
)
)
if isinstance(src, (Tensor, TensorWrapperBase)):
src = src.numpy()
if isinstance(tar, (Tensor, TensorWrapperBase)):
tar = tar.numpy()
if len(src) > len(tar):
failed()
for i in range(min(len(src), len(tar))):
if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]:
failed()
valid_broadcast(inp.shape, shape)
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), inp, shape)
return result
......
......@@ -240,7 +240,7 @@ def test_broadcast():
output1_shape = (30, 20, 30)
data1 = np.random.random(input1_shape).astype(np.float32)
input2_shape = (10, 20)
input2_shape = (10, 1)
output2_shape = (20, 10, 20)
data2 = np.random.random(input2_shape).astype(np.float32)
......@@ -253,6 +253,16 @@ def test_broadcast():
]
opr_test(cases, F.broadcast, compare_fn=compare_fn)
x = F.ones((2, 1, 3))
with pytest.raises(ValueError):
F.broadcast(x, (2, 3, 4))
with pytest.raises(ValueError):
F.broadcast(x, (4, 1, 3))
with pytest.raises(ValueError):
F.broadcast(x, (1, 3))
def test_utils_astensor1d():
reference = tensor(0)
......
......@@ -340,3 +340,20 @@ def test_raise_on_trace():
step_count += 1
assert catch_count == 1
def test_trace_broadcast():
for symbolic in [False, True]:
set_tensor_shape(True)
x1 = tensor(np.random.randn(3, 1, 1))
x2 = tensor(np.random.randn(1, 4, 1))
x3 = tensor(np.random.randn(1, 1, 5))
@trace(symbolic=symbolic, capture_as_const=True)
def f(x):
y = x.broadcast((3, 4, 5))
return y
f(x1)
f(x2)
f(x3)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册