diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 619301e3b45d3116a545dd16ef1d5dc165a4f210..7b99c7df188f3550524476ba23d6341c870cfbb4 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -432,19 +432,24 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, const auto &shape = tensor->dims(); const int rank = shape.size(); const int size = PyTuple_GET_SIZE(index); + + // specified_dims is the number of dimensions which indexed by Interger, + // Slices. + int specified_dims = 0; + for (int dim = 0; dim < size; ++dim) { + PyObject *slice_item = PyTuple_GetItem(index, dim); + if (PyCheckInteger(slice_item) || PySlice_Check(slice_item)) { + specified_dims++; + } + } + PADDLE_ENFORCE_EQ( size <= rank, true, platform::errors::InvalidArgument( "too many indices (%d) for tensor of dimension %d", size, rank)); - for (int dim = 0; dim < size; ++dim) { - PyObject *slice_item = PyTuple_GetItem(index, dim); - PADDLE_ENFORCE_EQ(PyCheckInteger(slice_item) || PySlice_Check(slice_item), - true, - platform::errors::InvalidArgument( - "Currently, VarBase.__getitem__() only allows " - "indexing by Integers, Slices, and tuples of " - "these types, but received %s in %dth slice item", - std::string(Py_TYPE(slice_item)->tp_name), dim + 1)); + for (int i = 0, dim = 0; i < size; ++i) { + PyObject *slice_item = PyTuple_GetItem(index, i); + infer_flags->push_back(1); int dim_len = shape[dim]; if (PyCheckInteger(slice_item)) { @@ -467,7 +472,8 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, slice_ends->push_back(start + 1); slice_strides->push_back(1); decrease_axis->push_back(dim); - } else { + dim++; + } else if (PySlice_Check(slice_item)) { // slice item Py_ssize_t start, end, step; PySliceObject *p = reinterpret_cast(slice_item); @@ -475,12 +481,22 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, // :: or : or 0:dim_len:1 if (start == 0 && end == dim_len && step == 1) { + dim++; continue; } slice_axes->push_back(dim); slice_starts->push_back(start); slice_ends->push_back(end); slice_strides->push_back(step); + dim++; + } else if (slice_item == Py_Ellipsis) { + dim += rank - specified_dims; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Currently, VarBase.__getitem__() only allows " + "indexing by Integers, Slices, Ellipsis, and tuples of " + "these types, but received %s in %dth slice item", + std::string(Py_TYPE(slice_item)->tp_name), i + 1)); } } if (!PyTuple_Check(_index)) Py_DecRef(index); diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 644e46f10815890a56d35708747447af72612497..87594f0f2d0be99715ce64979600ab82ae2d8dd7 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -652,6 +652,43 @@ class TestVarBase(unittest.TestCase): np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1])) self.assertTrue(np.array_equal(local_out[16], tensor_array[-4:4])) + def _test_for_getitem_ellipsis_index(self): + shape = (64, 3, 5, 256) + np_fp32_value = np.random.random(shape).astype('float32') + np_int_value = np.random.randint(1, 100, shape) + + var_fp32 = paddle.to_tensor(np_fp32_value) + var_int = paddle.to_tensor(np_int_value) + + def assert_getitem_ellipsis_index(var_tensor, var_np): + var = [ + var_tensor[..., 0].numpy(), + var_tensor[..., 1, 0].numpy(), + var_tensor[0, ..., 1, 0].numpy(), + var_tensor[1, ..., 1].numpy(), + var_tensor[2, ...].numpy(), + var_tensor[2, 0, ...].numpy(), + var_tensor[2, 0, 1, ...].numpy(), + var_tensor[...].numpy(), + var_tensor[:, ..., 100].numpy(), + ] + + self.assertTrue(np.array_equal(var[0], var_np[..., 0])) + self.assertTrue(np.array_equal(var[1], var_np[..., 1, 0])) + self.assertTrue(np.array_equal(var[2], var_np[0, ..., 1, 0])) + self.assertTrue(np.array_equal(var[3], var_np[1, ..., 1])) + self.assertTrue(np.array_equal(var[4], var_np[2, ...])) + self.assertTrue(np.array_equal(var[5], var_np[2, 0, ...])) + self.assertTrue(np.array_equal(var[6], var_np[2, 0, 1, ...])) + self.assertTrue(np.array_equal(var[7], var_np[...])) + self.assertTrue(np.array_equal(var[8], var_np[:, ..., 100])) + + var_fp32 = paddle.to_tensor(np_fp32_value) + var_int = paddle.to_tensor(np_int_value) + + assert_getitem_ellipsis_index(var_fp32, np_fp32_value) + assert_getitem_ellipsis_index(var_int, np_int_value) + def _test_for_var(self): np_value = np.random.random((30, 100, 100)).astype('float32') w = fluid.dygraph.to_variable(np_value) @@ -664,6 +701,7 @@ class TestVarBase(unittest.TestCase): self._test_slice() self._test_slice_for_tensor_attr() self._test_for_var() + self._test_for_getitem_ellipsis_index() var = fluid.dygraph.to_variable(self.array) self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :]))