From 12c15bebe421cd9f1aa1c63fc15310c91f8857d3 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Mon, 8 Feb 2021 11:25:16 +0800 Subject: [PATCH] [Static setitem] Support index is ellipsis for setitem in static mode (#30836) --- python/paddle/fluid/framework.py | 29 +++++++++++++ .../tests/unittests/test_set_value_op.py | 41 ++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 508afac2cd..43e2733162 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1866,6 +1866,35 @@ class Variable(object): starts = [] ends = [] max_integer = sys.maxsize + + def replace_ellipsis(item): + # Use slice(None) to replace Ellipsis. + # For var, var.shape = [3,4,5,6] + # + # var[..., 1:2] -> var[:, :, :, 1:2] + # var[0, ...] -> var[0] + # var[0, ..., 1:2] -> var[0, :, :, 1:2] + + item = list(item) + ell_count = item.count(Ellipsis) + if ell_count == 0: + return item + elif ell_count > 1: + raise IndexError( + "An index can only have a single ellipsis ('...')") + + ell_idx = item.index(Ellipsis) + + if ell_idx == len(item) - 1: + return item[:-1] + else: + item[ell_idx:ell_idx + 1] = [slice(None)] * ( + len(self.shape) - len(item) + 1) + + return item + + item = replace_ellipsis(item) + for dim, slice_item in enumerate(item): if isinstance(slice_item, slice): start = slice_item.start diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index aca685a410..79b270f162 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -52,7 +52,6 @@ class TestSetValueApi(TestSetValueBase): exe = paddle.static.Executor(paddle.CPUPlace()) out = exe.run(self.program, fetch_list=[x]) - self._get_answer() self.assertTrue( (self.data == out).all(), @@ -60,7 +59,7 @@ class TestSetValueApi(TestSetValueBase): self.data, out)) -# 1. Test different type of item: int, python slice +# 1. Test different type of item: int, python slice, Ellipsis class TestSetValueItemInt(TestSetValueApi): def _call_setitem(self, x): x[0] = self.value @@ -101,6 +100,38 @@ class TestSetValueItemSlice4(TestSetValueApi): self.data[0:, 1:2, :] = self.value +class TestSetValueItemEllipsis1(TestSetValueApi): + def _call_setitem(self, x): + x[0:, ..., 1:] = self.value + + def _get_answer(self): + self.data[0:, ..., 1:] = self.value + + +class TestSetValueItemEllipsis2(TestSetValueApi): + def _call_setitem(self, x): + x[0:, ...] = self.value + + def _get_answer(self): + self.data[0:, ...] = self.value + + +class TestSetValueItemEllipsis3(TestSetValueApi): + def _call_setitem(self, x): + x[..., 1:] = self.value + + def _get_answer(self): + self.data[..., 1:] = self.value + + +class TestSetValueItemEllipsis4(TestSetValueApi): + def _call_setitem(self, x): + x[...] = self.value + + def _get_answer(self): + self.data[...] = self.value + + # 2. Test different type of value: int, float, numpy.ndarray, Tensor # 2.1 value is int32, int64, float32, float64, bool @@ -499,6 +530,12 @@ class TestError(TestSetValueBase): x = paddle.ones(shape=self.shape, dtype=self.dtype) x[0:1:2] = self.value + def _ellipsis_error(self): + with self.assertRaisesRegexp( + IndexError, "An index can only have a single ellipsis"): + x = paddle.ones(shape=self.shape, dtype=self.dtype) + x[..., ...] = self.value + def _broadcast_mismatch(self): program = paddle.static.Program() with paddle.static.program_guard(program): -- GitLab