未验证 提交 9c52adef 编写于 作者: L liym27 提交者: GitHub

[slice getitem] Support getitem idx is Tensor or List (#33000)

上级 b30a7e31
...@@ -164,12 +164,75 @@ class TestVariable(unittest.TestCase): ...@@ -164,12 +164,75 @@ class TestVariable(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1])) np.array_equal(local_out[15], tensor_array[::-1, ::-1, ::-1]))
def test_slice(self): def _test_slice_index_tensor(self, place):
place = fluid.CPUPlace() data = np.random.rand(2, 3).astype("float32")
self._test_slice(place) prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
idx0 = [1, 0]
idx1 = [0, 1]
idx2 = [0, 0]
idx3 = [1, 1]
out0 = x[paddle.assign(np.array(idx0))]
out1 = x[paddle.assign(np.array(idx1))]
out2 = x[paddle.assign(np.array(idx2))]
out3 = x[paddle.assign(np.array(idx3))]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out0, out1, out2, out3])
expected = [data[idx0], data[idx1], data[idx2], data[idx3]]
self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())
with self.assertRaises(IndexError):
one = paddle.ones(shape=[1])
res = x[one, [0, 0]]
def _test_slice_index_list(self, place):
data = np.random.rand(2, 3).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
idx0 = [1, 0]
idx1 = [0, 1]
idx2 = [0, 0]
idx3 = [1, 1]
out0 = x[idx0]
out1 = x[idx1]
out2 = x[idx2]
out3 = x[idx3]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out0, out1, out2, out3])
expected = [data[idx0], data[idx1], data[idx2], data[idx3]]
self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())
with self.assertRaises(IndexError):
res = x[[1, 0], [0, 0]]
with self.assertRaises(TypeError):
res = x[[1.2, 0]]
def test_slice(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self._test_slice(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
for place in places:
self._test_slice(place)
self._test_slice_index_tensor(place)
self._test_slice_index_list(place)
def _tostring(self): def _tostring(self):
b = default_main_program().current_block() b = default_main_program().current_block()
......
...@@ -87,7 +87,7 @@ def _getitem_impl_(var, item): ...@@ -87,7 +87,7 @@ def _getitem_impl_(var, item):
Returns: Returns:
Sliced variable Sliced variable
""" """
from .framework import default_main_program from .framework import default_main_program, Variable
if not isinstance(item, tuple): if not isinstance(item, tuple):
item = (item, ) item = (item, )
...@@ -126,6 +126,31 @@ def _getitem_impl_(var, item): ...@@ -126,6 +126,31 @@ def _getitem_impl_(var, item):
start = 0 if start is None else start start = 0 if start is None else start
end = MAX_INTEGER if end is None else end end = MAX_INTEGER if end is None else end
elif isinstance(slice_item, list):
for i in slice_item:
if not isinstance(i, int):
raise TypeError("Only support int value in list")
if len(item) != 1:
raise IndexError(
"When index contains a list, its length must be 1, but received {}".
format(len(item)))
from .layers import assign
from ..tensor import index_select
idx = assign(np.array(slice_item))
return index_select(var, index=idx, axis=0)
elif isinstance(slice_item, Variable):
if len(item) != 1:
raise IndexError(
"When index contains a Tensor, its length must be 1, but received {}".
format(len(item)))
from ..tensor import index_select
return index_select(var, index=slice_item, axis=0)
else: else:
raise IndexError( raise IndexError(
"Valid index accept int or slice or ellipsis, but received {}.". "Valid index accept int or slice or ellipsis, but received {}.".
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册