未验证 提交 b6dc16cb 编写于 作者: Z zyfncg 提交者: GitHub

Support gettiem by Bool index (#35026)

* Support getitem by Bool index

* delete some debug info of bool index

* support the case that the shape of bool index is different from indexed tensor
上级 97fef015
......@@ -54,6 +54,10 @@ class IndexSelectOp : public framework::OperatorWithKernel {
"the dimension of Input(Index) is [%d].",
index_dim, index_dim.size()));
PADDLE_ENFORCE_EQ(index_dim[0] != 0, true,
platform::errors::InvalidArgument(
"The length of Input(Index) can't be 0."));
auto output_dim = framework::vectorize(input_dim);
if (dim < 0) {
dim += input_dim.size();
......
......@@ -414,17 +414,15 @@ static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
return 0;
}
static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
std::vector<int> *slice_axes,
std::vector<int> *slice_starts,
std::vector<int> *slice_ends,
std::vector<int> *slice_strides,
std::vector<int> *decrease_axis,
std::vector<int> *none_axes,
std::vector<int> *infer_flags) {
// We allow indexing by Integers, Slices, and tuples of those
// types.
// Ellipsis and None are not supported yet.
static void ParseIndexingSlice(
framework::LoDTensor *tensor, PyObject *_index,
std::vector<int> *slice_axes, std::vector<int> *slice_starts,
std::vector<int> *slice_ends, std::vector<int> *slice_strides,
std::vector<int> *decrease_axis, std::vector<int> *none_axes,
std::vector<int> *infer_flags, std::vector<int> *list_select_idxs,
bool *list_select_flag) {
// We allow indexing by Integers, Slices, Ellipsis, None, tuples of those
// types, and list of Bool and Integers.
// wrap to tuple
PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
PADDLE_ENFORCE_EQ(
......@@ -490,11 +488,58 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
dim += rank - specified_dims;
} else if (slice_item == Py_None) {
none_axes->push_back(dim);
} else if (PyList_Check(slice_item)) {
*list_select_flag = true;
if (size != 1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"When index contains a list, its length is excepted to 1, "
"but received %d",
size));
}
bool all_bool = true;
int list_size = PyList_GET_SIZE(slice_item);
for (int j = 0; j < list_size; ++j) {
PyObject *list_item = PyList_GetItem(slice_item, j);
if (PyCheckInteger(list_item)) {
all_bool = false;
} else if (!PyBool_Check(list_item)) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support int or bool in index list."));
}
}
if (all_bool) {
if (list_size != shape[0]) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The dimension of bool index doesn't match indexed array along "
"dimension 0, the target dimension is %d, but received %d.",
shape[0], list_size));
}
for (int j = 0; j < list_size; ++j) {
PyObject *list_item = PyList_GetItem(slice_item, j);
if (list_item == Py_True) {
list_select_idxs->push_back(j);
}
}
} else {
for (int j = 0; j < list_size; ++j) {
PyObject *list_item = PyList_GetItem(slice_item, j);
if (PyCheckInteger(list_item)) {
list_select_idxs->push_back(
static_cast<int>(PyLong_AsLong(list_item)));
} else if (list_item == Py_True) {
list_select_idxs->push_back(1);
} else {
list_select_idxs->push_back(0);
}
}
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Currently, VarBase.__getitem__() only allows indexing"
"by Integers, Slices, Ellipsis, None and tuples of "
"these types, but received %s in %dth slice item",
"Currently, VarBase.__getitem__() only allows indexing "
"by Integers, Slices, Ellipsis, None, tuples of these types "
"and list of Bool and Integers, but received "
"%s in %dth slice item",
std::string(Py_TYPE(slice_item)->tp_name), i + 1));
}
}
......@@ -798,10 +843,13 @@ void BindImperative(py::module *m_ptr) {
// copys data to cpu place, which reduces performance.
if (parse_index && value_is_tensor) {
std::vector<int> axes, starts, ends, steps, decrease_axes,
none_axes, infer_flags;
none_axes, infer_flags, list_select_idxs;
// if index is a list, list_select_flag will be true
bool list_select_flag;
ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends,
&steps, &decrease_axes, &none_axes,
&infer_flags);
&infer_flags, &list_select_idxs,
&list_select_flag);
framework::AttributeMap attrs = {
{"axes", axes},
......@@ -860,21 +908,26 @@ void BindImperative(py::module *m_ptr) {
.def("_getitem_index_not_tensor",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
std::vector<int> slice_axes, slice_starts, slice_ends,
slice_strides, decrease_axis, none_axes, infer_flags;
slice_strides, decrease_axis, none_axes, infer_flags,
list_select_idxs;
// if index is a list, list_select_flag will be true
bool list_select_flag = false;
auto tensor =
self->MutableVar()->GetMutable<framework::LoDTensor>();
ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
&slice_starts, &slice_ends, &slice_strides,
&decrease_axis, &none_axes, &infer_flags);
&decrease_axis, &none_axes, &infer_flags,
&list_select_idxs, &list_select_flag);
// release gil and do tracing
py::gil_scoped_release release;
const auto &tracer = imperative::GetCurrentTracer();
auto out = slice_axes.empty()
auto out = slice_axes.empty() && !list_select_flag
? self
: std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(
tracer->GenerateUniqueName()));
if (!slice_axes.empty()) {
imperative::NameVarBaseMap ins = {{"Input", {self}}};
framework::AttributeMap attrs = {
......@@ -960,6 +1013,22 @@ void BindImperative(py::module *m_ptr) {
}
}
// the index is a list
if (list_select_flag) {
auto select_index = std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(tracer->GenerateUniqueName()));
auto *idx_tensor = select_index->MutableVar()
->GetMutable<framework::LoDTensor>();
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(
tracer->ExpectedPlace());
TensorFromVector(list_select_idxs, *dev_ctx, idx_tensor);
imperative::NameVarBaseMap ins = {{"X", {self}},
{"Index", {select_index}}};
imperative::NameVarBaseMap outs = {{"Out", {out}}};
tracer->TraceOp("index_select", ins, outs, {{"dim", 0}});
}
return out;
})
.def(
......
......@@ -733,6 +733,45 @@ class TestVarBase(unittest.TestCase):
# self.assertTrue(
# np.array_equal(var[10], np_value[0, 1:10:2, None, None, ...]))
def _test_bool_index(self):
shape = (4, 2, 5, 64)
np_value = np.random.random(shape).astype('float32')
var_tensor = paddle.to_tensor(np_value)
index = [[True, True, True, True], [True, False, True, True],
[True, False, False, True], [False, 0, 1, True, True]]
index2d = np.array([[True, True], [False, False], [True, False],
[True, True]])
tensor_index = paddle.to_tensor(index2d)
var = [
var_tensor[index[0]].numpy(),
var_tensor[index[1]].numpy(),
var_tensor[index[2]].numpy(),
var_tensor[index[3]].numpy(),
var_tensor[paddle.to_tensor(index[0])].numpy(),
var_tensor[tensor_index].numpy(),
]
self.assertTrue(np.array_equal(var[0], np_value[index[0]]))
self.assertTrue(np.array_equal(var[1], np_value[index[1]]))
self.assertTrue(np.array_equal(var[2], np_value[index[2]]))
self.assertTrue(np.array_equal(var[3], np_value[index[3]]))
self.assertTrue(np.array_equal(var[4], np_value[index[0]]))
self.assertTrue(np.array_equal(var[5], np_value[index2d]))
self.assertTrue(
np.array_equal(var_tensor[var_tensor > 0.67], np_value[np_value >
0.67]))
self.assertTrue(
np.array_equal(var_tensor[var_tensor < 0.55], np_value[np_value <
0.55]))
with self.assertRaises(ValueError):
var_tensor[[False, False, False, False]]
with self.assertRaises(ValueError):
var_tensor[[True, False]]
with self.assertRaises(ValueError):
var_tensor[[True, False, False, False, False]]
with self.assertRaises(IndexError):
var_tensor[paddle.to_tensor([[True, False, False, False]])]
def _test_for_var(self):
np_value = np.random.random((30, 100, 100)).astype('float32')
w = fluid.dygraph.to_variable(np_value)
......@@ -747,6 +786,7 @@ class TestVarBase(unittest.TestCase):
self._test_for_var()
self._test_for_getitem_ellipsis_index()
self._test_none_index()
self._test_bool_index()
var = fluid.dygraph.to_variable(self.array)
self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :]))
......
......@@ -246,32 +246,49 @@ class TestVariable(unittest.TestCase):
res = x[[1.2, 0]]
def _test_slice_index_list_bool(self, place):
data = np.random.rand(2, 3).astype("float32")
data = np.random.rand(2, 3, 4).astype("float32")
np_idx = np.array([[True, False, False], [True, False, True]])
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
idx0 = [True, False]
idx1 = [False, True]
idx2 = [False, False]
idx3 = [True, True]
idx2 = [True, True]
idx3 = [False, False, 1]
idx4 = [True, False, 0]
idx5 = paddle.assign(np_idx)
out0 = x[idx0]
out1 = x[idx1]
out2 = x[idx2]
out3 = x[idx3]
out4 = x[idx4]
out5 = x[idx5]
out6 = x[x < 0.36]
out7 = x[x > 0.6]
exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out0, out1, out2, out3])
result = exe.run(
prog, fetch_list=[out0, out1, out2, out3, out4, out5, out6, out7])
expected = [data[idx0], data[idx1], data[idx2], data[idx3]]
expected = [
data[idx0], data[idx1], data[idx2], data[idx3], data[idx4],
data[np_idx], data[data < 0.36], data[data > 0.6]
]
self.assertTrue((result[0] == expected[0]).all())
self.assertTrue((result[1] == expected[1]).all())
self.assertTrue((result[2] == expected[2]).all())
self.assertTrue((result[3] == expected[3]).all())
self.assertTrue((result[4] == expected[4]).all())
self.assertTrue((result[5] == expected[5]).all())
self.assertTrue((result[6] == expected[6]).all())
self.assertTrue((result[7] == expected[7]).all())
with self.assertRaises(TypeError):
res = x[[True, 0]]
with self.assertRaises(IndexError):
res = x[[True, False, False]]
with self.assertRaises(ValueError):
res = x[[False, False]]
def test_slice(self):
places = [fluid.CPUPlace()]
......
......@@ -150,31 +150,37 @@ def _getitem_impl_(var, item):
end = MAX_INTEGER if step > 0 else -1
elif isinstance(slice_item, list):
is_bool_list = False
all_bool = True
for i in slice_item:
if not isinstance(i, (int, bool)):
if type(i) is int:
all_bool = False
elif not isinstance(i, bool):
raise TypeError("Only support int or bool in index list.")
if isinstance(i, bool):
is_bool_list = True
break
if len(item) != 1:
raise IndexError(
"When index contains a list, its length must be 1, but received {}".
"When index contains a list, its length must be 1, but received {}.".
format(len(item)))
if is_bool_list:
new_slice_item = []
new_slice_item = []
if all_bool:
if len(slice_item) != var.shape[0]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "\
"dimension 0, the target dimension is {}, but received {}.".
format(var.shape[0], len(slice_item)))
for idx, ele in enumerate(slice_item):
if not isinstance(ele, bool):
raise TypeError(
"Mixed bool index with other types is not supported."
)
if ele is True:
new_slice_item.append(idx)
slice_item = new_slice_item
else:
for idx, ele in enumerate(slice_item):
if type(ele) is int:
new_slice_item.append(ele)
elif ele is True:
new_slice_item.append(1)
else:
new_slice_item.append(0)
slice_item = new_slice_item
from .layers import assign
from ..tensor import index_select
......@@ -185,10 +191,27 @@ def _getitem_impl_(var, item):
elif isinstance(slice_item, Variable):
if len(item) != 1:
raise IndexError(
"When index contains a Tensor, its length must be 1, but received {}".
"When index contains a Tensor, its length must be 1, but received {}.".
format(len(item)))
from ..tensor import index_select
from ..tensor import index_select, gather_nd
from .layers.nn import where
if slice_item.dtype == core.VarDesc.VarType.BOOL:
if len(slice_item.shape) > len(var.shape):
raise IndexError(
"The dims of bool index doesn't match indexed array, "
"the dims of bool index except to be equal or less "
"than {}, but received {}.".format(
len(var.shape), len(slice_item.shape)))
for i, dim_len in enumerate(slice_item.shape):
if dim_len != var.shape[i]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "\
"dimension {}, the target dimension is {}, but received {}.".
format(i, var.shape[i], dim_len))
bool_2_idx = where(slice_item == True)
return gather_nd(var, bool_2_idx)
return index_select(var, index=slice_item, axis=0)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册