未验证 提交 12c15beb 编写于 作者: L liym27 提交者: GitHub

[Static setitem] Support index is ellipsis for setitem in static mode (#30836)

上级 97f7a70c
...@@ -1866,6 +1866,35 @@ class Variable(object): ...@@ -1866,6 +1866,35 @@ class Variable(object):
starts = [] starts = []
ends = [] ends = []
max_integer = sys.maxsize max_integer = sys.maxsize
def replace_ellipsis(item):
# Use slice(None) to replace Ellipsis.
# For var, var.shape = [3,4,5,6]
#
# var[..., 1:2] -> var[:, :, :, 1:2]
# var[0, ...] -> var[0]
# var[0, ..., 1:2] -> var[0, :, :, 1:2]
item = list(item)
ell_count = item.count(Ellipsis)
if ell_count == 0:
return item
elif ell_count > 1:
raise IndexError(
"An index can only have a single ellipsis ('...')")
ell_idx = item.index(Ellipsis)
if ell_idx == len(item) - 1:
return item[:-1]
else:
item[ell_idx:ell_idx + 1] = [slice(None)] * (
len(self.shape) - len(item) + 1)
return item
item = replace_ellipsis(item)
for dim, slice_item in enumerate(item): for dim, slice_item in enumerate(item):
if isinstance(slice_item, slice): if isinstance(slice_item, slice):
start = slice_item.start start = slice_item.start
......
...@@ -52,7 +52,6 @@ class TestSetValueApi(TestSetValueBase): ...@@ -52,7 +52,6 @@ class TestSetValueApi(TestSetValueBase):
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
out = exe.run(self.program, fetch_list=[x]) out = exe.run(self.program, fetch_list=[x])
self._get_answer() self._get_answer()
self.assertTrue( self.assertTrue(
(self.data == out).all(), (self.data == out).all(),
...@@ -60,7 +59,7 @@ class TestSetValueApi(TestSetValueBase): ...@@ -60,7 +59,7 @@ class TestSetValueApi(TestSetValueBase):
self.data, out)) self.data, out))
# 1. Test different type of item: int, python slice # 1. Test different type of item: int, python slice, Ellipsis
class TestSetValueItemInt(TestSetValueApi): class TestSetValueItemInt(TestSetValueApi):
def _call_setitem(self, x): def _call_setitem(self, x):
x[0] = self.value x[0] = self.value
...@@ -101,6 +100,38 @@ class TestSetValueItemSlice4(TestSetValueApi): ...@@ -101,6 +100,38 @@ class TestSetValueItemSlice4(TestSetValueApi):
self.data[0:, 1:2, :] = self.value self.data[0:, 1:2, :] = self.value
class TestSetValueItemEllipsis1(TestSetValueApi):
def _call_setitem(self, x):
x[0:, ..., 1:] = self.value
def _get_answer(self):
self.data[0:, ..., 1:] = self.value
class TestSetValueItemEllipsis2(TestSetValueApi):
def _call_setitem(self, x):
x[0:, ...] = self.value
def _get_answer(self):
self.data[0:, ...] = self.value
class TestSetValueItemEllipsis3(TestSetValueApi):
def _call_setitem(self, x):
x[..., 1:] = self.value
def _get_answer(self):
self.data[..., 1:] = self.value
class TestSetValueItemEllipsis4(TestSetValueApi):
def _call_setitem(self, x):
x[...] = self.value
def _get_answer(self):
self.data[...] = self.value
# 2. Test different type of value: int, float, numpy.ndarray, Tensor # 2. Test different type of value: int, float, numpy.ndarray, Tensor
# 2.1 value is int32, int64, float32, float64, bool # 2.1 value is int32, int64, float32, float64, bool
...@@ -499,6 +530,12 @@ class TestError(TestSetValueBase): ...@@ -499,6 +530,12 @@ class TestError(TestSetValueBase):
x = paddle.ones(shape=self.shape, dtype=self.dtype) x = paddle.ones(shape=self.shape, dtype=self.dtype)
x[0:1:2] = self.value x[0:1:2] = self.value
def _ellipsis_error(self):
with self.assertRaisesRegexp(
IndexError, "An index can only have a single ellipsis"):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
x[..., ...] = self.value
def _broadcast_mismatch(self): def _broadcast_mismatch(self):
program = paddle.static.Program() program = paddle.static.Program()
with paddle.static.program_guard(program): with paddle.static.program_guard(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册