diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 4162fa436798f6c5be6705a327fa1aa40344a692..c1956545f55ad1333124bca03608d35d43cf3fd6 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -245,6 +245,34 @@ class TestVariable(unittest.TestCase): with self.assertRaises(TypeError): res = x[[1.2, 0]] + def _test_slice_index_list_bool(self, place): + data = np.random.rand(2, 3).astype("float32") + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + idx0 = [True, False] + idx1 = [False, True] + idx2 = [False, False] + idx3 = [True, True] + + out0 = x[idx0] + out1 = x[idx1] + out2 = x[idx2] + out3 = x[idx3] + + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=[out0, out1, out2, out3]) + + expected = [data[idx0], data[idx1], data[idx2], data[idx3]] + + self.assertTrue((result[0] == expected[0]).all()) + self.assertTrue((result[1] == expected[1]).all()) + self.assertTrue((result[2] == expected[2]).all()) + self.assertTrue((result[3] == expected[3]).all()) + + with self.assertRaises(TypeError): + res = x[[True, 0]] + def test_slice(self): places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): @@ -255,6 +283,7 @@ class TestVariable(unittest.TestCase): self._test_slice_index_tensor(place) self._test_slice_index_list(place) self._test_slice_index_ellipsis(place) + self._test_slice_index_list_bool(place) def _tostring(self): b = default_main_program().current_block() diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index e289ae7f837d5ef939292ef2d5d6d1b6c376c283..c6ddba7feade333b7bc60d66390d1903324b3521 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -140,19 +140,36 @@ def _getitem_impl_(var, item): end = MAX_INTEGER if end is None else end elif isinstance(slice_item, list): + is_bool_list = False for i in slice_item: - if not isinstance(i, int): - raise TypeError("Only support int value in list") + if not isinstance(i, (int, bool)): + raise TypeError("Only support int or bool in index list.") + + if isinstance(i, bool): + is_bool_list = True + break if len(item) != 1: raise IndexError( "When index contains a list, its length must be 1, but received {}". format(len(item))) + if is_bool_list: + new_slice_item = [] + for idx, ele in enumerate(slice_item): + if not isinstance(ele, bool): + raise TypeError( + "Mixed bool index with other types is not supported." + ) + + if ele is True: + new_slice_item.append(idx) + slice_item = new_slice_item + from .layers import assign from ..tensor import index_select - idx = assign(np.array(slice_item)) + idx = assign(np.array(slice_item).astype("int32")) return index_select(var, index=idx, axis=0) elif isinstance(slice_item, Variable):