From 2937ea0e339ca80ac75299cf521d5c6386ecd385 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 26 Jun 2023 18:20:10 +0800 Subject: [PATCH] fix(imperative): fix indexing error when slice start is negative GitOrigin-RevId: 12e422f1166e1d1aa44d3b662a8928bb27f98917 --- imperative/python/src/tensor_utils.cpp | 25 ++++++++++++++++++- .../python/test/unit/core/test_indexing_op.py | 2 ++ imperative/src/impl/ops/subtensor.cpp | 5 ++-- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 28f7e4ee6..a82db85b3 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -1247,6 +1247,29 @@ py::object _fastpath_getitem_cpp(py::handle inp_hdl, py::tuple tuple_val) { return ret[0]; } +py::tuple _get_tuple_idx(py::handle idx_hdl) { + if (py::isinstance(idx_hdl)) { + bool is_all_int = true; + py::list idx = py::reinterpret_borrow(idx_hdl); + for (size_t i = 0; i < idx.size(); i++) { + if (py::int_::check_(idx[i])) { + continue; + } else if (py::float_::check_(idx[i])) { + throw py::value_error("float idx is vaild in subtensor"); + } else { + is_all_int = false; + } + } + if (is_all_int) { + return py::make_tuple(idx_hdl); + } else { + return py::tuple(idx); + } + } else { + return py::make_tuple(idx_hdl); + } +} + py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl); if (try_res.size() == 2) { @@ -1256,7 +1279,7 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { if (py::isinstance(idx_hdl)) { tuple_val = py::reinterpret_borrow(idx_hdl); } else { - tuple_val = py::make_tuple(idx_hdl); + tuple_val = _get_tuple_idx(idx_hdl); } if (subtensor_fastpath(inp_hdl, tuple_val)) { return _fastpath_getitem_cpp(inp_hdl, tuple_val); diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index 1f6360916..c5dd67762 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -422,6 +422,8 @@ def test_advance_indexing_high_level(test_varnode): np.testing.assert_equal(x[1, :], get_value(xx[1, :])) np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) np.testing.assert_equal(x[1:3, :], get_value(xx[1:3, :])) + np.testing.assert_equal(x[-2:], get_value(xx[-2:])) + np.testing.assert_equal(x[:, -1:], get_value(xx[:, -1:])) np.testing.assert_equal(x[:, :], get_value(xx[:, :])) np.testing.assert_equal(x[1, 1], get_value(xx[1, 1])) diff --git a/imperative/src/impl/ops/subtensor.cpp b/imperative/src/impl/ops/subtensor.cpp index 18fff76b4..56dc8a151 100644 --- a/imperative/src/impl/ops/subtensor.cpp +++ b/imperative/src/impl/ops/subtensor.cpp @@ -202,12 +202,13 @@ SmallVector apply_on_physical_tensor( ax_val = ax_val < 0 ? layout.shape[axis] + ax_val : ax_val; offset += ax_val * layout.stride[axis] * dtype_size; } else { + int shape_axis = src->layout().shape[axis]; if (s_val < 0) { - int shape_axis = src->layout().shape[axis]; start = b_val == INT_MIN ? shape_axis - 1 : b_val; start = mod_size(start, shape_axis); } - start = std::max(start, 0); + start = start == INT_MIN ? 0 : start; + start = start < 0 ? start + shape_axis : start; offset += start * layout.stride[axis] * dtype_size; } } -- GitLab