From 02a7b3ccd8f1a3f2558bf46f2c1030f6d758c9a6 Mon Sep 17 00:00:00 2001 From: JYChen Date: Wed, 16 Aug 2023 12:04:37 +0800 Subject: [PATCH] add range support in indexing (#56272) * add range support in indexing * add getitem ut case --- .../fluid/dygraph/tensor_patch_methods.py | 2 +- python/paddle/fluid/variable_index.py | 8 +++--- test/indexing/test_getitem.py | 22 ++++++++++++++++ test/indexing/test_setitem.py | 25 +++++++++++++++++++ 4 files changed, 53 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/dygraph/tensor_patch_methods.py b/python/paddle/fluid/dygraph/tensor_patch_methods.py index a294bc6de4e..765da4bcd88 100644 --- a/python/paddle/fluid/dygraph/tensor_patch_methods.py +++ b/python/paddle/fluid/dygraph/tensor_patch_methods.py @@ -732,7 +732,7 @@ def monkey_patch_tensor(): item = (item,) for slice_item in item: - if isinstance(slice_item, (list, np.ndarray, Variable)): + if isinstance(slice_item, (list, np.ndarray, Variable, range)): return True elif isinstance(slice_item, slice): if ( diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index acf30532fbe..78ba5e3cfd7 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -259,11 +259,13 @@ def replace_ellipsis(var, item): return item -def replace_ndarray(item): +def replace_ndarray_and_range(item): new_item = [] for slice_item in item: if isinstance(slice_item, np.ndarray): new_item.append(paddle.assign(slice_item)) + elif isinstance(slice_item, range): + new_item.append(list(slice_item)) else: new_item.append(slice_item) return new_item @@ -416,7 +418,7 @@ def _setitem_impl_(var, item, value): ends = [] steps = [] - item = replace_ndarray(item) + item = replace_ndarray_and_range(item) item = replace_ellipsis(var, item) item, none_axes = replace_none(item) slice_info = SliceInfo() @@ -700,7 +702,7 @@ def parse_index(x, indices): if not isinstance(indices, tuple): indices = (indices,) - indices = replace_ndarray(indices) + indices = replace_ndarray_and_range(indices) indices = replace_ellipsis(x, indices) indices, none_axes = replace_none(indices) diff --git a/test/indexing/test_getitem.py b/test/indexing/test_getitem.py index bb18a0b7723..6ecd7750ec7 100644 --- a/test/indexing/test_getitem.py +++ b/test/indexing/test_getitem.py @@ -129,6 +129,15 @@ class TestGetitemInDygraph(unittest.TestCase): np.testing.assert_allclose(y.numpy(), np_res) + def test_index_has_range(self): + np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_res = np_data[:, range(3), 4] + + x = paddle.to_tensor(np_data) + y = x[:, range(3), 4] + + np.testing.assert_allclose(y.numpy(), np_res) + class TestGetitemInStatic(unittest.TestCase): def setUp(self): @@ -312,6 +321,19 @@ class TestGetitemInStatic(unittest.TestCase): np.testing.assert_allclose(res[0], np_res) + def test_index_has_range(self): + # only one bool tensor with all False + np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_res = np_data[:, range(3), 4] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static(x, (slice(None, None, None), range(3), 4)) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_res) + class TestGetItemErrorCase(unittest.TestCase): def setUp(self): diff --git a/test/indexing/test_setitem.py b/test/indexing/test_setitem.py index f412247339f..b5e23ed309d 100644 --- a/test/indexing/test_setitem.py +++ b/test/indexing/test_setitem.py @@ -51,6 +51,15 @@ class TestSetitemInDygraph(unittest.TestCase): np.testing.assert_allclose(x.numpy(), np_data) + def test_index_has_range(self): + np_data = np.ones((3, 4, 5, 6), dtype='int32') + x = paddle.to_tensor(np_data) + + np_data[:, range(3), [1, 2, 4]] = 10 + x[:, range(3), [1, 2, 4]] = 10 + + np.testing.assert_allclose(x.numpy(), np_data) + class TestSetitemInStatic(unittest.TestCase): def setUp(self): @@ -137,3 +146,19 @@ class TestSetitemInStatic(unittest.TestCase): res = self.exe.run(fetch_list=[y.name]) np.testing.assert_allclose(res[0], np_data) + + def test_index_has_range(self): + np_data = np.ones((3, 4, 5, 6), dtype='int32') + np_data[:, range(3), [1, 2, 4]] = 10 + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='int32') + y = _setitem_static( + x, + (slice(None, None), range(3), [1, 2, 4]), + 10, + ) + res = self.exe.run(fetch_list=[y.name]) + + np.testing.assert_allclose(res[0], np_data) -- GitLab