未验证 提交 a039fd7b 编写于 作者: L liym27 提交者: GitHub

[Static getitem] Support static Variable getitem for Ellipsis index (#32876)

上级 741811e0
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import paddle import paddle
from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -218,6 +219,26 @@ class TestVariable(unittest.TestCase): ...@@ -218,6 +219,26 @@ class TestVariable(unittest.TestCase):
self.assertTrue((result[2] == expected[2]).all()) self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all()) self.assertTrue((result[3] == expected[3]).all())
def _test_slice_index_ellipsis(self, place):
data = np.random.rand(2, 3, 4).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
out1 = x[0:, ..., 1:]
out2 = x[0:, ...]
out3 = x[..., 1:]
out4 = x[...]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out1, out2, out3, out4])
expected = [data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...]]
self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
res = x[[1, 0], [0, 0]] res = x[[1, 0], [0, 0]]
...@@ -233,6 +254,7 @@ class TestVariable(unittest.TestCase): ...@@ -233,6 +254,7 @@ class TestVariable(unittest.TestCase):
self._test_slice(place) self._test_slice(place)
self._test_slice_index_tensor(place) self._test_slice_index_tensor(place)
self._test_slice_index_list(place) self._test_slice_index_list(place)
self._test_slice_index_ellipsis(place)
def _tostring(self): def _tostring(self):
b = default_main_program().current_block() b = default_main_program().current_block()
......
...@@ -112,6 +112,7 @@ def _getitem_impl_(var, item): ...@@ -112,6 +112,7 @@ def _getitem_impl_(var, item):
use_strided_slice = False use_strided_slice = False
item, none_axes = replace_none(item) item, none_axes = replace_none(item)
item = replace_ellipsis(var, item)
for dim, slice_item in enumerate(item): for dim, slice_item in enumerate(item):
if is_integer_or_scalar_tensor(slice_item): if is_integer_or_scalar_tensor(slice_item):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册