提交 68cde873 编写于 作者: M Megvii Engine Team

fix(mge/imperative): support broadcast with None

GitOrigin-RevId: dd330a2a1dc603ea52a655b350ee1a421015e7d7
上级 0bdd0b14
......@@ -99,7 +99,39 @@ def _transpose(data, axes):
def _broadcast(inp, shape):
shape = astensor1d(shape, inp, dtype="int32", device=inp.device)
auto_infer = False
if isinstance(shape, (list, tuple)):
shape_tuple = list(shape)
for i, s in enumerate(shape_tuple):
if isinstance(s, type(None)):
if s is None:
right = i - len(shape_tuple)
inp_shape = inp._tuple_shape
if len(inp_shape) + right >= 0:
shape_tuple[right] = list(inp_shape)[right]
auto_infer = True
continue
else:
raise ValueError("invalided Broadcast shape")
else:
raise ValueError(
"expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format(
i, s
)
)
if s < 0:
raise ValueError(
"expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format(
i, s
)
)
if auto_infer:
shape = tuple(shape_tuple)
try:
shape_tuple = make_shape_tuple(shape)
except ValueError:
shape_tuple = shape
shape = astensor1d(shape_tuple, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), inp, shape)
return result
......
......@@ -223,6 +223,34 @@ def test_reshape(is_varnode):
np.testing.assert_equal(yy.numpy(), y)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_broadcast_auto_infer(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = np.random.random((1, 2, 3)).astype(np.float32)
xx = make_tensor(x, network)
for shape in [
(1, 2, 3),
(1, None, 3),
]:
yy = F.broadcast_to(xx, shape)
np.testing.assert_equal(yy.numpy(), x)
with pytest.raises(ValueError):
F.broadcast_to(xx, (1, -1, 3))
with pytest.raises(ValueError):
F.broadcast_to(xx, (None, 1, 2, 3))
F.broadcast_to(xx, (1, None, 2, 3))
t = tensor(2, dtype=np.int32)
F.broadcast_to(xx, (t, None, 2, 3))
@pytest.mark.parametrize("is_trace", [True, False])
def test_reshape_on_empty_tensor(is_trace):
input1_shape = (100, 0, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册