From 0f5e90a2bcc15e9342a691c77693ed569640706c Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Fri, 25 Mar 2022 10:04:25 +0800 Subject: [PATCH] support get_item where the index is a bool scalar tensor (#40829) * support get_item where the index is a bool scalar tensor * add unittests for supporting get_item where the index is a bool scalar tensor --- .../fluid/tests/unittests/test_var_base.py | 92 +++++++++++++++++++ .../fluid/tests/unittests/test_variable.py | 56 ++++++++++- python/paddle/fluid/variable_index.py | 13 ++- 3 files changed, 158 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 4b3e935426f..9772a343f49 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -829,6 +829,15 @@ class TestVarBase(unittest.TestCase): with self.assertRaises(IndexError): var_tensor[paddle.to_tensor([[True, False, False, False]])] + def _test_scalar_bool_index(self): + shape = (1, 2, 5, 64) + np_value = np.random.random(shape).astype('float32') + var_tensor = paddle.to_tensor(np_value) + index = [True] + tensor_index = paddle.to_tensor(index) + var = [var_tensor[tensor_index].numpy(), ] + self.assertTrue(np.array_equal(var[0], np_value[index])) + def _test_for_var(self): np_value = np.random.random((30, 100, 100)).astype('float32') w = fluid.dygraph.to_variable(np_value) @@ -883,6 +892,7 @@ class TestVarBase(unittest.TestCase): self._test_for_getitem_ellipsis_index() self._test_none_index() self._test_bool_index() + self._test_scalar_bool_index() self._test_numpy_index() self._test_list_index() @@ -1219,6 +1229,88 @@ class TestVarBaseSetitemFp64(TestVarBaseSetitem): self.dtype = "float64" +class TestVarBaseSetitemBoolIndex(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.set_dtype() + self.set_input() + + def set_input(self): + self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(self.dtype)) + self.np_value = np.random.random((2, 3)).astype(self.dtype) + self.tensor_value = paddle.to_tensor(self.np_value) + + def set_dtype(self): + self.dtype = "int32" + + def _test(self, value): + paddle.disable_static() + self.assertEqual(self.tensor_x.inplace_version, 0) + + id_origin = id(self.tensor_x) + index_1 = paddle.to_tensor(np.array([True, False, False, False])) + self.tensor_x[index_1] = value + self.assertEqual(self.tensor_x.inplace_version, 1) + + if isinstance(value, (six.integer_types, float)): + result = np.zeros((2, 3)).astype(self.dtype) + value + + else: + result = self.np_value + + self.assertTrue(np.array_equal(self.tensor_x[0].numpy(), result)) + self.assertEqual(id_origin, id(self.tensor_x)) + + index_2 = paddle.to_tensor(np.array([False, True, False, False])) + self.tensor_x[index_2] = value + self.assertEqual(self.tensor_x.inplace_version, 2) + self.assertTrue(np.array_equal(self.tensor_x[1].numpy(), result)) + self.assertEqual(id_origin, id(self.tensor_x)) + + index_3 = paddle.to_tensor(np.array([True, True, True, True])) + self.tensor_x[index_3] = value + self.assertEqual(self.tensor_x.inplace_version, 3) + self.assertTrue(np.array_equal(self.tensor_x[3].numpy(), result)) + self.assertEqual(id_origin, id(self.tensor_x)) + + def test_value_tensor(self): + paddle.disable_static() + self._test(self.tensor_value) + + def test_value_numpy(self): + paddle.disable_static() + self._test(self.np_value) + + def test_value_int(self): + paddle.disable_static() + self._test(10) + + +class TestVarBaseSetitemBoolScalarIndex(unittest.TestCase): + def set_input(self): + self.tensor_x = paddle.to_tensor(np.ones((1, 2, 3)).astype(self.dtype)) + self.np_value = np.random.random((2, 3)).astype(self.dtype) + self.tensor_value = paddle.to_tensor(self.np_value) + + def _test(self, value): + paddle.disable_static() + self.assertEqual(self.tensor_x.inplace_version, 0) + + id_origin = id(self.tensor_x) + index = paddle.to_tensor(np.array([True])) + self.tensor_x[index] = value + self.assertEqual(self.tensor_x.inplace_version, 1) + + if isinstance(value, (six.integer_types, float)): + result = np.zeros((2, 3)).astype(self.dtype) + value + + else: + result = self.np_value + + self.assertTrue(np.array_equal(self.tensor_x[0].numpy(), result)) + self.assertEqual(id_origin, id(self.tensor_x)) + + class TestVarBaseInplaceVersion(unittest.TestCase): def test_setitem(self): paddle.disable_static() diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index beaf361379b..afa1ac0ad65 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -336,6 +336,23 @@ class TestVariable(unittest.TestCase): with paddle.static.program_guard(prog): res = x[[False, False]] + def _test_slice_index_scalar_bool(self, place): + data = np.random.rand(1, 3, 4).astype("float32") + np_idx = np.array([True]) + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + idx = paddle.assign(np_idx) + + out = x[idx] + + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=[out]) + + expected = [data[np_idx]] + + self.assertTrue((result[0] == expected[0]).all()) + def test_slice(self): places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): @@ -347,6 +364,7 @@ class TestVariable(unittest.TestCase): self._test_slice_index_list(place) self._test_slice_index_ellipsis(place) self._test_slice_index_list_bool(place) + self._test_slice_index_scalar_bool(place) def _tostring(self): b = default_main_program().current_block() @@ -705,7 +723,7 @@ class TestListIndex(unittest.TestCase): fetch_list=fetch_list) self.assertTrue( - np.array_equal(array2, setitem_pp[0]), + np.allclose(array2, setitem_pp[0]), msg='\n numpy:{},\n paddle:{}'.format(array2, setitem_pp[0])) def test_static_graph_setitem_list_index(self): @@ -769,6 +787,42 @@ class TestListIndex(unittest.TestCase): index_mod = (index % (min(array.shape))).tolist() self.run_setitem_list_index(array, index_mod, value_np) + def test_static_graph_setitem_bool_index(self): + paddle.enable_static() + + # case 1: + array = np.ones((4, 2, 3), dtype='float32') + value_np = np.random.random((2, 3)).astype('float32') + index = np.array([True, False, False, False]) + program = paddle.static.Program() + with paddle.static.program_guard(program): + self.run_setitem_list_index(array, index, value_np) + + # case 2: + array = np.ones((4, 2, 3), dtype='float32') + value_np = np.random.random((2, 3)).astype('float32') + index = np.array([False, True, False, False]) + program = paddle.static.Program() + with paddle.static.program_guard(program): + self.run_setitem_list_index(array, index, value_np) + + # case 3: + array = np.ones((4, 2, 3), dtype='float32') + value_np = np.random.random((2, 3)).astype('float32') + index = np.array([True, True, True, True]) + program = paddle.static.Program() + with paddle.static.program_guard(program): + self.run_setitem_list_index(array, index, value_np) + + def test_static_graph_setitem_bool_scalar_index(self): + paddle.enable_static() + array = np.ones((1, 2, 3), dtype='float32') + value_np = np.random.random((2, 3)).astype('float32') + index = np.array([True]) + program = paddle.static.Program() + with paddle.static.program_guard(program): + self.run_setitem_list_index(array, index, value_np) + def test_static_graph_tensor_index_setitem_muti_dim(self): paddle.enable_static() inps_shape = [3, 4, 5, 4] diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 1c7e4fb5f1a..d94664dd77f 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -255,6 +255,13 @@ def is_integer_or_scalar_tensor(ele): return False +def is_bool_tensor(ele): + from .framework import Variable + if isinstance(ele, Variable) and ele.dtype == paddle.bool: + return True + return False + + def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags): from .framework import Variable from .layers import utils @@ -304,7 +311,8 @@ def _getitem_impl_(var, item): slice_info = SliceInfo() for dim, slice_item in enumerate(item): - if is_integer_or_scalar_tensor(slice_item): + if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor( + slice_item): if isinstance(slice_item, int) and var.shape[dim] is not None and var.shape[ dim] >= 0 and slice_item >= var.shape[dim]: @@ -523,7 +531,8 @@ def _setitem_impl_(var, item, value): slice_info = SliceInfo() dim = 0 for _, slice_item in enumerate(item): - if is_integer_or_scalar_tensor(slice_item): + if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor( + slice_item): decrease_axes.append(dim) start = slice_item end = slice_item + 1 if slice_item != -1 else MAX_INTEGER -- GitLab