diff --git a/imperative/python/megengine/core/tensor/indexing.py b/imperative/python/megengine/core/tensor/indexing.py index 30ac98a699a96c6fbdfb641d048a8e76486d130e..b559d18be7736fcaf3a237e1ce3c2cf778d2e645 100644 --- a/imperative/python/megengine/core/tensor/indexing.py +++ b/imperative/python/megengine/core/tensor/indexing.py @@ -176,6 +176,8 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): def is_bool_list(x): if not isinstance(x, list): return False + if len(x) == 0: + return False for i in x: if not isinstance(i, bool): return False @@ -246,17 +248,6 @@ def getitem(tensor, index): if len(try_result) == 2: return try_result[0] tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) - for v in tensors: - if v.shape is None: - break - if isinstance(v.shape, v.__class__): - break - if len(v.shape) > 0 and v.shape[0] == 0: - (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( - tensor - ) - return empty_tensor - if use_subtensor: op = builtin.Subtensor(items=items) else: diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index 5f3c6689d5247675dd9c2a55b1e08fd76c088919..8d57aa5ea80887ac24f3128fa181fa2d258db926 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -610,6 +610,25 @@ def test_subtensor_on_empty_tensor(symbolic): run_test(lambda x: x[100:200, 300:400, 500:600]) +@pytest.mark.parametrize("symbolic", [True, False, None]) +def test_indexingMultiAxisVec_on_empty_tensor(symbolic): + np_x = np.array([], dtype=np.float32).reshape(10, 10, 0) + mge_x = megengine.tensor(np_x) + + def run_test(fn): + out_ref = fn(np_x) + if symbolic is not None: + fn = jit.trace(symbolic=symbolic)(fn) + for i in range(3): + out = fn(mge_x) + np.testing.assert_equal(out.numpy(), out_ref) + + run_test(lambda x: x[[1, 2, 3]]) + run_test(lambda x: x[[1, 2, 3], [4, 5, 6]]) + run_test(lambda x: x[[]]) + run_test(lambda x: x[[], [], []]) + + @pytest.mark.parametrize("symbolic", [True, False, None]) def test_setsubtensor_on_empty_tensor(symbolic): def run_test(inp_shp, fn): @@ -655,3 +674,39 @@ def test_setsubtensor_on_empty_tensor(symbolic): run_test((10, 10, 10), test4) run_test((10, 10, 10), test5) run_test((10, 10, 10), test6) + + +@pytest.mark.parametrize("symbolic", [True, False, None]) +def test_indexingSetMultiAxisVec_on_empty_tensor(symbolic): + def run_test(inp_shp, fn): + np_x = np.random.randn(*inp_shp).astype(np.float32) + mge_x = megengine.tensor(np_x) + out_ref = fn(np_x) + if symbolic is not None: + fn = jit.trace(symbolic=symbolic)(fn) + for i in range(3): + out = fn(mge_x) + np.testing.assert_equal(out.numpy(), out_ref) + + def test1(x): + x[[1, 2, 3]] = x[[1, 2, 3]] + return x + + def test2(x): + x[[1, 2, 3], [1, 2, 3]] = x[[1, 2, 3], [1, 2, 3]] + return x + + def test3(x): + x[[]] = x[[]] + return x + + def test4(x): + x[[], [], []] = x[[], [], []] + return x + + run_test((10, 10, 0), test1) + run_test((10, 10, 0), test2) + run_test((10, 10, 0), test3) + run_test((10, 10, 0), test4) + run_test((10, 10, 10), test3) + run_test((10, 10, 10), test4) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 037b0d05403388764e631d0ec65ef615dd42485b..36c7beb1f2514575895ee3ba08db93cfb208eb39 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -860,8 +860,8 @@ def test_condtake(): np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) -# @pytest.mark.parametrize("is_symbolic", [None, False, True]) -def test_condtake(is_symbolic=None): +@pytest.mark.parametrize("is_symbolic", [None, False, True]) +def test_condtake(is_symbolic): shapes = [ (3, 3, 3), (0,), diff --git a/src/opr/impl/indexing.cpp b/src/opr/impl/indexing.cpp index ad56bb48ca52f1122aa913f31bdea1a2d38af044..276bf816d839406e63198ed29a198ad219aa191c 100644 --- a/src/opr/impl/indexing.cpp +++ b/src/opr/impl/indexing.cpp @@ -292,8 +292,6 @@ cg::OperatorNodeBase::NodeProp* IndexingMultiAxisVecBase::do_make_node_prop() const { auto prop = Super::do_make_node_prop(); using DT = NodeProp::DepType; - // TODO: should also allow input shape is empty if any - // indexer's shape is empty prop->add_dep_type_existing_var(input(0), DT::VALUE_ALLOW_EMPTY); for (auto i: m_input2idxonly_axis_indexer) { if (i) {