未验证 提交 99d30bfc 编写于 作者: S songyouwei 提交者: GitHub

speedup slice impl (#23340)

test=develop
上级 1a6ce8b9
......@@ -217,6 +217,68 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
return result;
}
static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
std::vector<int> *slice_axes,
std::vector<int> *slice_starts,
std::vector<int> *slice_ends,
std::vector<int> *slice_strides,
std::vector<int> *decrease_axis,
std::vector<int> *infer_flags) {
// We allow indexing by Integers, Slices, and tuples of those
// types.
// Ellipsis and None are not supported yet.
// wrap to tuple
PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
PADDLE_ENFORCE_EQ(
tensor->IsInitialized(), true,
platform::errors::InvalidArgument("tensor has not been initialized"));
const auto &shape = tensor->dims();
const int rank = shape.size();
const int size = PyTuple_GET_SIZE(index);
PADDLE_ENFORCE_EQ(
size <= rank, true,
platform::errors::InvalidArgument(
"too many indices (%d) for tensor of dimension %d", size, rank));
for (int dim = 0; dim < size; ++dim) {
PyObject *slice_item = PyTuple_GetItem(index, dim);
PADDLE_ENFORCE_EQ(
PyNumber_Check(slice_item) || PySlice_Check(slice_item), true,
platform::errors::InvalidArgument(
"We allow indexing by Integers, Slices, and tuples of "
"these types, but received %s in %dth slice item",
std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
infer_flags->push_back(1);
int dim_len = shape[dim];
if (PyNumber_Check(slice_item)) {
// integer
int start = static_cast<int>(PyLong_AsLong(slice_item));
start = start < 0 ? start + dim_len : start;
slice_axes->push_back(dim);
slice_starts->push_back(start);
slice_ends->push_back(start + 1);
slice_strides->push_back(1);
decrease_axis->push_back(dim);
} else {
// slice
Py_ssize_t start, end, step;
// The parameter type for the slice parameter was PySliceObject* before 3.2
#if PY_VERSION_HEX >= 0x03020000
PySlice_GetIndices(slice_item, dim_len, &start, &end, &step);
#else
PySlice_GetIndices(reinterpret_cast<PySliceObject *>(slice_item), dim_len,
&start, &end, &step);
#endif
// :: or : or 0:dim_len:1
if (start == 0 && end == dim_len && step == 1) continue;
slice_axes->push_back(dim);
slice_starts->push_back(start);
slice_ends->push_back(end);
slice_strides->push_back(step);
}
}
if (!PyTuple_Check(_index)) Py_DecRef(index);
}
// Bind Methods
void BindImperative(py::module *m_ptr) {
auto &m = *m_ptr;
......@@ -396,77 +458,22 @@ void BindImperative(py::module *m_ptr) {
.def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
.def("__init__", &InitVarBaseFromNumpyWithKwargs)
.def("__getitem__",
[](imperative::VarBase &self, py::handle _index) {
// We allow indexing by Integers, Slices, and tuples of those
// types.
// Ellipsis and None are not supported yet.
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
std::vector<int> slice_axes, slice_starts, slice_ends,
slice_strides, decrease_axis;
// wrap to tuple
PyObject *index = !PyTuple_Check(_index.ptr())
? PyTuple_Pack(1, _index.ptr())
: _index.ptr();
const auto &tensor = self.Var().Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"%s has not been initialized", self.Name()));
const auto &shape = tensor.dims();
const int rank = shape.size();
const int size = PyTuple_GET_SIZE(index);
PADDLE_ENFORCE_EQ(
size <= rank, true,
platform::errors::InvalidArgument(
"too many indices (%d) for tensor of dimension %d", size,
rank));
for (int dim = 0; dim < size; ++dim) {
PyObject *slice_item = PyTuple_GetItem(index, dim);
PADDLE_ENFORCE_EQ(
PyNumber_Check(slice_item) || PySlice_Check(slice_item),
true,
platform::errors::InvalidArgument(
"We allow indexing by Integers, Slices, and tuples of "
"these types, but received %s in %dth slice item",
std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
int dim_len = shape[dim];
if (PyNumber_Check(slice_item)) {
// integer
int start = static_cast<int>(PyLong_AsLong(slice_item));
start = start < 0 ? start + dim_len : start;
slice_axes.push_back(dim);
slice_starts.push_back(start);
slice_ends.push_back(start + 1);
slice_strides.push_back(1);
decrease_axis.push_back(dim);
} else {
// slice
Py_ssize_t start, end, step;
// The parameter type for the slice parameter was PySliceObject* before 3.2
#if PY_VERSION_HEX >= 0x03020000
PySlice_GetIndices(slice_item, dim_len, &start, &end, &step);
#else
PySlice_GetIndices(
reinterpret_cast<PySliceObject *>(slice_item), dim_len,
&start, &end, &step);
#endif
// :: or : or 0:dim_len:1
if (start == 0 && end == dim_len && step == 1) continue;
slice_axes.push_back(dim);
slice_starts.push_back(start);
slice_ends.push_back(end);
slice_strides.push_back(step);
}
}
if (!PyTuple_Check(_index.ptr())) Py_DecRef(index);
slice_strides, decrease_axis, infer_flags;
auto tensor =
self->MutableVar()->GetMutable<framework::LoDTensor>();
ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
&slice_starts, &slice_ends, &slice_strides,
&decrease_axis, &infer_flags);
// release gil and do tracing
py::gil_scoped_release release;
const auto &tracer = imperative::GetCurrentTracer();
auto _self = self.NewVarBase(tensor.place(), false);
if (slice_axes.empty()) {
return _self;
return self;
} else {
std::vector<int> infer_flags(size, 1);
imperative::NameVarBaseMap ins = {{"Input", {_self}}};
imperative::NameVarBaseMap ins = {{"Input", {self}}};
framework::AttributeMap attrs = {
{"axes", slice_axes},
{"starts", slice_starts},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册