未验证 提交 9fd1dd05 编写于 作者: H hong 提交者: GitHub

Fix get item out of range error (#24339) (#24943)

* raise index error when slice out of range; test=develop

* add uni test; test=develop

* fix format error; test=develop

* add comment for py::index_error; test=develop

* polish error message; test=develop

* polish error message; test=develop
上级 5fc42753
...@@ -322,7 +322,18 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, ...@@ -322,7 +322,18 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
if (PyCheckInteger(slice_item)) { if (PyCheckInteger(slice_item)) {
// integer, PyLong_AsLong supports both int and long // integer, PyLong_AsLong supports both int and long
int start = static_cast<int>(PyLong_AsLong(slice_item)); int start = static_cast<int>(PyLong_AsLong(slice_item));
auto s_t = start;
start = start < 0 ? start + dim_len : start; start = start < 0 ? start + dim_len : start;
if (start >= dim_len) {
std::string str_error_message =
"The starting index " + std::to_string(s_t) +
" of slice is out of bounds in tensor " + std::to_string(dim) +
"-th axis, it shound be in the range of [" +
std::to_string(-dim_len) + ", " + std::to_string(dim_len) + ")";
// py::index_error is corresponding to IndexError in Python
// Used to indicate out of bounds access in __getitem__, __setitem__
throw py::index_error(str_error_message);
}
slice_axes->push_back(dim); slice_axes->push_back(dim);
slice_starts->push_back(start); slice_starts->push_back(start);
slice_ends->push_back(start + 1); slice_ends->push_back(start + 1);
......
...@@ -181,14 +181,25 @@ class TestVarBase(unittest.TestCase): ...@@ -181,14 +181,25 @@ class TestVarBase(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1])) np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
def _test_for_var(self):
np_value = np.random.random((30, 100, 100)).astype('float32')
w = fluid.dygraph.to_variable(np_value)
for i, e in enumerate(w):
self.assertTrue(np.array_equal(e.numpy(), np_value[i]))
def test_slice(self): def test_slice(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
self._test_slice() self._test_slice()
self._test_for_var()
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :])) self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :]))
self.assertTrue(np.array_equal(var[::-1].numpy(), self.array[::-1])) self.assertTrue(np.array_equal(var[::-1].numpy(), self.array[::-1]))
with self.assertRaises(IndexError):
y = var[self.shape[0]]
def test_var_base_to_np(self): def test_var_base_to_np(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册