From 9c52adefd664e8cd81dcd2916e59bbac334387be Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Wed, 2 Jun 2021 21:08:36 +0800 Subject: [PATCH] [slice getitem] Support getitem idx is Tensor or List (#33000) --- .../fluid/tests/unittests/test_variable.py | 71 +++++++++++++++++-- python/paddle/fluid/variable_index.py | 27 ++++++- 2 files changed, 93 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 71051689db..6b99e85591 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -164,12 +164,75 @@ class TestVariable(unittest.TestCase): self.assertTrue( np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1])) - def test_slice(self): - place = fluid.CPUPlace() - self._test_slice(place) + def _test_slice_index_tensor(self, place): + data = np.random.rand(2, 3).astype("float32") + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + idx0 = [1, 0] + idx1 = [0, 1] + idx2 = [0, 0] + idx3 = [1, 1] + + out0 = x[paddle.assign(np.array(idx0))] + out1 = x[paddle.assign(np.array(idx1))] + out2 = x[paddle.assign(np.array(idx2))] + out3 = x[paddle.assign(np.array(idx3))] + + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=[out0, out1, out2, out3]) + + expected = [data[idx0], data[idx1], data[idx2], data[idx3]] + + self.assertTrue((result[0] == expected[0]).all()) + self.assertTrue((result[1] == expected[1]).all()) + self.assertTrue((result[2] == expected[2]).all()) + self.assertTrue((result[3] == expected[3]).all()) + + with self.assertRaises(IndexError): + one = paddle.ones(shape=[1]) + res = x[one, [0, 0]] + + def _test_slice_index_list(self, place): + data = np.random.rand(2, 3).astype("float32") + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + idx0 = [1, 0] + idx1 = [0, 1] + idx2 = [0, 0] + idx3 = [1, 1] + + out0 = x[idx0] + out1 = x[idx1] + out2 = x[idx2] + out3 = x[idx3] + + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=[out0, out1, out2, out3]) + + expected = [data[idx0], data[idx1], data[idx2], data[idx3]] + + self.assertTrue((result[0] == expected[0]).all()) + self.assertTrue((result[1] == expected[1]).all()) + self.assertTrue((result[2] == expected[2]).all()) + self.assertTrue((result[3] == expected[3]).all()) + + with self.assertRaises(IndexError): + res = x[[1, 0], [0, 0]] + + with self.assertRaises(TypeError): + res = x[[1.2, 0]] + def test_slice(self): + places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): - self._test_slice(core.CUDAPlace(0)) + places.append(core.CUDAPlace(0)) + + for place in places: + self._test_slice(place) + self._test_slice_index_tensor(place) + self._test_slice_index_list(place) def _tostring(self): b = default_main_program().current_block() diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 242b5b14db..31545a8400 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -87,7 +87,7 @@ def _getitem_impl_(var, item): Returns: Sliced variable """ - from .framework import default_main_program + from .framework import default_main_program, Variable if not isinstance(item, tuple): item = (item, ) @@ -126,6 +126,31 @@ def _getitem_impl_(var, item): start = 0 if start is None else start end = MAX_INTEGER if end is None else end + elif isinstance(slice_item, list): + for i in slice_item: + if not isinstance(i, int): + raise TypeError("Only support int value in list") + + if len(item) != 1: + raise IndexError( + "When index contains a list, its length must be 1, but received {}". + format(len(item))) + + from .layers import assign + from ..tensor import index_select + + idx = assign(np.array(slice_item)) + return index_select(var, index=idx, axis=0) + + elif isinstance(slice_item, Variable): + if len(item) != 1: + raise IndexError( + "When index contains a Tensor, its length must be 1, but received {}". + format(len(item))) + + from ..tensor import index_select + return index_select(var, index=slice_item, axis=0) + else: raise IndexError( "Valid index accept int or slice or ellipsis, but received {}.". -- GitLab