diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index 09f1a6b91d3ab3002e556b2594241cc2e0047218..25dce2436ea0ea0d8e7fbf4999ddc336183f642c 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -7,12 +7,14 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np +import pytest import megengine.functional as F import megengine.functional.elemwise as elemwise from megengine import tensor from megengine.core.tensor import dtype from megengine.functional.elemwise import Elemwise, _elwise +from megengine.jit import trace def test_abs(): @@ -180,3 +182,80 @@ def test_int32_input(): inp = (x,) * nargs y = op(*inp) y.numpy() + + +@pytest.mark.parametrize("is_trace", [True, False]) +def test_empty_tensor(is_trace): + binary_func = [] + unary_func = [] + for op_name in elemwise.__all__: + op = getattr(elemwise, op_name) + nargs = op.__code__.co_argcount + if op_name == "clip": + unary_func.append(["clip", lambda x, f=op: f(x, lower=0, upper=1)]) + elif op_name.endswith("_shift"): + unary_func.append( + [op_name, lambda x, f=op: f(tensor(x.numpy(), dtype="int32"), 1)] + ) + elif op_name.startswith("logical_"): # logical_xxx op only accept boolean type + if nargs == 1: + unary_func.append( + [op_name, lambda x, f=op: f(tensor(x.numpy(), dtype="bool"))] + ) + else: + assert nargs == 2 + binary_func.append( + [ + op_name, + lambda x, y, f=op: f( + tensor(x.numpy(), dtype="bool"), + tensor(y.numpy(), dtype="bool"), + ), + ] + ) + elif nargs == 1: + unary_func.append([op_name, op]) + elif nargs == 2: + binary_func.append([op_name, op]) + else: + print(nargs) + raise NotImplementedError + + def run_test(func, args, ref_shape, is_trace, sym=False): + args = [tensor(t, dtype="float32") for t in args] + if is_trace: + func = trace(symbolic=sym)(func) + for _ in range(3): + out = func(*args) + assert out.numpy().shape == ref_shape + else: + out = func(*args) + assert out.numpy().shape == ref_shape + print(out.numpy().shape) + + inps = [ + np.array([]).astype("float32"), + np.random.randn(2, 0, 3).astype("float32"), + 123, + ] + for op_name, op in unary_func: + if is_trace: + for sym in [True, False]: + run_test(op, [inps[0],], inps[0].shape, True, sym) + run_test(op, [inps[1],], inps[1].shape, True, sym) + else: + run_test(op, [inps[0],], inps[0].shape, False) + run_test(op, [inps[1],], inps[1].shape, False) + + for op_name, op in binary_func: + if is_trace: + for sym in [True, False]: + run_test(op, [inps[0], inps[0]], (inps[0] + inps[0]).shape, True, sym) + run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, True, sym) + run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, True, sym) + run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, True, sym) + else: + run_test(op, [inps[0], inps[0]], (inps[0] + inps[0]).shape, False) + run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False) + run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False) + run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index c87a5e747cb3e071ada48c37b7645b1901ad7612..5f0305a4ae648370738777552b79372418a4099d 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -19,6 +19,7 @@ from megengine.core._trace_option import use_symbolic_shape from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor.utils import astensor1d from megengine.distributed.helper import get_device_count_by_fork +from megengine.jit import trace from megengine.utils.network import Network, set_symbolic_shape from megengine.utils.network_node import VarNode @@ -177,6 +178,48 @@ def test_reshape(is_varnode): np.testing.assert_equal(yy.numpy(), y) +@pytest.mark.parametrize("is_trace", [True, False]) +def test_reshape_on_empty_tensor(is_trace): + input1_shape = (100, 0, 1) + output1_shape = (100, 0, 10) + data1 = tensor(np.random.random(input1_shape).astype(np.float32)) + + input2_shape = (10, 0) + output2_shape = (0,) + data2 = tensor(np.random.random(input2_shape).astype(np.float32)) + + input3_shape = (10, 0, 10) + output3_shape = (0, 1, 2, 3) + data3 = tensor(np.random.random(input3_shape).astype(np.float32)) + + def comp(out, target_shp): + assert out._tuple_shape == target_shp + + def func(x, shp): + return F.reshape(x, shp) + + cases = [ + [data1, output1_shape], + [data2, output2_shape], + [data3, output3_shape], + ] + + def test(func, inp, comp, target_shp): + out = func(inp, target_shp) + comp(out, target_shp) + + if is_trace: + for symbolic in [False, True]: + for inp, target_shp in cases: + func_traced = trace(symbolic=symbolic)(func) + test(func_traced, inp, comp, target_shp) + test(func_traced, inp, comp, target_shp) + test(func_traced, inp, comp, target_shp) + else: + for inp, target_shp in cases: + test(func, inp, comp, target_shp) + + @pytest.mark.parametrize("is_varnode", [True, False]) def test_reshape_shape_inference(is_varnode): if is_varnode: @@ -480,6 +523,48 @@ def test_broadcast(is_varnode): F.broadcast_to(x, (1, 3)) +@pytest.mark.parametrize("is_trace", [True, False]) +def test_broadcast_on_empty_tensor(is_trace): + input1_shape = (100, 0, 1) + output1_shape = (100, 0, 10) + data1 = tensor(np.random.random(input1_shape).astype(np.float32)) + + input2_shape = (10, 0) + output2_shape = (10, 10, 0) + data2 = tensor(np.random.random(input2_shape).astype(np.float32)) + + input3_shape = (0, 0, 1, 10) + output3_shape = (10, 0, 0, 10, 10) + data3 = tensor(np.random.random(input3_shape).astype(np.float32)) + + def comp(out, target_shp): + assert out._tuple_shape == target_shp + + def func(x, shp): + return F.broadcast_to(x, shp) + + cases = [ + [data1, output1_shape], + [data2, output2_shape], + [data3, output3_shape], + ] + + def test(func, inp, comp, target_shp): + out = func(inp, target_shp) + comp(out, target_shp) + + if is_trace: + for symbolic in [False, True]: + for inp, target_shp in cases: + func_traced = trace(symbolic=symbolic)(func) + test(func_traced, inp, comp, target_shp) + test(func_traced, inp, comp, target_shp) + test(func_traced, inp, comp, target_shp) + else: + for inp, target_shp in cases: + test(func, inp, comp, target_shp) + + @pytest.mark.parametrize("is_varnode", [True, False]) def test_utils_astensor1d(is_varnode): if is_varnode: diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index fc24d053315be1ac674c2d4c4d93e6c72a50ce22..c87e7a1199f86ec4e12032124e41c182cd397fb2 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -259,6 +259,10 @@ void Elemwise::perform( mgb_assert(t.comp_node() == out_cn); mgb_assert(t.dtype() == out_dt); } + if (t.shape().is_empty()) { + mgb_assert(dest.empty()); + return; + } inp_shapes[i] = t.shape(); } if (!opr) { diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index 87aac28eafea35ac75a3dd5a678c286be5e31357..8f8d233d9244e47ac12b232eb163d87d1b298f11 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -1064,4 +1064,37 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); } +TEST(TestOprBasicArithElemwise, PerformEmptyIO) { + auto cn = CompNode::load("xpu0"); + HostTensorGenerator<> gen; + auto host_x1 = gen({2, 0, 3, 4}), + host_x2 = gen({1}); + auto dev_x1 = std::make_shared(cn), + dev_x2 = std::make_shared(cn); + dev_x1->copy_from(*host_x1); + dev_x2->copy_from(*host_x2); + + auto dev_y = std::make_shared(cn, dev_x1->dtype()); + dev_y->resize(dev_x1->shape()); + auto&& dnn_opr = opr::intl::create_megdnn_opr(cn); + + // test unary mode + for (auto mode: {Mode::NEGATE, Mode::EXP, Mode::LOG}) { + SmallVector inputs = {*dev_x1}; + ASSERT_NO_THROW(opr::Elemwise::perform(mode, *dev_y, inputs, dnn_opr)); + ASSERT_TRUE(dev_y->empty()); + ASSERT_TRUE(dev_y->shape().is_empty()); + MGB_ASSERT_SHAPE_EQ(dev_y->shape(), dev_x1->shape()); + } + + // test binary mode + for (auto mode: {Mode::ADD, Mode::MUL, Mode::LT}) { + SmallVector inputs = {*dev_x1, *dev_x2}; + ASSERT_NO_THROW(opr::Elemwise::perform(mode, *dev_y, inputs, dnn_opr)); + ASSERT_TRUE(dev_y->empty()); + ASSERT_TRUE(dev_y->shape().is_empty()); + MGB_ASSERT_SHAPE_EQ(dev_y->shape(), dev_x1->shape()); + } +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}