From bc9aa47ad87bba5e94f58280fe9d7e073ff01733 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 28 Feb 2022 13:57:52 +0800 Subject: [PATCH] feat(mge/indexing): support newaxis GitOrigin-RevId: 8338c4b47542671f07cee9d68bbe8c35c25c5f16 --- imperative/python/src/tensor_utils.cpp | 110 +++++++++--------- .../python/test/unit/core/test_indexing_op.py | 28 ++++- 2 files changed, 78 insertions(+), 60 deletions(-) diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 09695cf47..0b90cc042 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -459,12 +459,8 @@ py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { if (!dtype_equal(cur, descr)) { std::shared_ptr op = TypeCvt::make(npy::dtype_np2mgb_descr(descr)); py::object Op = py::cast(op); - std::vector p; - p.resize(2); - p[0] = Op.ptr(); - p[1] = tensor.ptr(); - py::tuple ret = - py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + PyObject* p[2] = {Op.ptr(), tensor.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 2)); return ret[0]; } else { return py::reinterpret_borrow(tensor); @@ -514,7 +510,7 @@ py::object _convert_inputs_cpp( } } auto convert = [&](py::object value) { - if (value.ptr() == Py_None) { + if (value.is_none()) { return value; } return _convert_single_value_cpp(value, dtype, device); @@ -545,12 +541,9 @@ py::object _astensor1d_cpp( if (device.ptr() != Py_None) { std::shared_ptr op = Copy::make(device_obj.cast()); py::object Op = py::cast(op); - std::vector p; - p.resize(2); - p[0] = Op.ptr(); - p[1] = ret.ptr(); - py::tuple copy_ret = py::reinterpret_steal( - py_apply(NULL, p.data(), p.size())); + PyObject* p[2] = {Op.ptr(), ret.ptr()}; + py::tuple copy_ret = + py::reinterpret_steal(py_apply(NULL, p, 2)); return copy_ret[0]; } return ret; @@ -590,7 +583,7 @@ py::object _astensor1d_cpp( c_args[lis.size()] = Py_None; py::tuple inp_tup = py::reinterpret_steal( convert_inputs_cpp(NULL, c_args.data(), c_args.size())); - if (device_obj.ptr() == Py_None) { + if (device_obj.is_none()) { std::vector inp(inp_tup.size()); for (size_t i = 0; i < inp_tup.size(); ++i) { inp[i] = inp_tup[i].ptr(); @@ -637,15 +630,10 @@ py::object _get_index(py::object tensor, py::object src) { return tensor; } } - static std::shared_ptr op = CondTake::make(); - std::vector p; - p.resize(3); + std::shared_ptr op = CondTake::make(); py::object Op = py::cast(op); - p[0] = Op.ptr(); - p[1] = tensor.ptr(); - p[2] = tensor.ptr(); - py::tuple ret = - py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + PyObject* p[3] = {Op.ptr(), tensor.ptr(), tensor.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 3)); return ret[1]; } @@ -666,15 +654,10 @@ py::tuple _try_cond_take(py::handle tensor, py::handle index) { } else { iobj = py::reinterpret_borrow(index); } - static std::shared_ptr op = CondTake::make(); - std::vector p; - p.resize(3); + std::shared_ptr op = CondTake::make(); py::object Op = py::cast(op); - p[0] = Op.ptr(); - p[1] = tensor.ptr(); - p[2] = iobj.ptr(); - py::tuple ret = - py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + PyObject* p[3] = {Op.ptr(), tensor.ptr(), iobj.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 3)); return ret; } @@ -685,7 +668,9 @@ py::tuple _remove_ellipsis(py::object tensor, py::tuple tuple_val) { bool has_unknown_ndim_bool_index = false; for (size_t i = 0; i < tuple_size; ++i) { py::object handle = tuple_val[i]; - if (handle.ptr() == Py_Ellipsis) { + if (handle.is_none()) { + continue; + } else if (handle.ptr() == Py_Ellipsis) { pos = static_cast(i); for (size_t j = 0; j < i; ++j) { py::object t = tuple_val[j]; @@ -749,8 +734,14 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) { size_t offset = 0; size_t tdim = 0; + size_t nonedim = 0; for (size_t i = 0; i < tuple_val.size(); ++i) { py::handle k = tuple_val[i]; + if (k.ptr() == Py_None) { + nonedim++; + new_tuple_val.append(k); + continue; + } if (is_bool_dtype(k.ptr())) { size_t ndim = getattr(k, "ndim").cast(); if (ndim > 1) { @@ -777,7 +768,7 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) { Py_XDECREF(sym); if (is_sym) { py::object tshape = getattr(tensor, "shape"); - for (size_t j = 0; j < i; ++j) { + for (size_t j = 0; j < i - nonedim; ++j) { new_shape.append(tshape[py::int_(j)]); } new_shape.append(kshape[py::int_(0)]); @@ -789,7 +780,7 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) { tensor = _reshape_cpp(tensor, shape_tensor); cur_shape = _make_shape_tuple(shape_tensor); } else { - for (size_t j = 0; j < i; ++j) { + for (size_t j = 0; j < i - nonedim; ++j) { new_shape.append(cur_shape[j]); } new_shape.append(py::reinterpret_borrow(kshape)[0]); @@ -838,8 +829,8 @@ py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) { size_t idx_ndim = 0; for (size_t i = 0; i < tuple_val.size(); ++i) { py::object k = tuple_val[i]; - if (k.ptr() == Py_None) { - throw py::index_error("newaxis is not allowed here"); + if (k.is_none()) { + continue; } else if (k.ptr() == Py_Ellipsis) { need_remove_ellipsis = true; } else { @@ -878,6 +869,20 @@ py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) { } } + std::vector axis; + for (size_t i = 0; i < tuple_val.size(); ++i) { + if (tuple_val[i].is_none()) { + axis.push_back(i); + } + } + if (axis.size()) { + std::shared_ptr op = AddAxis::make(axis); + py::object Op = py::cast(op); + PyObject* p[2] = {Op.ptr(), inp.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 2)); + inp = ret[0]; + } + py::list items; py::list tensors; int cur_axis = -1; @@ -885,6 +890,9 @@ py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) { for (size_t i = 0; i < tuple_val.size(); ++i) { py::object handle = tuple_val[i]; cur_axis++; + if (handle.is_none()) { + continue; + } if (!is_scalar(handle.ptr()) && !PySlice_Check(handle.ptr())) { use_subtensor = false; } @@ -970,11 +978,11 @@ py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) { if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) { lis = py::reinterpret_steal(PySequence_List(shape_hdl.ptr())); for (size_t i = 0; i < lis.size(); ++i) { - if (lis[i].ptr() == Py_None) { + if (lis[i].is_none()) { auto_infer = true; size_t right = lis.size() - i; py::object tshp = getattr(inp_hdl, "_tuple_shape"); - if (tshp.ptr() == Py_None) { + if (tshp.is_none()) { throw py::index_error("does not support `None` with unknown shape"); } py::tuple inp_shape = py::reinterpret_borrow(tshp); @@ -1116,7 +1124,7 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { {item[0].cast(), item[1].cast(), item[2].cast(), item[3].cast(), item[4].cast()}); } - static std::shared_ptr op; + std::shared_ptr op; if (up[3].cast()) { op = Subtensor::make(cpp_items); } else { @@ -1155,7 +1163,7 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h {item[0].cast(), item[1].cast(), item[2].cast(), item[3].cast(), item[4].cast()}); } - static std::shared_ptr op, set_op; + std::shared_ptr op, set_op; if (up[3].cast()) { op = Subtensor::make(cpp_items); } else { @@ -1340,13 +1348,9 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { } std::sort(axis.begin(), axis.end()); std::shared_ptr op = AddAxis::make(axis = axis); - std::vector p; - p.resize(2); py::object Op = py::cast(op); - p[0] = Op.ptr(); - p[1] = inp_hdl.ptr(); - py::tuple ret = - py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + PyObject* p[2] = {Op.ptr(), inp_hdl.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 2)); return ret[0]; } @@ -1390,13 +1394,9 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { axis[i] -= static_cast(i); } std::shared_ptr op = RemoveAxis::make(axis = axis); - std::vector p; - p.resize(2); py::object Op = py::cast(op); - p[0] = Op.ptr(); - p[1] = inp_hdl.ptr(); - py::tuple ret = - py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + PyObject* p[2] = {Op.ptr(), inp_hdl.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 2)); return ret[0]; } py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { @@ -1437,13 +1437,9 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { } } std::shared_ptr op = Dimshuffle::make(pattern); - std::vector p; - p.resize(2); py::object Op = py::cast(op); - p[0] = Op.ptr(); - p[1] = inp_hdl.ptr(); - py::tuple ret = - py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + PyObject* p[2] = {Op.ptr(), inp_hdl.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 2)); return ret[0]; } diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index dd44657a5..ad4220822 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -436,6 +436,8 @@ def test_advance_indexing_high_level(test_varnode): x = np.arange(27).reshape(3, 3, 3).astype("int32") xx = make_tensor(x, network) + y = np.array([0, 2], dtype=np.int32) + z = np.array([[0, 1], [1, 2]], dtype=np.int32) np.testing.assert_equal(x[1, :, :], get_value(xx[1, :, :])) np.testing.assert_equal(x[1, :, 1], get_value(xx[1, :, 1])) @@ -444,6 +446,21 @@ def test_advance_indexing_high_level(test_varnode): np.testing.assert_equal(x[:, 1, 1], get_value(xx[:, 1, 1])) np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) np.testing.assert_equal(x[1, 1:2], get_value(xx[1, 1:2])) + np.testing.assert_equal(x[:2, y, [0, 1]], get_value(xx[:2, y, [0, 1]])) + np.testing.assert_equal(x[None, None], get_value(xx[None, None])) + np.testing.assert_equal(x[:, None, ...], get_value(xx[:, None, ...])) + np.testing.assert_equal(x[1, None, :, 1], get_value(xx[1, None, :, 1])) + np.testing.assert_equal(x[:, None, 1, None], get_value(xx[:, None, 1, None])) + np.testing.assert_equal(x[:2, y, None, [0, 1]], get_value(xx[:2, y, None, [0, 1]])) + np.testing.assert_equal( + x[None, :, None, [0, 2], None, [1, 2]], + get_value(xx[None, :, None, [0, 2], None, [1, 2]]), + ) + np.testing.assert_equal(x[z], get_value(xx[z])) + np.testing.assert_equal(x[z, None], get_value(xx[z, None])) + np.testing.assert_equal(x[None, z], get_value(xx[None, z])) + np.testing.assert_equal(x[z, None, z], get_value(xx[z, None, z])) + np.testing.assert_equal(x[None, z, None], get_value(xx[None, z, None])) x_ = x.copy() x_[1, 1, 1] = -1 @@ -592,16 +609,24 @@ def test_advance_indexing_with_bool(test_varnode): b = (np.random.sample((2, 3, 4)) > 0.5).astype("bool") bb = make_tensor(b, network) np.testing.assert_equal(a[b, :, 0:4:2], get_value(aa[bb, :, 0:4:2])) + np.testing.assert_equal(a[None, b, :, 0:4:2], get_value(aa[None, bb, :, 0:4:2])) b = (np.random.sample((4, 3, 4)) > 0.5).astype("bool") bb = make_tensor(b, network) np.testing.assert_equal(a[..., b, 0:2], get_value(aa[..., bb, 0:2])) + np.testing.assert_equal( + a[None, ..., b, None, 0:2], get_value(aa[None, ..., bb, None, 0:2]) + ) b = (np.random.sample((3, 4, 3)) > 0.5).astype("bool") bb = make_tensor(b, network) np.testing.assert_equal( a[:, b, 0:2, [True, False]], get_value(aa[:, bb, 0:2, [True, False]]) ) + np.testing.assert_equal( + a[:, b, None, 0:2, [True, False]], + get_value(aa[:, bb, None, 0:2, [True, False]]), + ) @pytest.mark.parametrize("symbolic", [True, False, None]) @@ -781,9 +806,6 @@ def test_indexing_error(test_varnode): aa = make_tensor(a, network) bb = make_tensor(b, network) - with pytest.raises(IndexError): - aa[None] # newaxis is not allowed - with pytest.raises(IndexError): aa[..., ...] # only one ellipsis is allowed -- GitLab