From b6dc16cb771ba637290cfaaa5c48f6ac0d940203 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 23 Aug 2021 15:07:23 +0800 Subject: [PATCH] Support gettiem by Bool index (#35026) * Support getitem by Bool index * delete some debug info of bool index * support the case that the shape of bool index is different from indexed tensor --- paddle/fluid/operators/index_select_op.cc | 4 + paddle/fluid/pybind/imperative.cc | 107 ++++++++++++++---- .../fluid/tests/unittests/test_var_base.py | 40 +++++++ .../fluid/tests/unittests/test_variable.py | 31 +++-- python/paddle/fluid/variable_index.py | 57 +++++++--- 5 files changed, 196 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/operators/index_select_op.cc b/paddle/fluid/operators/index_select_op.cc index 60ca7e2fe7c..5bf16960170 100644 --- a/paddle/fluid/operators/index_select_op.cc +++ b/paddle/fluid/operators/index_select_op.cc @@ -54,6 +54,10 @@ class IndexSelectOp : public framework::OperatorWithKernel { "the dimension of Input(Index) is [%d].", index_dim, index_dim.size())); + PADDLE_ENFORCE_EQ(index_dim[0] != 0, true, + platform::errors::InvalidArgument( + "The length of Input(Index) can't be 0.")); + auto output_dim = framework::vectorize(input_dim); if (dim < 0) { dim += input_dim.size(); diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 6c4213979a4..777d68ea2f9 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -414,17 +414,15 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length, return 0; } -static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, - std::vector *slice_axes, - std::vector *slice_starts, - std::vector *slice_ends, - std::vector *slice_strides, - std::vector *decrease_axis, - std::vector *none_axes, - std::vector *infer_flags) { - // We allow indexing by Integers, Slices, and tuples of those - // types. - // Ellipsis and None are not supported yet. +static void ParseIndexingSlice( + framework::LoDTensor *tensor, PyObject *_index, + std::vector *slice_axes, std::vector *slice_starts, + std::vector *slice_ends, std::vector *slice_strides, + std::vector *decrease_axis, std::vector *none_axes, + std::vector *infer_flags, std::vector *list_select_idxs, + bool *list_select_flag) { + // We allow indexing by Integers, Slices, Ellipsis, None, tuples of those + // types, and list of Bool and Integers. // wrap to tuple PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index; PADDLE_ENFORCE_EQ( @@ -490,11 +488,58 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, dim += rank - specified_dims; } else if (slice_item == Py_None) { none_axes->push_back(dim); + } else if (PyList_Check(slice_item)) { + *list_select_flag = true; + if (size != 1) { + PADDLE_THROW(platform::errors::InvalidArgument( + "When index contains a list, its length is excepted to 1, " + "but received %d", + size)); + } + bool all_bool = true; + int list_size = PyList_GET_SIZE(slice_item); + for (int j = 0; j < list_size; ++j) { + PyObject *list_item = PyList_GetItem(slice_item, j); + if (PyCheckInteger(list_item)) { + all_bool = false; + } else if (!PyBool_Check(list_item)) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only support int or bool in index list.")); + } + } + if (all_bool) { + if (list_size != shape[0]) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The dimension of bool index doesn't match indexed array along " + "dimension 0, the target dimension is %d, but received %d.", + shape[0], list_size)); + } + for (int j = 0; j < list_size; ++j) { + PyObject *list_item = PyList_GetItem(slice_item, j); + if (list_item == Py_True) { + list_select_idxs->push_back(j); + } + } + } else { + for (int j = 0; j < list_size; ++j) { + PyObject *list_item = PyList_GetItem(slice_item, j); + if (PyCheckInteger(list_item)) { + list_select_idxs->push_back( + static_cast(PyLong_AsLong(list_item))); + } else if (list_item == Py_True) { + list_select_idxs->push_back(1); + } else { + list_select_idxs->push_back(0); + } + } + } + } else { PADDLE_THROW(platform::errors::InvalidArgument( - "Currently, VarBase.__getitem__() only allows indexing" - "by Integers, Slices, Ellipsis, None and tuples of " - "these types, but received %s in %dth slice item", + "Currently, VarBase.__getitem__() only allows indexing " + "by Integers, Slices, Ellipsis, None, tuples of these types " + "and list of Bool and Integers, but received " + "%s in %dth slice item", std::string(Py_TYPE(slice_item)->tp_name), i + 1)); } } @@ -798,10 +843,13 @@ void BindImperative(py::module *m_ptr) { // copys data to cpu place, which reduces performance. if (parse_index && value_is_tensor) { std::vector axes, starts, ends, steps, decrease_axes, - none_axes, infer_flags; + none_axes, infer_flags, list_select_idxs; + // if index is a list, list_select_flag will be true + bool list_select_flag; ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends, &steps, &decrease_axes, &none_axes, - &infer_flags); + &infer_flags, &list_select_idxs, + &list_select_flag); framework::AttributeMap attrs = { {"axes", axes}, @@ -860,21 +908,26 @@ void BindImperative(py::module *m_ptr) { .def("_getitem_index_not_tensor", [](std::shared_ptr &self, py::handle _index) { std::vector slice_axes, slice_starts, slice_ends, - slice_strides, decrease_axis, none_axes, infer_flags; + slice_strides, decrease_axis, none_axes, infer_flags, + list_select_idxs; + // if index is a list, list_select_flag will be true + bool list_select_flag = false; auto tensor = self->MutableVar()->GetMutable(); ParseIndexingSlice(tensor, _index.ptr(), &slice_axes, &slice_starts, &slice_ends, &slice_strides, - &decrease_axis, &none_axes, &infer_flags); + &decrease_axis, &none_axes, &infer_flags, + &list_select_idxs, &list_select_flag); // release gil and do tracing py::gil_scoped_release release; const auto &tracer = imperative::GetCurrentTracer(); - auto out = slice_axes.empty() + auto out = slice_axes.empty() && !list_select_flag ? self : std::shared_ptr( new imperative::VarBase( tracer->GenerateUniqueName())); + if (!slice_axes.empty()) { imperative::NameVarBaseMap ins = {{"Input", {self}}}; framework::AttributeMap attrs = { @@ -960,6 +1013,22 @@ void BindImperative(py::module *m_ptr) { } } + // the index is a list + if (list_select_flag) { + auto select_index = std::shared_ptr( + new imperative::VarBase(tracer->GenerateUniqueName())); + auto *idx_tensor = select_index->MutableVar() + ->GetMutable(); + auto *dev_ctx = platform::DeviceContextPool::Instance().Get( + tracer->ExpectedPlace()); + TensorFromVector(list_select_idxs, *dev_ctx, idx_tensor); + + imperative::NameVarBaseMap ins = {{"X", {self}}, + {"Index", {select_index}}}; + imperative::NameVarBaseMap outs = {{"Out", {out}}}; + tracer->TraceOp("index_select", ins, outs, {{"dim", 0}}); + } + return out; }) .def( diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index cdf34c27c0a..416f125caa2 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -733,6 +733,45 @@ class TestVarBase(unittest.TestCase): # self.assertTrue( # np.array_equal(var[10], np_value[0, 1:10:2, None, None, ...])) + def _test_bool_index(self): + shape = (4, 2, 5, 64) + np_value = np.random.random(shape).astype('float32') + var_tensor = paddle.to_tensor(np_value) + index = [[True, True, True, True], [True, False, True, True], + [True, False, False, True], [False, 0, 1, True, True]] + index2d = np.array([[True, True], [False, False], [True, False], + [True, True]]) + tensor_index = paddle.to_tensor(index2d) + var = [ + var_tensor[index[0]].numpy(), + var_tensor[index[1]].numpy(), + var_tensor[index[2]].numpy(), + var_tensor[index[3]].numpy(), + var_tensor[paddle.to_tensor(index[0])].numpy(), + var_tensor[tensor_index].numpy(), + ] + self.assertTrue(np.array_equal(var[0], np_value[index[0]])) + self.assertTrue(np.array_equal(var[1], np_value[index[1]])) + self.assertTrue(np.array_equal(var[2], np_value[index[2]])) + self.assertTrue(np.array_equal(var[3], np_value[index[3]])) + self.assertTrue(np.array_equal(var[4], np_value[index[0]])) + self.assertTrue(np.array_equal(var[5], np_value[index2d])) + self.assertTrue( + np.array_equal(var_tensor[var_tensor > 0.67], np_value[np_value > + 0.67])) + self.assertTrue( + np.array_equal(var_tensor[var_tensor < 0.55], np_value[np_value < + 0.55])) + + with self.assertRaises(ValueError): + var_tensor[[False, False, False, False]] + with self.assertRaises(ValueError): + var_tensor[[True, False]] + with self.assertRaises(ValueError): + var_tensor[[True, False, False, False, False]] + with self.assertRaises(IndexError): + var_tensor[paddle.to_tensor([[True, False, False, False]])] + def _test_for_var(self): np_value = np.random.random((30, 100, 100)).astype('float32') w = fluid.dygraph.to_variable(np_value) @@ -747,6 +786,7 @@ class TestVarBase(unittest.TestCase): self._test_for_var() self._test_for_getitem_ellipsis_index() self._test_none_index() + self._test_bool_index() var = fluid.dygraph.to_variable(self.array) self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :])) diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index a998d58fdbc..0c120100faf 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -246,32 +246,49 @@ class TestVariable(unittest.TestCase): res = x[[1.2, 0]] def _test_slice_index_list_bool(self, place): - data = np.random.rand(2, 3).astype("float32") + data = np.random.rand(2, 3, 4).astype("float32") + np_idx = np.array([[True, False, False], [True, False, True]]) prog = paddle.static.Program() with paddle.static.program_guard(prog): x = paddle.assign(data) idx0 = [True, False] idx1 = [False, True] - idx2 = [False, False] - idx3 = [True, True] + idx2 = [True, True] + idx3 = [False, False, 1] + idx4 = [True, False, 0] + idx5 = paddle.assign(np_idx) out0 = x[idx0] out1 = x[idx1] out2 = x[idx2] out3 = x[idx3] + out4 = x[idx4] + out5 = x[idx5] + out6 = x[x < 0.36] + out7 = x[x > 0.6] exe = paddle.static.Executor(place) - result = exe.run(prog, fetch_list=[out0, out1, out2, out3]) + result = exe.run( + prog, fetch_list=[out0, out1, out2, out3, out4, out5, out6, out7]) - expected = [data[idx0], data[idx1], data[idx2], data[idx3]] + expected = [ + data[idx0], data[idx1], data[idx2], data[idx3], data[idx4], + data[np_idx], data[data < 0.36], data[data > 0.6] + ] self.assertTrue((result[0] == expected[0]).all()) self.assertTrue((result[1] == expected[1]).all()) self.assertTrue((result[2] == expected[2]).all()) self.assertTrue((result[3] == expected[3]).all()) + self.assertTrue((result[4] == expected[4]).all()) + self.assertTrue((result[5] == expected[5]).all()) + self.assertTrue((result[6] == expected[6]).all()) + self.assertTrue((result[7] == expected[7]).all()) - with self.assertRaises(TypeError): - res = x[[True, 0]] + with self.assertRaises(IndexError): + res = x[[True, False, False]] + with self.assertRaises(ValueError): + res = x[[False, False]] def test_slice(self): places = [fluid.CPUPlace()] diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 1ba44cea763..5bdf0451b54 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -150,31 +150,37 @@ def _getitem_impl_(var, item): end = MAX_INTEGER if step > 0 else -1 elif isinstance(slice_item, list): - is_bool_list = False + all_bool = True for i in slice_item: - if not isinstance(i, (int, bool)): + if type(i) is int: + all_bool = False + elif not isinstance(i, bool): raise TypeError("Only support int or bool in index list.") - if isinstance(i, bool): - is_bool_list = True - break - if len(item) != 1: raise IndexError( - "When index contains a list, its length must be 1, but received {}". + "When index contains a list, its length must be 1, but received {}.". format(len(item))) - - if is_bool_list: - new_slice_item = [] + new_slice_item = [] + if all_bool: + if len(slice_item) != var.shape[0]: + raise IndexError( + "The dimension of bool index doesn't match indexed array along "\ + "dimension 0, the target dimension is {}, but received {}.". + format(var.shape[0], len(slice_item))) for idx, ele in enumerate(slice_item): - if not isinstance(ele, bool): - raise TypeError( - "Mixed bool index with other types is not supported." - ) - if ele is True: new_slice_item.append(idx) slice_item = new_slice_item + else: + for idx, ele in enumerate(slice_item): + if type(ele) is int: + new_slice_item.append(ele) + elif ele is True: + new_slice_item.append(1) + else: + new_slice_item.append(0) + slice_item = new_slice_item from .layers import assign from ..tensor import index_select @@ -185,10 +191,27 @@ def _getitem_impl_(var, item): elif isinstance(slice_item, Variable): if len(item) != 1: raise IndexError( - "When index contains a Tensor, its length must be 1, but received {}". + "When index contains a Tensor, its length must be 1, but received {}.". format(len(item))) - from ..tensor import index_select + from ..tensor import index_select, gather_nd + from .layers.nn import where + + if slice_item.dtype == core.VarDesc.VarType.BOOL: + if len(slice_item.shape) > len(var.shape): + raise IndexError( + "The dims of bool index doesn't match indexed array, " + "the dims of bool index except to be equal or less " + "than {}, but received {}.".format( + len(var.shape), len(slice_item.shape))) + for i, dim_len in enumerate(slice_item.shape): + if dim_len != var.shape[i]: + raise IndexError( + "The dimension of bool index doesn't match indexed array along "\ + "dimension {}, the target dimension is {}, but received {}.". + format(i, var.shape[i], dim_len)) + bool_2_idx = where(slice_item == True) + return gather_nd(var, bool_2_idx) return index_select(var, index=slice_item, axis=0) else: -- GitLab