未验证 提交 a0bbc992 编写于 作者: Z zyfncg 提交者: GitHub

Support getitem by None index in dynamic mode (#34338)

* Support getitem by ellipsis index in dynamic mode

* change some code style

* Support getitem by none index in dynamic mode

* modify a comments style and remove useless code
上级 df27c264
...@@ -420,6 +420,7 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, ...@@ -420,6 +420,7 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
std::vector<int> *slice_ends, std::vector<int> *slice_ends,
std::vector<int> *slice_strides, std::vector<int> *slice_strides,
std::vector<int> *decrease_axis, std::vector<int> *decrease_axis,
std::vector<int> *none_axes,
std::vector<int> *infer_flags) { std::vector<int> *infer_flags) {
// We allow indexing by Integers, Slices, and tuples of those // We allow indexing by Integers, Slices, and tuples of those
// types. // types.
...@@ -443,10 +444,6 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, ...@@ -443,10 +444,6 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
} }
} }
PADDLE_ENFORCE_EQ(
size <= rank, true,
platform::errors::InvalidArgument(
"too many indices (%d) for tensor of dimension %d", size, rank));
for (int i = 0, dim = 0; i < size; ++i) { for (int i = 0, dim = 0; i < size; ++i) {
PyObject *slice_item = PyTuple_GetItem(index, i); PyObject *slice_item = PyTuple_GetItem(index, i);
...@@ -491,14 +488,24 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, ...@@ -491,14 +488,24 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
dim++; dim++;
} else if (slice_item == Py_Ellipsis) { } else if (slice_item == Py_Ellipsis) {
dim += rank - specified_dims; dim += rank - specified_dims;
} else if (slice_item == Py_None) {
none_axes->push_back(dim);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Currently, VarBase.__getitem__() only allows " "Currently, VarBase.__getitem__() only allows indexing"
"indexing by Integers, Slices, Ellipsis, and tuples of " "by Integers, Slices, Ellipsis, None and tuples of "
"these types, but received %s in %dth slice item", "these types, but received %s in %dth slice item",
std::string(Py_TYPE(slice_item)->tp_name), i + 1)); std::string(Py_TYPE(slice_item)->tp_name), i + 1));
} }
} }
// valid_index is the number of dimensions exclude None index
const int valid_indexs = size - none_axes->size();
PADDLE_ENFORCE_EQ(valid_indexs <= rank, true,
platform::errors::InvalidArgument(
"Too many indices (%d) for tensor of dimension %d.",
valid_indexs, rank));
if (!PyTuple_Check(_index)) Py_DecRef(index); if (!PyTuple_Check(_index)) Py_DecRef(index);
} }
...@@ -790,9 +797,10 @@ void BindImperative(py::module *m_ptr) { ...@@ -790,9 +797,10 @@ void BindImperative(py::module *m_ptr) {
// copys data to cpu place, which reduces performance. // copys data to cpu place, which reduces performance.
if (parse_index && value_is_tensor) { if (parse_index && value_is_tensor) {
std::vector<int> axes, starts, ends, steps, decrease_axes, std::vector<int> axes, starts, ends, steps, decrease_axes,
infer_flags; none_axes, infer_flags;
ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends, ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends,
&steps, &decrease_axes, &infer_flags); &steps, &decrease_axes, &none_axes,
&infer_flags);
framework::AttributeMap attrs = { framework::AttributeMap attrs = {
{"axes", axes}, {"axes", axes},
...@@ -850,18 +858,22 @@ void BindImperative(py::module *m_ptr) { ...@@ -850,18 +858,22 @@ void BindImperative(py::module *m_ptr) {
.def("_getitem_index_not_tensor", .def("_getitem_index_not_tensor",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) { [](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
std::vector<int> slice_axes, slice_starts, slice_ends, std::vector<int> slice_axes, slice_starts, slice_ends,
slice_strides, decrease_axis, infer_flags; slice_strides, decrease_axis, none_axes, infer_flags;
auto tensor = auto tensor =
self->MutableVar()->GetMutable<framework::LoDTensor>(); self->MutableVar()->GetMutable<framework::LoDTensor>();
ParseIndexingSlice(tensor, _index.ptr(), &slice_axes, ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
&slice_starts, &slice_ends, &slice_strides, &slice_starts, &slice_ends, &slice_strides,
&decrease_axis, &infer_flags); &decrease_axis, &none_axes, &infer_flags);
// release gil and do tracing // release gil and do tracing
py::gil_scoped_release release; py::gil_scoped_release release;
const auto &tracer = imperative::GetCurrentTracer(); const auto &tracer = imperative::GetCurrentTracer();
if (slice_axes.empty()) {
return self; auto out = slice_axes.empty()
} else { ? self
: std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(
tracer->GenerateUniqueName()));
if (!slice_axes.empty()) {
imperative::NameVarBaseMap ins = {{"Input", {self}}}; imperative::NameVarBaseMap ins = {{"Input", {self}}};
framework::AttributeMap attrs = { framework::AttributeMap attrs = {
{"axes", slice_axes}, {"axes", slice_axes},
...@@ -869,8 +881,6 @@ void BindImperative(py::module *m_ptr) { ...@@ -869,8 +881,6 @@ void BindImperative(py::module *m_ptr) {
{"ends", slice_ends}, {"ends", slice_ends},
{"infer_flags", infer_flags}, {"infer_flags", infer_flags},
{"decrease_axis", decrease_axis}}; {"decrease_axis", decrease_axis}};
auto out = std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(tracer->GenerateUniqueName()));
imperative::NameVarBaseMap outs = {{"Out", {out}}}; imperative::NameVarBaseMap outs = {{"Out", {out}}};
std::string op_type = "slice"; std::string op_type = "slice";
for (auto stride : slice_strides) { for (auto stride : slice_strides) {
...@@ -882,8 +892,50 @@ void BindImperative(py::module *m_ptr) { ...@@ -882,8 +892,50 @@ void BindImperative(py::module *m_ptr) {
} }
} }
tracer->TraceOp(op_type, ins, outs, std::move(attrs)); tracer->TraceOp(op_type, ins, outs, std::move(attrs));
return out;
} }
if (!none_axes.empty()) {
// Deal with cases when all axes are decreased.
// After slice, the shape of out is [1], which should have been
// [], but Paddle doesn't support scalar.
// In order to ensure the correctness of the final shape of out,
// one dimension of out needs to be decreased.
// For example:
// # x.shape: (2,3,4)
// out = x[0, 1, 1, None] # out.shape : (1)
if (static_cast<int>(decrease_axis.size()) ==
tensor->dims().size()) {
none_axes.pop_back();
}
if (!none_axes.empty()) {
// Deal with cases that decrease_axes is not empty
// For example:
// # x.shape: (2,3,4)
// out = x[0, 0:2, None] # out.shape : (2, 1, 4)
for (auto &axis : none_axes) {
int len = 0;
for (int da : decrease_axis) {
if (da < axis) {
len++;
}
}
axis -= len;
}
imperative::NameVarBaseMap ins = {{"X", {out}}};
framework::AttributeMap attrs = {{"axes", none_axes}};
auto new_out = std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(tracer->GenerateUniqueName()));
auto out_xshape = std::shared_ptr<imperative::VarBase>(
new imperative::VarBase(tracer->GenerateUniqueName()));
imperative::NameVarBaseMap outs = {{"Out", {new_out}},
{"XShape", {out_xshape}}};
tracer->TraceOp("unsqueeze2", ins, outs, std::move(attrs));
return new_out;
}
}
return out;
}) })
.def( .def(
"_getitem_from_offset", "_getitem_from_offset",
......
...@@ -689,6 +689,40 @@ class TestVarBase(unittest.TestCase): ...@@ -689,6 +689,40 @@ class TestVarBase(unittest.TestCase):
assert_getitem_ellipsis_index(var_fp32, np_fp32_value) assert_getitem_ellipsis_index(var_fp32, np_fp32_value)
assert_getitem_ellipsis_index(var_int, np_int_value) assert_getitem_ellipsis_index(var_int, np_int_value)
def _test_none_index(self):
shape = (8, 64, 5, 256)
np_value = np.random.random(shape).astype('float32')
var_tensor = paddle.to_tensor(np_value)
var = [
var_tensor[1, 0, None].numpy(),
var_tensor[None, ..., 1, 0].numpy(),
var_tensor[:, :, :, None].numpy(),
var_tensor[1, ..., 1, None].numpy(),
var_tensor[2, ..., None, None].numpy(),
var_tensor[None, 2, 0, ...].numpy(),
var_tensor[None, 2, None, 1].numpy(),
var_tensor[None].numpy(),
var_tensor[0, 0, None, 0, 0, None].numpy(),
var_tensor[0, 1:10:2, None, None, ...].numpy(),
]
self.assertTrue(np.array_equal(var[0], np_value[1, 0, None]))
self.assertTrue(np.array_equal(var[1], np_value[None, ..., 1, 0]))
self.assertTrue(np.array_equal(var[2], np_value[:, :, :, None]))
self.assertTrue(np.array_equal(var[3], np_value[1, ..., 1, None]))
self.assertTrue(np.array_equal(var[4], np_value[2, ..., None, None]))
self.assertTrue(np.array_equal(var[5], np_value[None, 2, 0, ...]))
self.assertTrue(np.array_equal(var[6], np_value[None, 2, None, 1]))
self.assertTrue(np.array_equal(var[7], np_value[None]))
self.assertTrue(
np.array_equal(var[8], np_value[0, 0, None, 0, 0, None]))
# TODO(zyfncg) there is a bug of dimensions when slice step > 1 and
# indexs has int type
# self.assertTrue(
# np.array_equal(var[9], np_value[0, 1:10:2, None, None, ...]))
def _test_for_var(self): def _test_for_var(self):
np_value = np.random.random((30, 100, 100)).astype('float32') np_value = np.random.random((30, 100, 100)).astype('float32')
w = fluid.dygraph.to_variable(np_value) w = fluid.dygraph.to_variable(np_value)
...@@ -702,6 +736,7 @@ class TestVarBase(unittest.TestCase): ...@@ -702,6 +736,7 @@ class TestVarBase(unittest.TestCase):
self._test_slice_for_tensor_attr() self._test_slice_for_tensor_attr()
self._test_for_var() self._test_for_var()
self._test_for_getitem_ellipsis_index() self._test_for_getitem_ellipsis_index()
self._test_none_index()
var = fluid.dygraph.to_variable(self.array) var = fluid.dygraph.to_variable(self.array)
self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :])) self.assertTrue(np.array_equal(var[1, :].numpy(), self.array[1, :]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册