未验证 提交 9658f49b 编写于 作者: J JYChen 提交者: GitHub

fix UT when np >= 1.23 (#51466)

* fix UT when np >= 1.24

* optimize decription of this change
上级 9cd99f7e
...@@ -952,9 +952,22 @@ class TestVarBase(unittest.TestCase): ...@@ -952,9 +952,22 @@ class TestVarBase(unittest.TestCase):
array = np.arange(120).reshape([6, 5, 4]) array = np.arange(120).reshape([6, 5, 4])
x = paddle.to_tensor(array) x = paddle.to_tensor(array)
py_idx = [[0, 2, 0, 1, 3], [0, 0, 1, 2, 0]] py_idx = [[0, 2, 0, 1, 3], [0, 0, 1, 2, 0]]
# note(chenjianye):
# Non-tuple sequence for multidimensional indexing is supported in numpy < 1.23.
# For List case, the outermost `[]` will be treated as tuple `()` in version less than 1.23,
# which is used to wrap index elements for multiple axes.
# And from 1.23, this will be treat as a whole and only works on one axis.
#
# e.g. x[[[0],[1]]] == x[([0],[1])] == x[[0],[1]] (in version < 1.23)
# x[[[0],[1]]] == x[array([[0],[1]])] (in version >= 1.23)
#
# Here, we just modify the code to remove the impact of numpy version changes,
# changing x[[[0],[1]]] to x[tuple([[0],[1]])] == x[([0],[1])] == x[[0],[1]].
# Whether the paddle behavior in this case will change is still up for debate.
idx = [paddle.to_tensor(py_idx[0]), paddle.to_tensor(py_idx[1])] idx = [paddle.to_tensor(py_idx[0]), paddle.to_tensor(py_idx[1])]
np.testing.assert_array_equal(x[idx].numpy(), array[py_idx]) np.testing.assert_array_equal(x[idx].numpy(), array[tuple(py_idx)])
np.testing.assert_array_equal(x[py_idx].numpy(), array[py_idx]) np.testing.assert_array_equal(x[py_idx].numpy(), array[tuple(py_idx)])
# case2: # case2:
tensor_x = paddle.to_tensor( tensor_x = paddle.to_tensor(
np.zeros(12).reshape(2, 6).astype(np.float32) np.zeros(12).reshape(2, 6).astype(np.float32)
......
...@@ -596,6 +596,19 @@ class TestVariableSlice(unittest.TestCase): ...@@ -596,6 +596,19 @@ class TestVariableSlice(unittest.TestCase):
class TestListIndex(unittest.TestCase): class TestListIndex(unittest.TestCase):
# note(chenjianye):
# Non-tuple sequence for multidimensional indexing is supported in numpy < 1.23.
# For List case, the outermost `[]` will be treated as tuple `()` in version less than 1.23,
# which is used to wrap index elements for multiple axes.
# And from 1.23, this will be treat as a whole and only works on one axis.
#
# e.g. x[[[0],[1]]] == x[([0],[1])] == x[[0],[1]] (in version < 1.23)
# x[[[0],[1]]] == x[array([[0],[1]])] (in version >= 1.23)
#
# Here, we just modify the code to remove the impact of numpy version changes,
# changing x[[[0],[1]]] to x[tuple([[0],[1]])] == x[([0],[1])] == x[[0],[1]].
# Whether the paddle behavior in this case will change is still up for debate.
def setUp(self): def setUp(self):
np.random.seed(2022) np.random.seed(2022)
...@@ -637,7 +650,7 @@ class TestListIndex(unittest.TestCase): ...@@ -637,7 +650,7 @@ class TestListIndex(unittest.TestCase):
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
fetch_list = [y.name] fetch_list = [y.name]
getitem_np = array[index_mod] getitem_np = array[tuple(index_mod)]
getitem_pp = exe.run( getitem_pp = exe.run(
prog, feed={x.name: array}, fetch_list=fetch_list prog, feed={x.name: array}, fetch_list=fetch_list
) )
...@@ -659,7 +672,7 @@ class TestListIndex(unittest.TestCase): ...@@ -659,7 +672,7 @@ class TestListIndex(unittest.TestCase):
pt = paddle.to_tensor(array) pt = paddle.to_tensor(array)
index_mod = (index % (array.shape[-1])).tolist() index_mod = (index % (array.shape[-1])).tolist()
try: try:
getitem_np = array[index_mod] getitem_np = array[tuple(index_mod)]
except: except:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -844,8 +857,12 @@ class TestListIndex(unittest.TestCase): ...@@ -844,8 +857,12 @@ class TestListIndex(unittest.TestCase):
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
fetch_list = [y.name] fetch_list = [y.name]
array2 = array.copy() array2 = array.copy()
try: try:
index = (
tuple(index)
if isinstance(index, list) and isinstance(index[0], list)
else index
)
array2[index] = value_np array2[index] = value_np
except: except:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册