From 57c4eccf3b768e6ef279b27efbf4f4423202e7f4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 22 Sep 2020 19:21:29 +0800 Subject: [PATCH] fix(mge/functional): add shape check for bc GitOrigin-RevId: e152c1928c6336102995ce9c51ac6a874b1cd7d1 --- .../megengine/core/tensor/tensor_wrapper.py | 22 +++++++++++++++++++ .../test/unit/functional/test_tensor.py | 12 +++++++++- imperative/python/test/unit/test_tracing.py | 17 ++++++++++++++ 3 files changed, 50 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index afe180b3f..a2b509a76 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -57,6 +57,28 @@ def _transpose(data, axes): def _broadcast(inp, shape): + def valid_broadcast(src, tar): + def failed(): + raise ValueError( + "the input shape {} can not be broadcasted to target shape {}".format( + src, tar + ) + ) + + if isinstance(src, (Tensor, TensorWrapperBase)): + src = src.numpy() + + if isinstance(tar, (Tensor, TensorWrapperBase)): + tar = tar.numpy() + + if len(src) > len(tar): + failed() + + for i in range(min(len(src), len(tar))): + if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: + failed() + + valid_broadcast(inp.shape, shape) shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Broadcast(), inp, shape) return result diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 72da5d861..88b1bcdaf 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -240,7 +240,7 @@ def test_broadcast(): output1_shape = (30, 20, 30) data1 = np.random.random(input1_shape).astype(np.float32) - input2_shape = (10, 20) + input2_shape = (10, 1) output2_shape = (20, 10, 20) data2 = np.random.random(input2_shape).astype(np.float32) @@ -253,6 +253,16 @@ def test_broadcast(): ] opr_test(cases, F.broadcast, compare_fn=compare_fn) + x = F.ones((2, 1, 3)) + with pytest.raises(ValueError): + F.broadcast(x, (2, 3, 4)) + + with pytest.raises(ValueError): + F.broadcast(x, (4, 1, 3)) + + with pytest.raises(ValueError): + F.broadcast(x, (1, 3)) + def test_utils_astensor1d(): reference = tensor(0) diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index a022b9eb7..c9d7abd2a 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -340,3 +340,20 @@ def test_raise_on_trace(): step_count += 1 assert catch_count == 1 + + +def test_trace_broadcast(): + for symbolic in [False, True]: + set_tensor_shape(True) + x1 = tensor(np.random.randn(3, 1, 1)) + x2 = tensor(np.random.randn(1, 4, 1)) + x3 = tensor(np.random.randn(1, 1, 5)) + + @trace(symbolic=symbolic, capture_as_const=True) + def f(x): + y = x.broadcast((3, 4, 5)) + return y + + f(x1) + f(x2) + f(x3) -- GitLab