未验证 提交 d387820d 编写于 作者: Z zyfncg 提交者: GitHub

Support settiem by Bool index (#35133)

* Support getitem by Bool index

* delete some debug info of bool index

* support the case that the shape of bool index is different from indexed tensor

* support setitem by bool index

* add the unittest for throwing exception

* merge conflict

* add check for int tensor when index is bool
上级 884011a4
......@@ -499,12 +499,12 @@ static void ParseIndexingSlice(
none_axes->push_back(dim);
} else if (PyList_Check(slice_item)) {
*list_select_flag = true;
if (size != 1) {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(
size, 1,
platform::errors::InvalidArgument(
"When index contains a list, its length is excepted to 1, "
"but received %d",
size));
}
bool all_bool = true;
int list_size = PyList_GET_SIZE(slice_item);
for (int j = 0; j < list_size; ++j) {
......@@ -517,12 +517,13 @@ static void ParseIndexingSlice(
}
}
if (all_bool) {
if (list_size != shape[0]) {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(
list_size, shape[0],
platform::errors::InvalidArgument(
"The dimension of bool index doesn't match indexed array along "
"dimension 0, the target dimension is %d, but received %d.",
shape[0], list_size));
}
for (int j = 0; j < list_size; ++j) {
PyObject *list_item = PyList_GetItem(slice_item, j);
if (list_item == Py_True) {
......@@ -818,7 +819,7 @@ void BindImperative(py::module *m_ptr) {
.def("__setitem_varbase__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
py::object &value_obj) {
VLOG(4) << "Call __setitem__";
VLOG(4) << "Call __setitem_varbase__";
auto self_tensor =
self->MutableVar()->GetMutable<framework::LoDTensor>();
......@@ -871,7 +872,6 @@ void BindImperative(py::module *m_ptr) {
// TODO(liym27): Try not to call TensorToPyArray because it always
// copys data to cpu place, which reduces performance.
if (parse_index && value_is_tensor) {
VLOG(4) << "index is integer/slice/ellipsis and value is tensor";
std::vector<int> axes, starts, ends, steps, decrease_axes,
none_axes, infer_flags, list_select_idxs;
// if index is a list, list_select_flag will be true
......@@ -880,6 +880,7 @@ void BindImperative(py::module *m_ptr) {
&steps, &decrease_axes, &none_axes,
&infer_flags, &list_select_idxs,
&list_select_flag);
framework::AttributeMap attrs = {
{"axes", axes},
{"starts", starts},
......
......@@ -587,14 +587,25 @@ def monkey_patch_varbase():
return self._getitem_index_not_tensor(item)
def __setitem__(self, item, value):
def contain_tensor_or_list(item):
if not isinstance(item, tuple):
item = [item]
if contain_tensor(item):
# 1. Call _setitem_impl_ when item contains tensor.
# Why not call a c++ function ? Because item can't be parsed when it contains tensor.
for slice_item in item:
if isinstance(slice_item, list):
return True
elif isinstance(slice_item, Variable):
return True
return False
if contain_tensor_or_list(item):
# To reuse code with static graph,
# Call _setitem_impl_ when item contains tensor or list.
return _setitem_impl_(self, item, value)
else:
# 2. Call c++ func __setitem_varbase__ to speedup.
# Call c++ func __setitem_varbase__ to speedup.
return self.__setitem_varbase__(item, value)
for method_name, method in (
......
......@@ -408,6 +408,61 @@ class TestSetValueItemNone9(TestSetValueApi):
self.data[None, :, 1, ..., None] = np.zeros(self.shape)[0, 0, :, None]
# 1.5 item is list or Tensor of bol
class TestSetValueItemBool1(TestSetValueApi):
def _call_setitem(self, x):
x[[True, False]] = self.value
def _get_answer(self):
self.data[[True, False]] = self.value
class TestSetValueItemBool2(TestSetValueApi):
def _call_setitem(self, x):
x[[False, False]] = self.value
def _get_answer(self):
self.data[[False, False]] = self.value
class TestSetValueItemBool3(TestSetValueApi):
def _call_setitem(self, x):
x[[False, True]] = np.zeros(self.shape[2])
def _get_answer(self):
self.data[[False, True]] = np.zeros(self.shape[2])
class TestSetValueItemBool4(TestSetValueApi):
def _call_setitem(self, x):
idx = paddle.assign(np.array([False, True]))
x[idx] = np.zeros(self.shape[2])
def _get_answer(self):
self.data[np.array([False, True])] = np.zeros(self.shape[2])
class TestSetValueItemBool5(TestSetValueApi):
def _call_setitem(self, x):
idx = paddle.assign(
np.array([[False, True, False], [True, True, False]]))
x[idx] = self.value
def _get_answer(self):
self.data[np.array([[False, True, False], [True, True, False]
])] = self.value
class TestSetValueItemBool6(TestSetValueApi):
def _call_setitem(self, x):
x[0, ...] = 0
x[x > 0] = self.value
def _get_answer(self):
self.data[0, ...] = 0
self.data[self.data > 0] = self.value
# 2. Test different type of value: int, float, numpy.ndarray, Tensor
# 2.1 value is int32, int64, float32, float64, bool
......@@ -830,6 +885,21 @@ class TestError(TestSetValueBase):
one = paddle.ones([1])
x[::one] = self.value
def _bool_list_error(self):
with self.assertRaises(TypeError):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
x[[True, False, 0]] = 0
with self.assertRaises(IndexError):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
x[[True, False], [True, False]] = 0
def _bool_tensor_error(self):
with self.assertRaises(IndexError):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
idx = paddle.assign([True, False, True])
x[idx] = 0
def _broadcast_mismatch(self):
program = paddle.static.Program()
with paddle.static.program_guard(program):
......@@ -846,6 +916,8 @@ class TestError(TestSetValueBase):
self._value_type_error()
self._dtype_error()
self._step_error()
self._bool_list_error()
self._bool_tensor_error()
self._broadcast_mismatch()
......
......@@ -509,16 +509,6 @@ def _setitem_impl_(var, item, value):
start = slice_item
end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
step = 1
elif isinstance(slice_item, list):
if not is_list_tuple(slice_item, int):
raise TypeError(
"Only support int or list in index list. But revceived {}.".
format(slice_item))
slice_info.update(slice_item)
continue
elif isinstance(slice_item, (Variable, np.ndarray)):
slice_info.update(slice_item)
continue
elif isinstance(slice_item, slice):
start = slice_item.start
......@@ -547,10 +537,43 @@ def _setitem_impl_(var, item, value):
if end is None:
end = MAX_INTEGER if step > 0 else (0 - MAX_INTEGER)
elif isinstance(slice_item, list):
if is_list_tuple(slice_item, int):
slice_info.update(slice_item)
continue
for i in slice_item:
if not isinstance(i, bool):
raise TypeError("Doesn't support {} in index list.".format(
type(i)))
if len(item) != 1:
raise IndexError(
"When index contains a bool list, its length must be 1, but received {}.".
format(len(item)))
from .layers import assign
idx_tensor = assign(slice_item)
return set_value_for_bool_tensor(var, idx_tensor, value)
elif isinstance(slice_item, np.ndarray):
slice_info.update(slice_item)
continue
elif isinstance(slice_item, Variable):
if slice_item.dtype == core.VarDesc.VarType.BOOL:
if len(item) != 1:
raise IndexError(
"When index contains a bool tensor, its length must be 1, but received {}.".
format(len(item)))
return set_value_for_bool_tensor(var, slice_item, value)
else:
slice_info.update(slice_item)
continue
else:
raise IndexError(
"Valid index accept int, slice, ellipsis or None, but received {}.".
format(slice_item))
"Valid index accept int, slice, ellipsis, None, list of bool, Variable, "
"but received {}.".format(slice_item))
axes.append(dim)
starts.append(start)
......@@ -632,3 +655,47 @@ def _setitem_impl_(var, item, value):
type="set_value", inputs=inputs, outputs={'Out': var}, attrs=attrs)
return var
# the item is a tensor of bool
def set_value_for_bool_tensor(var, item, value):
# TODO(zyfncg): Now scatter_nd_add only support float32 and float64 tensor,
# so in the current version we also only support float32 and float64 tensor,
# this problem will be fixed in the future.
if var.dtype != core.VarDesc.VarType.FP32 and var.dtype != core.VarDesc.VarType.FP64:
raise TypeError("Only support float and double tensor for bool index, "
"but received {}.".format(var.dtype))
if len(item.shape) > len(var.shape):
raise IndexError("The dims of bool index doesn't match indexed array, "
"the dims of bool index except to be equal or less "
"than {}, but received {}.".format(
len(var.shape), len(item.shape)))
for i, dim_len in enumerate(item.shape):
if dim_len != var.shape[i]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "
"dimension {}, the target dimension is {}, but received {}.".
format(i, var.shape[i], dim_len))
def idx_not_empty(var, item, value):
from .framework import Variable
from .layers import assign
from .layers.nn import where
from ..tensor import gather_nd, scatter_nd_add
if not isinstance(value, Variable):
value = assign(value).cast(var.dtype)
idx = where(item)
gather_val = gather_nd(var, idx)
gather_val_new = value - gather_val
out = scatter_nd_add(var, idx, gather_val_new)
var[:] = out
from .layers.control_flow import cond
# If all the bool index is False, just do nothing
cond(item.any(), lambda: idx_not_empty(var, item, value))
return var
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册