diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 28f7e4ee6d77e52ad44103a0c3f508804aa9f5b9..a82db85b39f42aaf5f4c530c341dd4fc8047190a 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -1247,6 +1247,29 @@ py::object _fastpath_getitem_cpp(py::handle inp_hdl, py::tuple tuple_val) { return ret[0]; } +py::tuple _get_tuple_idx(py::handle idx_hdl) { + if (py::isinstance(idx_hdl)) { + bool is_all_int = true; + py::list idx = py::reinterpret_borrow(idx_hdl); + for (size_t i = 0; i < idx.size(); i++) { + if (py::int_::check_(idx[i])) { + continue; + } else if (py::float_::check_(idx[i])) { + throw py::value_error("float idx is vaild in subtensor"); + } else { + is_all_int = false; + } + } + if (is_all_int) { + return py::make_tuple(idx_hdl); + } else { + return py::tuple(idx); + } + } else { + return py::make_tuple(idx_hdl); + } +} + py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl); if (try_res.size() == 2) { @@ -1256,7 +1279,7 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { if (py::isinstance(idx_hdl)) { tuple_val = py::reinterpret_borrow(idx_hdl); } else { - tuple_val = py::make_tuple(idx_hdl); + tuple_val = _get_tuple_idx(idx_hdl); } if (subtensor_fastpath(inp_hdl, tuple_val)) { return _fastpath_getitem_cpp(inp_hdl, tuple_val); diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index 1f636091669f84a2ef30e8cdeb3199abb27377a3..c5dd67762c04ac62840d58443deeed0e99d6a175 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -422,6 +422,8 @@ def test_advance_indexing_high_level(test_varnode): np.testing.assert_equal(x[1, :], get_value(xx[1, :])) np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) np.testing.assert_equal(x[1:3, :], get_value(xx[1:3, :])) + np.testing.assert_equal(x[-2:], get_value(xx[-2:])) + np.testing.assert_equal(x[:, -1:], get_value(xx[:, -1:])) np.testing.assert_equal(x[:, :], get_value(xx[:, :])) np.testing.assert_equal(x[1, 1], get_value(xx[1, 1])) diff --git a/imperative/src/impl/ops/subtensor.cpp b/imperative/src/impl/ops/subtensor.cpp index 18fff76b41ec2e683ac5fd195bc807cb417fc5cb..56dc8a151726e9f2a8f72224f8067e4683467f41 100644 --- a/imperative/src/impl/ops/subtensor.cpp +++ b/imperative/src/impl/ops/subtensor.cpp @@ -202,12 +202,13 @@ SmallVector apply_on_physical_tensor( ax_val = ax_val < 0 ? layout.shape[axis] + ax_val : ax_val; offset += ax_val * layout.stride[axis] * dtype_size; } else { + int shape_axis = src->layout().shape[axis]; if (s_val < 0) { - int shape_axis = src->layout().shape[axis]; start = b_val == INT_MIN ? shape_axis - 1 : b_val; start = mod_size(start, shape_axis); } - start = std::max(start, 0); + start = start == INT_MIN ? 0 : start; + start = start < 0 ? start + shape_axis : start; offset += start * layout.stride[axis] * dtype_size; } }