diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index fe15d2bf461f60b0b3dfdd97fe32969e272e53bc..a7c086c8d9f257c74c2d540f7c97ef8279ea1337 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -99,7 +99,39 @@ def _transpose(data, axes): def _broadcast(inp, shape): - shape = astensor1d(shape, inp, dtype="int32", device=inp.device) + auto_infer = False + if isinstance(shape, (list, tuple)): + shape_tuple = list(shape) + for i, s in enumerate(shape_tuple): + if isinstance(s, type(None)): + if s is None: + right = i - len(shape_tuple) + inp_shape = inp._tuple_shape + if len(inp_shape) + right >= 0: + shape_tuple[right] = list(inp_shape)[right] + auto_infer = True + continue + else: + raise ValueError("invalided Broadcast shape") + else: + raise ValueError( + "expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( + i, s + ) + ) + if s < 0: + raise ValueError( + "expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( + i, s + ) + ) + if auto_infer: + shape = tuple(shape_tuple) + try: + shape_tuple = make_shape_tuple(shape) + except ValueError: + shape_tuple = shape + shape = astensor1d(shape_tuple, inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Broadcast(), inp, shape) return result diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index bc7645379084d3c7d9f056afbb48736bd9c362d1..8f4f8d2562f99f9c1e8a3fe86d8d37bf8f6357d1 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -223,6 +223,34 @@ def test_reshape(is_varnode): np.testing.assert_equal(yy.numpy(), y) +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_broadcast_auto_infer(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + x = np.random.random((1, 2, 3)).astype(np.float32) + xx = make_tensor(x, network) + + for shape in [ + (1, 2, 3), + (1, None, 3), + ]: + yy = F.broadcast_to(xx, shape) + np.testing.assert_equal(yy.numpy(), x) + + with pytest.raises(ValueError): + F.broadcast_to(xx, (1, -1, 3)) + + with pytest.raises(ValueError): + F.broadcast_to(xx, (None, 1, 2, 3)) + + F.broadcast_to(xx, (1, None, 2, 3)) + t = tensor(2, dtype=np.int32) + F.broadcast_to(xx, (t, None, 2, 3)) + + @pytest.mark.parametrize("is_trace", [True, False]) def test_reshape_on_empty_tensor(is_trace): input1_shape = (100, 0, 1)