提交 2937ea0e 编写于 作者: M Megvii Engine Team

fix(imperative): fix indexing error when slice start is negative

GitOrigin-RevId: 12e422f1166e1d1aa44d3b662a8928bb27f98917
上级 75f1f0c5
......@@ -1247,6 +1247,29 @@ py::object _fastpath_getitem_cpp(py::handle inp_hdl, py::tuple tuple_val) {
return ret[0];
}
py::tuple _get_tuple_idx(py::handle idx_hdl) {
if (py::isinstance<py::list>(idx_hdl)) {
bool is_all_int = true;
py::list idx = py::reinterpret_borrow<py::list>(idx_hdl);
for (size_t i = 0; i < idx.size(); i++) {
if (py::int_::check_(idx[i])) {
continue;
} else if (py::float_::check_(idx[i])) {
throw py::value_error("float idx is vaild in subtensor");
} else {
is_all_int = false;
}
}
if (is_all_int) {
return py::make_tuple(idx_hdl);
} else {
return py::tuple(idx);
}
} else {
return py::make_tuple(idx_hdl);
}
}
py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl);
if (try_res.size() == 2) {
......@@ -1256,7 +1279,7 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
if (py::isinstance<py::tuple>(idx_hdl)) {
tuple_val = py::reinterpret_borrow<py::tuple>(idx_hdl);
} else {
tuple_val = py::make_tuple(idx_hdl);
tuple_val = _get_tuple_idx(idx_hdl);
}
if (subtensor_fastpath(inp_hdl, tuple_val)) {
return _fastpath_getitem_cpp(inp_hdl, tuple_val);
......
......@@ -422,6 +422,8 @@ def test_advance_indexing_high_level(test_varnode):
np.testing.assert_equal(x[1, :], get_value(xx[1, :]))
np.testing.assert_equal(x[:, 1], get_value(xx[:, 1]))
np.testing.assert_equal(x[1:3, :], get_value(xx[1:3, :]))
np.testing.assert_equal(x[-2:], get_value(xx[-2:]))
np.testing.assert_equal(x[:, -1:], get_value(xx[:, -1:]))
np.testing.assert_equal(x[:, :], get_value(xx[:, :]))
np.testing.assert_equal(x[1, 1], get_value(xx[1, 1]))
......
......@@ -202,12 +202,13 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
ax_val = ax_val < 0 ? layout.shape[axis] + ax_val : ax_val;
offset += ax_val * layout.stride[axis] * dtype_size;
} else {
int shape_axis = src->layout().shape[axis];
if (s_val < 0) {
int shape_axis = src->layout().shape[axis];
start = b_val == INT_MIN ? shape_axis - 1 : b_val;
start = mod_size(start, shape_axis);
}
start = std::max(start, 0);
start = start == INT_MIN ? 0 : start;
start = start < 0 ? start + shape_axis : start;
offset += start * layout.stride[axis] * dtype_size;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册