未验证 提交 a2256366 编写于 作者: L liym27 提交者: GitHub

[static getitem]Support index is list bool for getitem in static mode (#33298)

上级 11b57760
......@@ -245,6 +245,34 @@ class TestVariable(unittest.TestCase):
with self.assertRaises(TypeError):
res = x[[1.2, 0]]
def _test_slice_index_list_bool(self, place):
data = np.random.rand(2, 3).astype("float32")
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
idx0 = [True, False]
idx1 = [False, True]
idx2 = [False, False]
idx3 = [True, True]
out0 = x[idx0]
out1 = x[idx1]
out2 = x[idx2]
out3 = x[idx3]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out0, out1, out2, out3])
expected = [data[idx0], data[idx1], data[idx2], data[idx3]]
self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())
with self.assertRaises(TypeError):
res = x[[True, 0]]
def test_slice(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
......@@ -255,6 +283,7 @@ class TestVariable(unittest.TestCase):
self._test_slice_index_tensor(place)
self._test_slice_index_list(place)
self._test_slice_index_ellipsis(place)
self._test_slice_index_list_bool(place)
def _tostring(self):
b = default_main_program().current_block()
......
......@@ -140,19 +140,36 @@ def _getitem_impl_(var, item):
end = MAX_INTEGER if end is None else end
elif isinstance(slice_item, list):
is_bool_list = False
for i in slice_item:
if not isinstance(i, int):
raise TypeError("Only support int value in list")
if not isinstance(i, (int, bool)):
raise TypeError("Only support int or bool in index list.")
if isinstance(i, bool):
is_bool_list = True
break
if len(item) != 1:
raise IndexError(
"When index contains a list, its length must be 1, but received {}".
format(len(item)))
if is_bool_list:
new_slice_item = []
for idx, ele in enumerate(slice_item):
if not isinstance(ele, bool):
raise TypeError(
"Mixed bool index with other types is not supported."
)
if ele is True:
new_slice_item.append(idx)
slice_item = new_slice_item
from .layers import assign
from ..tensor import index_select
idx = assign(np.array(slice_item))
idx = assign(np.array(slice_item).astype("int32"))
return index_select(var, index=idx, axis=0)
elif isinstance(slice_item, Variable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册