From eb6d7da947a9ec9151503d069d6329750e5a764c Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Mon, 4 Apr 2022 21:54:50 +0800 Subject: [PATCH] support getitem when index is a all-false bool tensor (#41297) * support getitem when index is a all-false bool tensor * use cond to replace if * add static_graph geitem unit test when index is a bool tensor --- .../fluid/tests/unittests/test_var_base.py | 11 ++-- .../fluid/tests/unittests/test_variable.py | 55 +++++++++++++++++++ python/paddle/fluid/variable_index.py | 49 +++++++++++------ 3 files changed, 94 insertions(+), 21 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 11d77ecc622..ef57ba15302 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -795,17 +795,17 @@ class TestVarBase(unittest.TestCase): np_value = np.random.random(shape).astype('float32') var_tensor = paddle.to_tensor(np_value) index = [[True, True, True, True], [True, False, True, True], - [True, False, False, True], [False, 0, 1, True, True]] + [True, False, False, True], [False, 0, 1, True, True], + [False, False, False, False]] index2d = np.array([[True, True], [False, False], [True, False], [True, True]]) tensor_index = paddle.to_tensor(index2d) var = [ - var_tensor[index[0]].numpy(), - var_tensor[index[1]].numpy(), - var_tensor[index[2]].numpy(), - var_tensor[index[3]].numpy(), + var_tensor[index[0]].numpy(), var_tensor[index[1]].numpy(), + var_tensor[index[2]].numpy(), var_tensor[index[3]].numpy(), var_tensor[paddle.to_tensor(index[0])].numpy(), var_tensor[tensor_index].numpy(), + var_tensor[paddle.to_tensor(index[4])].numpy() ] self.assertTrue(np.array_equal(var[0], np_value[index[0]])) self.assertTrue(np.array_equal(var[1], np_value[index[1]])) @@ -813,6 +813,7 @@ class TestVarBase(unittest.TestCase): self.assertTrue(np.array_equal(var[3], np_value[index[3]])) self.assertTrue(np.array_equal(var[4], np_value[index[0]])) self.assertTrue(np.array_equal(var[5], np_value[index2d])) + self.assertTrue(np.array_equal(var[6], np_value[index[4]])) self.assertTrue( np.array_equal(var_tensor[var_tensor > 0.67], np_value[np_value > 0.67])) diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index b218739ff95..3a924669b00 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -690,6 +690,61 @@ class TestListIndex(unittest.TestCase): y = x[index_t1, index_t2] self.assertTrue(np.array_equal(y.numpy(), y_np)) + def run_getitem_list_index(self, array, index): + x = paddle.static.data(name='x', shape=array.shape, dtype='float32') + + y = x[index] + place = paddle.fluid.CPUPlace() + + prog = paddle.static.default_main_program() + exe = paddle.static.Executor(place) + + exe.run(paddle.static.default_startup_program()) + fetch_list = [y.name] + array2 = array.copy() + + try: + value_np = array2[index] + except: + with self.assertRaises(ValueError): + getitem_pp = exe.run(prog, + feed={x.name: array}, + fetch_list=fetch_list) + return + getitem_pp = exe.run(prog, feed={x.name: array}, fetch_list=fetch_list) + + print(getitem_pp) + self.assertTrue( + np.array_equal(value_np, getitem_pp[0]), + msg='\n numpy:{},\n paddle:{}'.format(value_np, getitem_pp[0])) + + def test_static_graph_getitem_bool_index(self): + paddle.enable_static() + + # case 1: + array = np.ones((4, 2, 3), dtype='float32') + value_np = np.random.random((2, 3)).astype('float32') + index = np.array([True, False, False, False]) + program = paddle.static.Program() + with paddle.static.program_guard(program): + self.run_getitem_list_index(array, index) + + # case 2: + array = np.ones((4, 2, 3), dtype='float32') + value_np = np.random.random((2, 3)).astype('float32') + index = np.array([False, True, False, False]) + program = paddle.static.Program() + with paddle.static.program_guard(program): + self.run_getitem_list_index(array, index) + + # case 3: + array = np.ones((4, 2, 3), dtype='float32') + value_np = np.random.random((2, 3)).astype('float32') + index = np.array([True, True, True, True]) + program = paddle.static.Program() + with paddle.static.program_guard(program): + self.run_getitem_list_index(array, index) + def run_setitem_list_index(self, array, index, value_np): x = paddle.static.data(name='x', shape=array.shape, dtype='float32') diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index e6990e25a08..257ddc96d9c 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -279,6 +279,37 @@ def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags): attrs[attr_name] = attr +# the item is a tensor of bool +def get_value_for_bool_tensor(var, item): + if len(item.shape) > len(var.shape): + raise IndexError("The dims of bool index doesn't match indexed array, " + "the dims of bool index except to be equal or less " + "than {}, but received {}.".format( + len(var.shape), len(item.shape))) + for i, dim_len in enumerate(item.shape): + if dim_len != var.shape[i]: + raise IndexError( + "The dimension of bool index doesn't match indexed array along "\ + "dimension {}, the target dimension is {}, but received {}.". + format(i, var.shape[i], dim_len)) + + def idx_not_empty(var, item): + from .layers.nn import where + from ..tensor import gather_nd + + bool_2_idx = where(item == True) + return gather_nd(var, bool_2_idx) + + def idx_empty(var): + var_shape = list(var.shape) + var_shape[0] = 0 + return paddle.empty(var_shape, dtype=var.dtype) + + from .layers.control_flow import cond + return cond(item.any(), lambda: idx_not_empty(var, item), + lambda: idx_empty(var)) + + def _getitem_impl_(var, item): """ Slice the variable. @@ -393,24 +424,10 @@ def _getitem_impl_(var, item): elif isinstance(slice_item, (Variable, core.eager.Tensor)): if len(item) == 1: - from ..tensor import index_select, gather_nd - from .layers.nn import where + from ..tensor import index_select if slice_item.dtype == paddle.bool: - if len(slice_item.shape) > len(var.shape): - raise IndexError( - "The dims of bool index doesn't match indexed array, " - "the dims of bool index except to be equal or less " - "than {}, but received {}.".format( - len(var.shape), len(slice_item.shape))) - for i, dim_len in enumerate(slice_item.shape): - if dim_len != var.shape[i]: - raise IndexError( - "The dimension of bool index doesn't match indexed array along "\ - "dimension {}, the target dimension is {}, but received {}.". - format(i, var.shape[i], dim_len)) - bool_2_idx = where(slice_item == True) - return gather_nd(var, bool_2_idx) + return get_value_for_bool_tensor(var, slice_item) else: if len(slice_item.shape) == 1: return index_select(var, index=slice_item, axis=0) -- GitLab