未验证 提交 0f5e90a2 编写于 作者: F FlyingQianMM 提交者: GitHub

support get_item where the index is a bool scalar tensor (#40829)

* support get_item where the index is a bool scalar tensor

* add unittests for supporting get_item where the index is a bool scalar tensor
上级 8df91763
...@@ -829,6 +829,15 @@ class TestVarBase(unittest.TestCase): ...@@ -829,6 +829,15 @@ class TestVarBase(unittest.TestCase):
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
var_tensor[paddle.to_tensor([[True, False, False, False]])] var_tensor[paddle.to_tensor([[True, False, False, False]])]
def _test_scalar_bool_index(self):
shape = (1, 2, 5, 64)
np_value = np.random.random(shape).astype('float32')
var_tensor = paddle.to_tensor(np_value)
index = [True]
tensor_index = paddle.to_tensor(index)
var = [var_tensor[tensor_index].numpy(), ]
self.assertTrue(np.array_equal(var[0], np_value[index]))
def _test_for_var(self): def _test_for_var(self):
np_value = np.random.random((30, 100, 100)).astype('float32') np_value = np.random.random((30, 100, 100)).astype('float32')
w = fluid.dygraph.to_variable(np_value) w = fluid.dygraph.to_variable(np_value)
...@@ -883,6 +892,7 @@ class TestVarBase(unittest.TestCase): ...@@ -883,6 +892,7 @@ class TestVarBase(unittest.TestCase):
self._test_for_getitem_ellipsis_index() self._test_for_getitem_ellipsis_index()
self._test_none_index() self._test_none_index()
self._test_bool_index() self._test_bool_index()
self._test_scalar_bool_index()
self._test_numpy_index() self._test_numpy_index()
self._test_list_index() self._test_list_index()
...@@ -1219,6 +1229,88 @@ class TestVarBaseSetitemFp64(TestVarBaseSetitem): ...@@ -1219,6 +1229,88 @@ class TestVarBaseSetitemFp64(TestVarBaseSetitem):
self.dtype = "float64" self.dtype = "float64"
class TestVarBaseSetitemBoolIndex(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.set_dtype()
self.set_input()
def set_input(self):
self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(self.dtype))
self.np_value = np.random.random((2, 3)).astype(self.dtype)
self.tensor_value = paddle.to_tensor(self.np_value)
def set_dtype(self):
self.dtype = "int32"
def _test(self, value):
paddle.disable_static()
self.assertEqual(self.tensor_x.inplace_version, 0)
id_origin = id(self.tensor_x)
index_1 = paddle.to_tensor(np.array([True, False, False, False]))
self.tensor_x[index_1] = value
self.assertEqual(self.tensor_x.inplace_version, 1)
if isinstance(value, (six.integer_types, float)):
result = np.zeros((2, 3)).astype(self.dtype) + value
else:
result = self.np_value
self.assertTrue(np.array_equal(self.tensor_x[0].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))
index_2 = paddle.to_tensor(np.array([False, True, False, False]))
self.tensor_x[index_2] = value
self.assertEqual(self.tensor_x.inplace_version, 2)
self.assertTrue(np.array_equal(self.tensor_x[1].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))
index_3 = paddle.to_tensor(np.array([True, True, True, True]))
self.tensor_x[index_3] = value
self.assertEqual(self.tensor_x.inplace_version, 3)
self.assertTrue(np.array_equal(self.tensor_x[3].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))
def test_value_tensor(self):
paddle.disable_static()
self._test(self.tensor_value)
def test_value_numpy(self):
paddle.disable_static()
self._test(self.np_value)
def test_value_int(self):
paddle.disable_static()
self._test(10)
class TestVarBaseSetitemBoolScalarIndex(unittest.TestCase):
def set_input(self):
self.tensor_x = paddle.to_tensor(np.ones((1, 2, 3)).astype(self.dtype))
self.np_value = np.random.random((2, 3)).astype(self.dtype)
self.tensor_value = paddle.to_tensor(self.np_value)
def _test(self, value):
paddle.disable_static()
self.assertEqual(self.tensor_x.inplace_version, 0)
id_origin = id(self.tensor_x)
index = paddle.to_tensor(np.array([True]))
self.tensor_x[index] = value
self.assertEqual(self.tensor_x.inplace_version, 1)
if isinstance(value, (six.integer_types, float)):
result = np.zeros((2, 3)).astype(self.dtype) + value
else:
result = self.np_value
self.assertTrue(np.array_equal(self.tensor_x[0].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))
class TestVarBaseInplaceVersion(unittest.TestCase): class TestVarBaseInplaceVersion(unittest.TestCase):
def test_setitem(self): def test_setitem(self):
paddle.disable_static() paddle.disable_static()
......
...@@ -336,6 +336,23 @@ class TestVariable(unittest.TestCase): ...@@ -336,6 +336,23 @@ class TestVariable(unittest.TestCase):
with paddle.static.program_guard(prog): with paddle.static.program_guard(prog):
res = x[[False, False]] res = x[[False, False]]
def _test_slice_index_scalar_bool(self, place):
data = np.random.rand(1, 3, 4).astype("float32")
np_idx = np.array([True])
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
idx = paddle.assign(np_idx)
out = x[idx]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out])
expected = [data[np_idx]]
self.assertTrue((result[0] == expected[0]).all())
def test_slice(self): def test_slice(self):
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
...@@ -347,6 +364,7 @@ class TestVariable(unittest.TestCase): ...@@ -347,6 +364,7 @@ class TestVariable(unittest.TestCase):
self._test_slice_index_list(place) self._test_slice_index_list(place)
self._test_slice_index_ellipsis(place) self._test_slice_index_ellipsis(place)
self._test_slice_index_list_bool(place) self._test_slice_index_list_bool(place)
self._test_slice_index_scalar_bool(place)
def _tostring(self): def _tostring(self):
b = default_main_program().current_block() b = default_main_program().current_block()
...@@ -705,7 +723,7 @@ class TestListIndex(unittest.TestCase): ...@@ -705,7 +723,7 @@ class TestListIndex(unittest.TestCase):
fetch_list=fetch_list) fetch_list=fetch_list)
self.assertTrue( self.assertTrue(
np.array_equal(array2, setitem_pp[0]), np.allclose(array2, setitem_pp[0]),
msg='\n numpy:{},\n paddle:{}'.format(array2, setitem_pp[0])) msg='\n numpy:{},\n paddle:{}'.format(array2, setitem_pp[0]))
def test_static_graph_setitem_list_index(self): def test_static_graph_setitem_list_index(self):
...@@ -769,6 +787,42 @@ class TestListIndex(unittest.TestCase): ...@@ -769,6 +787,42 @@ class TestListIndex(unittest.TestCase):
index_mod = (index % (min(array.shape))).tolist() index_mod = (index % (min(array.shape))).tolist()
self.run_setitem_list_index(array, index_mod, value_np) self.run_setitem_list_index(array, index_mod, value_np)
def test_static_graph_setitem_bool_index(self):
paddle.enable_static()
# case 1:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([True, False, False, False])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_setitem_list_index(array, index, value_np)
# case 2:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([False, True, False, False])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_setitem_list_index(array, index, value_np)
# case 3:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([True, True, True, True])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_setitem_list_index(array, index, value_np)
def test_static_graph_setitem_bool_scalar_index(self):
paddle.enable_static()
array = np.ones((1, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([True])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_setitem_list_index(array, index, value_np)
def test_static_graph_tensor_index_setitem_muti_dim(self): def test_static_graph_tensor_index_setitem_muti_dim(self):
paddle.enable_static() paddle.enable_static()
inps_shape = [3, 4, 5, 4] inps_shape = [3, 4, 5, 4]
......
...@@ -255,6 +255,13 @@ def is_integer_or_scalar_tensor(ele): ...@@ -255,6 +255,13 @@ def is_integer_or_scalar_tensor(ele):
return False return False
def is_bool_tensor(ele):
from .framework import Variable
if isinstance(ele, Variable) and ele.dtype == paddle.bool:
return True
return False
def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags): def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags):
from .framework import Variable from .framework import Variable
from .layers import utils from .layers import utils
...@@ -304,7 +311,8 @@ def _getitem_impl_(var, item): ...@@ -304,7 +311,8 @@ def _getitem_impl_(var, item):
slice_info = SliceInfo() slice_info = SliceInfo()
for dim, slice_item in enumerate(item): for dim, slice_item in enumerate(item):
if is_integer_or_scalar_tensor(slice_item): if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor(
slice_item):
if isinstance(slice_item, if isinstance(slice_item,
int) and var.shape[dim] is not None and var.shape[ int) and var.shape[dim] is not None and var.shape[
dim] >= 0 and slice_item >= var.shape[dim]: dim] >= 0 and slice_item >= var.shape[dim]:
...@@ -523,7 +531,8 @@ def _setitem_impl_(var, item, value): ...@@ -523,7 +531,8 @@ def _setitem_impl_(var, item, value):
slice_info = SliceInfo() slice_info = SliceInfo()
dim = 0 dim = 0
for _, slice_item in enumerate(item): for _, slice_item in enumerate(item):
if is_integer_or_scalar_tensor(slice_item): if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor(
slice_item):
decrease_axes.append(dim) decrease_axes.append(dim)
start = slice_item start = slice_item
end = slice_item + 1 if slice_item != -1 else MAX_INTEGER end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册