未验证 提交 eb6d7da9 编写于 作者: F FlyingQianMM 提交者: GitHub

support getitem when index is a all-false bool tensor (#41297)

* support getitem when index is a all-false bool tensor

* use cond to replace if

* add static_graph geitem unit test when index is a bool tensor
上级 3e9ad093
......@@ -795,17 +795,17 @@ class TestVarBase(unittest.TestCase):
np_value = np.random.random(shape).astype('float32')
var_tensor = paddle.to_tensor(np_value)
index = [[True, True, True, True], [True, False, True, True],
[True, False, False, True], [False, 0, 1, True, True]]
[True, False, False, True], [False, 0, 1, True, True],
[False, False, False, False]]
index2d = np.array([[True, True], [False, False], [True, False],
[True, True]])
tensor_index = paddle.to_tensor(index2d)
var = [
var_tensor[index[0]].numpy(),
var_tensor[index[1]].numpy(),
var_tensor[index[2]].numpy(),
var_tensor[index[3]].numpy(),
var_tensor[index[0]].numpy(), var_tensor[index[1]].numpy(),
var_tensor[index[2]].numpy(), var_tensor[index[3]].numpy(),
var_tensor[paddle.to_tensor(index[0])].numpy(),
var_tensor[tensor_index].numpy(),
var_tensor[paddle.to_tensor(index[4])].numpy()
]
self.assertTrue(np.array_equal(var[0], np_value[index[0]]))
self.assertTrue(np.array_equal(var[1], np_value[index[1]]))
......@@ -813,6 +813,7 @@ class TestVarBase(unittest.TestCase):
self.assertTrue(np.array_equal(var[3], np_value[index[3]]))
self.assertTrue(np.array_equal(var[4], np_value[index[0]]))
self.assertTrue(np.array_equal(var[5], np_value[index2d]))
self.assertTrue(np.array_equal(var[6], np_value[index[4]]))
self.assertTrue(
np.array_equal(var_tensor[var_tensor > 0.67], np_value[np_value >
0.67]))
......
......@@ -690,6 +690,61 @@ class TestListIndex(unittest.TestCase):
y = x[index_t1, index_t2]
self.assertTrue(np.array_equal(y.numpy(), y_np))
def run_getitem_list_index(self, array, index):
x = paddle.static.data(name='x', shape=array.shape, dtype='float32')
y = x[index]
place = paddle.fluid.CPUPlace()
prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
fetch_list = [y.name]
array2 = array.copy()
try:
value_np = array2[index]
except:
with self.assertRaises(ValueError):
getitem_pp = exe.run(prog,
feed={x.name: array},
fetch_list=fetch_list)
return
getitem_pp = exe.run(prog, feed={x.name: array}, fetch_list=fetch_list)
print(getitem_pp)
self.assertTrue(
np.array_equal(value_np, getitem_pp[0]),
msg='\n numpy:{},\n paddle:{}'.format(value_np, getitem_pp[0]))
def test_static_graph_getitem_bool_index(self):
paddle.enable_static()
# case 1:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([True, False, False, False])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_getitem_list_index(array, index)
# case 2:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([False, True, False, False])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_getitem_list_index(array, index)
# case 3:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([True, True, True, True])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_getitem_list_index(array, index)
def run_setitem_list_index(self, array, index, value_np):
x = paddle.static.data(name='x', shape=array.shape, dtype='float32')
......
......@@ -279,6 +279,37 @@ def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags):
attrs[attr_name] = attr
# the item is a tensor of bool
def get_value_for_bool_tensor(var, item):
if len(item.shape) > len(var.shape):
raise IndexError("The dims of bool index doesn't match indexed array, "
"the dims of bool index except to be equal or less "
"than {}, but received {}.".format(
len(var.shape), len(item.shape)))
for i, dim_len in enumerate(item.shape):
if dim_len != var.shape[i]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "\
"dimension {}, the target dimension is {}, but received {}.".
format(i, var.shape[i], dim_len))
def idx_not_empty(var, item):
from .layers.nn import where
from ..tensor import gather_nd
bool_2_idx = where(item == True)
return gather_nd(var, bool_2_idx)
def idx_empty(var):
var_shape = list(var.shape)
var_shape[0] = 0
return paddle.empty(var_shape, dtype=var.dtype)
from .layers.control_flow import cond
return cond(item.any(), lambda: idx_not_empty(var, item),
lambda: idx_empty(var))
def _getitem_impl_(var, item):
"""
Slice the variable.
......@@ -393,24 +424,10 @@ def _getitem_impl_(var, item):
elif isinstance(slice_item, (Variable, core.eager.Tensor)):
if len(item) == 1:
from ..tensor import index_select, gather_nd
from .layers.nn import where
from ..tensor import index_select
if slice_item.dtype == paddle.bool:
if len(slice_item.shape) > len(var.shape):
raise IndexError(
"The dims of bool index doesn't match indexed array, "
"the dims of bool index except to be equal or less "
"than {}, but received {}.".format(
len(var.shape), len(slice_item.shape)))
for i, dim_len in enumerate(slice_item.shape):
if dim_len != var.shape[i]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "\
"dimension {}, the target dimension is {}, but received {}.".
format(i, var.shape[i], dim_len))
bool_2_idx = where(slice_item == True)
return gather_nd(var, bool_2_idx)
return get_value_for_bool_tensor(var, slice_item)
else:
if len(slice_item.shape) == 1:
return index_select(var, index=slice_item, axis=0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册