diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 6ffecd33f8f48d69ffc7593cb684a93f2d4be226..4162fa436798f6c5be6705a327fa1aa40344a692 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import paddle from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode +import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.core as core @@ -218,6 +219,26 @@ class TestVariable(unittest.TestCase): self.assertTrue((result[2] == expected[2]).all()) self.assertTrue((result[3] == expected[3]).all()) + def _test_slice_index_ellipsis(self, place): + data = np.random.rand(2, 3, 4).astype("float32") + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + out1 = x[0:, ..., 1:] + out2 = x[0:, ...] + out3 = x[..., 1:] + out4 = x[...] + + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=[out1, out2, out3, out4]) + + expected = [data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...]] + + self.assertTrue((result[0] == expected[0]).all()) + self.assertTrue((result[1] == expected[1]).all()) + self.assertTrue((result[2] == expected[2]).all()) + self.assertTrue((result[3] == expected[3]).all()) + with self.assertRaises(IndexError): res = x[[1, 0], [0, 0]] @@ -233,6 +254,7 @@ class TestVariable(unittest.TestCase): self._test_slice(place) self._test_slice_index_tensor(place) self._test_slice_index_list(place) + self._test_slice_index_ellipsis(place) def _tostring(self): b = default_main_program().current_block() diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index aed8c82d43b4dda373d30916ac291b4eff8a1064..e289ae7f837d5ef939292ef2d5d6d1b6c376c283 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -112,6 +112,7 @@ def _getitem_impl_(var, item): use_strided_slice = False item, none_axes = replace_none(item) + item = replace_ellipsis(var, item) for dim, slice_item in enumerate(item): if is_integer_or_scalar_tensor(slice_item):