From a22563669dd03c363b4a7c4636f1d4c9c74c6fb8 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Thu, 10 Jun 2021 11:53:59 +0800 Subject: [PATCH] [static getitem]Support index is list bool for getitem in static mode (#33298) --- .../fluid/tests/unittests/test_variable.py | 29 +++++++++++++++++++ python/paddle/fluid/variable_index.py | 23 +++++++++++++-- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 4162fa43679..c1956545f55 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -245,6 +245,34 @@ class TestVariable(unittest.TestCase): with self.assertRaises(TypeError): res = x[[1.2, 0]] + def _test_slice_index_list_bool(self, place): + data = np.random.rand(2, 3).astype("float32") + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + idx0 = [True, False] + idx1 = [False, True] + idx2 = [False, False] + idx3 = [True, True] + + out0 = x[idx0] + out1 = x[idx1] + out2 = x[idx2] + out3 = x[idx3] + + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=[out0, out1, out2, out3]) + + expected = [data[idx0], data[idx1], data[idx2], data[idx3]] + + self.assertTrue((result[0] == expected[0]).all()) + self.assertTrue((result[1] == expected[1]).all()) + self.assertTrue((result[2] == expected[2]).all()) + self.assertTrue((result[3] == expected[3]).all()) + + with self.assertRaises(TypeError): + res = x[[True, 0]] + def test_slice(self): places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): @@ -255,6 +283,7 @@ class TestVariable(unittest.TestCase): self._test_slice_index_tensor(place) self._test_slice_index_list(place) self._test_slice_index_ellipsis(place) + self._test_slice_index_list_bool(place) def _tostring(self): b = default_main_program().current_block() diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index e289ae7f837..c6ddba7fead 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -140,19 +140,36 @@ def _getitem_impl_(var, item): end = MAX_INTEGER if end is None else end elif isinstance(slice_item, list): + is_bool_list = False for i in slice_item: - if not isinstance(i, int): - raise TypeError("Only support int value in list") + if not isinstance(i, (int, bool)): + raise TypeError("Only support int or bool in index list.") + + if isinstance(i, bool): + is_bool_list = True + break if len(item) != 1: raise IndexError( "When index contains a list, its length must be 1, but received {}". format(len(item))) + if is_bool_list: + new_slice_item = [] + for idx, ele in enumerate(slice_item): + if not isinstance(ele, bool): + raise TypeError( + "Mixed bool index with other types is not supported." + ) + + if ele is True: + new_slice_item.append(idx) + slice_item = new_slice_item + from .layers import assign from ..tensor import index_select - idx = assign(np.array(slice_item)) + idx = assign(np.array(slice_item).astype("int32")) return index_select(var, index=idx, axis=0) elif isinstance(slice_item, Variable): -- GitLab