From 79b9f47e40cdbaa81d1401dc1be2b7a6bd51c13a Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 16 Nov 2021 16:07:13 +0800 Subject: [PATCH] fix bug of indexing with ellipsis (#37192) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复了一维Tensor在使用省略号(...)索引时维度检测异常的问题。 --- paddle/fluid/pybind/imperative.cc | 9 ++++++++- python/paddle/fluid/tests/unittests/test_var_base.py | 5 +++++ python/paddle/fluid/tests/unittests/test_variable.py | 8 ++++++-- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 4403eb46972..60b97b76491 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -528,13 +528,20 @@ static void ParseIndexingSlice( // specified_dims is the number of dimensions which indexed by Interger, // Slices. int specified_dims = 0; + int ell_count = 0; for (int dim = 0; dim < size; ++dim) { PyObject *slice_item = PyTuple_GetItem(index, dim); if (PyCheckInteger(slice_item) || PySlice_Check(slice_item)) { specified_dims++; + } else if (slice_item == Py_Ellipsis) { + ell_count++; } } + PADDLE_ENFORCE_LE(ell_count, 1, + platform::errors::InvalidArgument( + "An index can only have a single ellipsis ('...')")); + for (int i = 0, dim = 0; i < size; ++i) { PyObject *slice_item = PyTuple_GetItem(index, i); @@ -639,7 +646,7 @@ static void ParseIndexingSlice( } // valid_index is the number of dimensions exclude None index - const int valid_indexs = size - none_axes->size(); + const int valid_indexs = size - none_axes->size() - ell_count; PADDLE_ENFORCE_EQ(valid_indexs <= rank, true, platform::errors::InvalidArgument( "Too many indices (%d) for tensor of dimension %d.", diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index cfaef15c1d3..e4ef7428202 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -702,6 +702,11 @@ class TestVarBase(unittest.TestCase): assert_getitem_ellipsis_index(var_fp32, np_fp32_value) assert_getitem_ellipsis_index(var_int, np_int_value) + # test 1 dim tensor + var_one_dim = paddle.to_tensor([1, 2, 3, 4]) + self.assertTrue( + np.array_equal(var_one_dim[..., 0].numpy(), np.array([1]))) + def _test_none_index(self): shape = (8, 64, 5, 256) np_value = np.random.random(shape).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index e9e959266db..2df336fbe8e 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -226,19 +226,22 @@ class TestVariable(unittest.TestCase): prog = paddle.static.Program() with paddle.static.program_guard(prog): x = paddle.assign(data) + y = paddle.assign([1, 2, 3, 4]) out1 = x[0:, ..., 1:] out2 = x[0:, ...] out3 = x[..., 1:] out4 = x[...] out5 = x[[1, 0], [0, 0]] out6 = x[([1, 0], [0, 0])] + out7 = y[..., 0] exe = paddle.static.Executor(place) - result = exe.run(prog, fetch_list=[out1, out2, out3, out4, out5, out6]) + result = exe.run(prog, + fetch_list=[out1, out2, out3, out4, out5, out6, out7]) expected = [ data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...], - data[[1, 0], [0, 0]], data[([1, 0], [0, 0])] + data[[1, 0], [0, 0]], data[([1, 0], [0, 0])], np.array([1]) ] self.assertTrue((result[0] == expected[0]).all()) @@ -247,6 +250,7 @@ class TestVariable(unittest.TestCase): self.assertTrue((result[3] == expected[3]).all()) self.assertTrue((result[4] == expected[4]).all()) self.assertTrue((result[5] == expected[5]).all()) + self.assertTrue((result[6] == expected[6]).all()) with self.assertRaises(IndexError): res = x[[1.2, 0]] -- GitLab