未验证 提交 6e17e661 编写于 作者: J JYChen 提交者: GitHub

Normalize multi dim list in indexing (#56893)

* add unit test for bool-list index

* normative semantics of multi-dim List

* Adapt to different CI environment Numpy versions
上级 fcb67605
...@@ -41,16 +41,6 @@ def is_list_tuple(index, contain_type): ...@@ -41,16 +41,6 @@ def is_list_tuple(index, contain_type):
return True return True
def is_one_dim_list(index, contain_type):
if isinstance(index, list):
for i in index:
if not isinstance(i, contain_type):
return False
else:
return False
return True
def get_list_index_shape(var_dims, index_dims): def get_list_index_shape(var_dims, index_dims):
var_dims_size = len(var_dims) var_dims_size = len(var_dims)
index_dims_size = len(index_dims) index_dims_size = len(index_dims)
...@@ -405,9 +395,7 @@ def _setitem_impl_(var, item, value): ...@@ -405,9 +395,7 @@ def _setitem_impl_(var, item, value):
return _setitem_for_tensor_array(var, item, value) return _setitem_for_tensor_array(var, item, value)
inputs = {'Input': var} inputs = {'Input': var}
if isinstance(item, list):
if not is_one_dim_list(item, int):
item = tuple(item)
# 1. Parse item # 1. Parse item
if not isinstance(item, tuple): if not isinstance(item, tuple):
item = (item,) item = (item,)
...@@ -702,9 +690,6 @@ def parse_index(x, indices): ...@@ -702,9 +690,6 @@ def parse_index(x, indices):
use_strided_slice = False use_strided_slice = False
has_advanced_index = False has_advanced_index = False
if isinstance(indices, list) and not is_one_dim_list(indices, int):
indices = tuple(indices)
if not isinstance(indices, tuple): if not isinstance(indices, tuple):
indices = (indices,) indices = (indices,)
......
...@@ -138,6 +138,46 @@ class TestGetitemInDygraph(unittest.TestCase): ...@@ -138,6 +138,46 @@ class TestGetitemInDygraph(unittest.TestCase):
np.testing.assert_allclose(y.numpy(), np_res) np.testing.assert_allclose(y.numpy(), np_res)
def test_indexing_with_bool_list1(self):
# test bool-list indexing when axes num less than x.rank
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_res = np_data[[True, False, True], [False, False, False, True]]
x = paddle.to_tensor(np_data)
y = x[[True, False, True], [False, False, False, True]]
np.testing.assert_allclose(y.numpy(), np_res)
def test_indexing_with_bool_list2(self):
# test bool-list indexing when axes num less than x.rank
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_res = np_data[
[True, False, True],
[False, False, True, False],
[True, False, False, True, False],
]
x = paddle.to_tensor(np_data)
y = x[
[True, False, True],
[False, False, True, False],
[True, False, False, True, False],
]
np.testing.assert_allclose(y.numpy(), np_res)
def test_indexing_is_multi_dim_list(self):
# indexing is multi-dim int list, should be treat as one index, like numpy>=1.23
np_data = np.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3))
np_res = np_data[np.array([[2, 3, 4], [1, 2, 5]])]
x = paddle.to_tensor(np_data)
y = x[[[2, 3, 4], [1, 2, 5]]]
y_index_tensor = x[paddle.to_tensor([[2, 3, 4], [1, 2, 5]])]
np.testing.assert_allclose(y.numpy(), np_res)
np.testing.assert_allclose(y.numpy(), y_index_tensor.numpy())
class TestGetitemInStatic(unittest.TestCase): class TestGetitemInStatic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -334,6 +374,64 @@ class TestGetitemInStatic(unittest.TestCase): ...@@ -334,6 +374,64 @@ class TestGetitemInStatic(unittest.TestCase):
np.testing.assert_allclose(res[0], np_res) np.testing.assert_allclose(res[0], np_res)
def test_indexing_with_bool_list1(self):
# test bool-list indexing when axes num less than x.rank
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_res = np_data[[True, False, True], [False, False, False, True]]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(
x, ([True, False, True], [False, False, False, True])
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
def test_indexing_with_bool_list2(self):
# test bool-list indexing when axes num less than x.rank
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_res = np_data[
[True, False, True],
[False, False, True, False],
[True, False, False, True, False],
]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(
x,
(
[True, False, True],
[False, False, True, False],
[True, False, False, True, False],
),
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_res)
def test_indexing_is_multi_dim_list(self):
# indexing is multi-dim int list, should be treat as one index, like numpy>=1.23
np_data = np.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3))
np_res = np_data[np.array([[2, 3, 4], [1, 2, 5]])]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(x, ([[2, 3, 4], [1, 2, 5]]))
y_index_tensor = _getitem_static(
x, paddle.to_tensor([[2, 3, 4], [1, 2, 5]])
)
res = self.exe.run(fetch_list=[y.name, y_index_tensor.name])
np.testing.assert_allclose(res[0], np_res)
np.testing.assert_allclose(res[1], np_res)
class TestGetItemErrorCase(unittest.TestCase): class TestGetItemErrorCase(unittest.TestCase):
def setUp(self): def setUp(self):
......
...@@ -60,6 +60,44 @@ class TestSetitemInDygraph(unittest.TestCase): ...@@ -60,6 +60,44 @@ class TestSetitemInDygraph(unittest.TestCase):
np.testing.assert_allclose(x.numpy(), np_data) np.testing.assert_allclose(x.numpy(), np_data)
def test_indexing_with_bool_list1(self):
# test bool-list indexing when axes num less than x.rank
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_data[[True, False, True], [False, False, False, True]] = 7
x = paddle.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
x[[True, False, True], [False, False, False, True]] = 7
np.testing.assert_allclose(x.numpy(), np_data)
def test_indexing_with_bool_list2(self):
# test bool-list indexing when axes num less than x.rank
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_data[
[True, False, True],
[False, False, True, False],
[True, False, False, True, False],
] = 8
x = paddle.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
x[
[True, False, True],
[False, False, True, False],
[True, False, False, True, False],
] = 8
np.testing.assert_allclose(x.numpy(), np_data)
def test_indexing_is_multi_dim_list(self):
# indexing is multi-dim int list, should be treat as one index, like numpy>=1.23
np_data = np.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3))
np_data[np.array([[2, 3, 4], [1, 2, 5]])] = 100
x = paddle.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3))
x[[[2, 3, 4], [1, 2, 5]]] = 100
np.testing.assert_allclose(x.numpy(), np_data)
class TestSetitemInStatic(unittest.TestCase): class TestSetitemInStatic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -162,3 +200,58 @@ class TestSetitemInStatic(unittest.TestCase): ...@@ -162,3 +200,58 @@ class TestSetitemInStatic(unittest.TestCase):
res = self.exe.run(fetch_list=[y.name]) res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_data) np.testing.assert_allclose(res[0], np_data)
def test_indexing_with_bool_list1(self):
# test bool-list indexing when axes num less than x.rank
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_data[[True, False, True], [False, False, False, True]] = 7
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
y = _setitem_static(
x, ([True, False, True], [False, False, False, True]), 7
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_data)
def test_indexing_with_bool_list2(self):
# test bool-list indexing when axes num less than x.rank
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_data[
[True, False, True],
[False, False, True, False],
[True, False, False, True, False],
] = 8
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
y = _setitem_static(
x,
(
[True, False, True],
[False, False, True, False],
[True, False, False, True, False],
),
8,
)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_data)
def test_indexing_is_multi_dim_list(self):
# indexing is multi-dim int list, should be treat as one index, like numpy>=1.23
np_data = np.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3))
np_data[np.array([[2, 3, 4], [1, 2, 5]])] = 10
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.arange(3 * 4 * 5 * 6).reshape((6, 5, 4, 3))
y = _setitem_static(x, [[[2, 3, 4], [1, 2, 5]]], 10)
res = self.exe.run(fetch_list=[y.name])
np.testing.assert_allclose(res[0], np_data)
...@@ -991,21 +991,11 @@ class TestVarBase(unittest.TestCase): ...@@ -991,21 +991,11 @@ class TestVarBase(unittest.TestCase):
x = paddle.to_tensor(array) x = paddle.to_tensor(array)
py_idx = [[0, 2, 0, 1, 3], [0, 0, 1, 2, 0]] py_idx = [[0, 2, 0, 1, 3], [0, 0, 1, 2, 0]]
# note(chenjianye):
# Non-tuple sequence for multidimensional indexing is supported in numpy < 1.23.
# For List case, the outermost `[]` will be treated as tuple `()` in version less than 1.23,
# which is used to wrap index elements for multiple axes.
# And from 1.23, this will be treat as a whole and only works on one axis.
#
# e.g. x[[[0],[1]]] == x[([0],[1])] == x[[0],[1]] (in version < 1.23)
# x[[[0],[1]]] == x[array([[0],[1]])] (in version >= 1.23)
#
# Here, we just modify the code to remove the impact of numpy version changes,
# changing x[[[0],[1]]] to x[tuple([[0],[1]])] == x[([0],[1])] == x[[0],[1]].
# Whether the paddle behavior in this case will change is still up for debate.
idx = [paddle.to_tensor(py_idx[0]), paddle.to_tensor(py_idx[1])] idx = [paddle.to_tensor(py_idx[0]), paddle.to_tensor(py_idx[1])]
np.testing.assert_array_equal(x[idx].numpy(), array[tuple(py_idx)]) np.testing.assert_array_equal(x[idx].numpy(), array[np.array(py_idx)])
np.testing.assert_array_equal(x[py_idx].numpy(), array[tuple(py_idx)]) np.testing.assert_array_equal(
x[py_idx].numpy(), array[np.array(py_idx)]
)
# case2: # case2:
tensor_x = paddle.to_tensor( tensor_x = paddle.to_tensor(
np.zeros(12).reshape(2, 6).astype(np.float32) np.zeros(12).reshape(2, 6).astype(np.float32)
......
...@@ -585,19 +585,6 @@ class TestVariableSlice(unittest.TestCase): ...@@ -585,19 +585,6 @@ class TestVariableSlice(unittest.TestCase):
class TestListIndex(unittest.TestCase): class TestListIndex(unittest.TestCase):
# note(chenjianye):
# Non-tuple sequence for multidimensional indexing is supported in numpy < 1.23.
# For List case, the outermost `[]` will be treated as tuple `()` in version less than 1.23,
# which is used to wrap index elements for multiple axes.
# And from 1.23, this will be treat as a whole and only works on one axis.
#
# e.g. x[[[0],[1]]] == x[([0],[1])] == x[[0],[1]] (in version < 1.23)
# x[[[0],[1]]] == x[array([[0],[1]])] (in version >= 1.23)
#
# Here, we just modify the code to remove the impact of numpy version changes,
# changing x[[[0],[1]]] to x[tuple([[0],[1]])] == x[([0],[1])] == x[[0],[1]].
# Whether the paddle behavior in this case will change is still up for debate.
def setUp(self): def setUp(self):
np.random.seed(2022) np.random.seed(2022)
...@@ -639,7 +626,7 @@ class TestListIndex(unittest.TestCase): ...@@ -639,7 +626,7 @@ class TestListIndex(unittest.TestCase):
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
fetch_list = [y.name] fetch_list = [y.name]
getitem_np = array[tuple(index_mod)] getitem_np = array[np.array(index_mod)]
getitem_pp = exe.run( getitem_pp = exe.run(
prog, feed={x.name: array}, fetch_list=fetch_list prog, feed={x.name: array}, fetch_list=fetch_list
) )
...@@ -660,7 +647,7 @@ class TestListIndex(unittest.TestCase): ...@@ -660,7 +647,7 @@ class TestListIndex(unittest.TestCase):
pt = paddle.to_tensor(array) pt = paddle.to_tensor(array)
index_mod = (index % (array.shape[-1])).tolist() index_mod = (index % (array.shape[-1])).tolist()
try: try:
getitem_np = array[tuple(index_mod)] getitem_np = array[np.array(index_mod)]
except: except:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -845,7 +832,7 @@ class TestListIndex(unittest.TestCase): ...@@ -845,7 +832,7 @@ class TestListIndex(unittest.TestCase):
array2 = array.copy() array2 = array.copy()
try: try:
index = ( index = (
tuple(index) np.array(index)
if isinstance(index, list) and isinstance(index[0], list) if isinstance(index, list) and isinstance(index[0], list)
else index else index
) )
...@@ -869,12 +856,12 @@ class TestListIndex(unittest.TestCase): ...@@ -869,12 +856,12 @@ class TestListIndex(unittest.TestCase):
def test_static_graph_setitem_list_index(self): def test_static_graph_setitem_list_index(self):
paddle.enable_static() paddle.enable_static()
# case 1: # case 1:
inps_shape = [3, 4, 5, 2, 3] inps_shape = [4, 5, 2]
array = np.arange(self.numel(inps_shape), dtype='float32').reshape( array = np.arange(self.numel(inps_shape), dtype='float32').reshape(
inps_shape inps_shape
) )
index_shape = [3, 3, 1, 2] index_shape = [3, 3, 1]
index = np.arange(self.numel(index_shape)).reshape(index_shape) index = np.arange(self.numel(index_shape)).reshape(index_shape)
value_shape = inps_shape[3:] value_shape = inps_shape[3:]
...@@ -897,12 +884,12 @@ class TestListIndex(unittest.TestCase): ...@@ -897,12 +884,12 @@ class TestListIndex(unittest.TestCase):
index = index[0] index = index[0]
# case 2: # case 2:
inps_shape = [3, 4, 5, 4, 3] inps_shape = [4, 5, 4]
array = np.arange(self.numel(inps_shape), dtype='float32').reshape( array = np.arange(self.numel(inps_shape), dtype='float32').reshape(
inps_shape inps_shape
) )
index_shape = [4, 3, 2, 2] index_shape = [4, 3, 2]
index = np.arange(self.numel(index_shape)).reshape(index_shape) index = np.arange(self.numel(index_shape)).reshape(index_shape)
value_shape = [3] value_shape = [3]
...@@ -913,7 +900,7 @@ class TestListIndex(unittest.TestCase): ...@@ -913,7 +900,7 @@ class TestListIndex(unittest.TestCase):
+ 100 + 100
) )
for _ in range(4): for _ in range(3):
program = paddle.static.Program() program = paddle.static.Program()
index_mod = (index % (min(array.shape))).tolist() index_mod = (index % (min(array.shape))).tolist()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册