未验证 提交 02a7b3cc 编写于 作者: J JYChen 提交者: GitHub

add range support in indexing (#56272)

* add range support in indexing

* add getitem ut case
上级 84482da8
......@@ -732,7 +732,7 @@ def monkey_patch_tensor():
item = (item,)
for slice_item in item:
if isinstance(slice_item, (list, np.ndarray, Variable)):
if isinstance(slice_item, (list, np.ndarray, Variable, range)):
return True
elif isinstance(slice_item, slice):
if (
......
......@@ -259,11 +259,13 @@ def replace_ellipsis(var, item):
return item
def replace_ndarray(item):
def replace_ndarray_and_range(item):
new_item = []
for slice_item in item:
if isinstance(slice_item, np.ndarray):
new_item.append(paddle.assign(slice_item))
elif isinstance(slice_item, range):
new_item.append(list(slice_item))
else:
new_item.append(slice_item)
return new_item
......@@ -416,7 +418,7 @@ def _setitem_impl_(var, item, value):
ends = []
steps = []
item = replace_ndarray(item)
item = replace_ndarray_and_range(item)
item = replace_ellipsis(var, item)
item, none_axes = replace_none(item)
slice_info = SliceInfo()
......@@ -700,7 +702,7 @@ def parse_index(x, indices):
if not isinstance(indices, tuple):
indices = (indices,)
indices = replace_ndarray(indices)
indices = replace_ndarray_and_range(indices)
indices = replace_ellipsis(x, indices)
indices, none_axes = replace_none(indices)
......
......@@ -129,6 +129,15 @@ class TestGetitemInDygraph(unittest.TestCase):
np.testing.assert_allclose(y.numpy(), np_res)
def test_index_has_range(self):
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_res = np_data[:, range(3), 4]
x = paddle.to_tensor(np_data)
y = x[:, range(3), 4]
np.testing.assert_allclose(y.numpy(), np_res)
class TestGetitemInStatic(unittest.TestCase):
def setUp(self):
......@@ -312,6 +321,19 @@ class TestGetitemInStatic(unittest.TestCase):
np.testing.assert_allclose(res[0], np_res)
def test_index_has_range(self):
# only one bool tensor with all False
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_res = np_data[:, range(3), 4]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(x, (slice(None, None, None), range(3), 4))
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
class TestGetItemErrorCase(unittest.TestCase):
def setUp(self):
......
......@@ -51,6 +51,15 @@ class TestSetitemInDygraph(unittest.TestCase):
np.testing.assert_allclose(x.numpy(), np_data)
def test_index_has_range(self):
np_data = np.ones((3, 4, 5, 6), dtype='int32')
x = paddle.to_tensor(np_data)
np_data[:, range(3), [1, 2, 4]] = 10
x[:, range(3), [1, 2, 4]] = 10
np.testing.assert_allclose(x.numpy(), np_data)
class TestSetitemInStatic(unittest.TestCase):
def setUp(self):
......@@ -137,3 +146,19 @@ class TestSetitemInStatic(unittest.TestCase):
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_data)
def test_index_has_range(self):
np_data = np.ones((3, 4, 5, 6), dtype='int32')
np_data[:, range(3), [1, 2, 4]] = 10
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.ones((3, 4, 5, 6), dtype='int32')
y = _setitem_static(
x,
(slice(None, None), range(3), [1, 2, 4]),
10,
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_data)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册