From 99d30bfc367d6b472d8b667656678f6b6d84db0c Mon Sep 17 00:00:00 2001 From: songyouwei Date: Wed, 1 Apr 2020 11:47:56 +0800 Subject: [PATCH] speedup slice impl (#23340) test=develop --- paddle/fluid/pybind/imperative.cc | 135 ++++++++++++++++-------------- 1 file changed, 71 insertions(+), 64 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index b334e009bf3..18b82292603 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -217,6 +217,68 @@ static imperative::NameVarBaseMap ConvertToNameVarBaseMap( return result; } +static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, + std::vector *slice_axes, + std::vector *slice_starts, + std::vector *slice_ends, + std::vector *slice_strides, + std::vector *decrease_axis, + std::vector *infer_flags) { + // We allow indexing by Integers, Slices, and tuples of those + // types. + // Ellipsis and None are not supported yet. + // wrap to tuple + PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index; + PADDLE_ENFORCE_EQ( + tensor->IsInitialized(), true, + platform::errors::InvalidArgument("tensor has not been initialized")); + const auto &shape = tensor->dims(); + const int rank = shape.size(); + const int size = PyTuple_GET_SIZE(index); + PADDLE_ENFORCE_EQ( + size <= rank, true, + platform::errors::InvalidArgument( + "too many indices (%d) for tensor of dimension %d", size, rank)); + for (int dim = 0; dim < size; ++dim) { + PyObject *slice_item = PyTuple_GetItem(index, dim); + PADDLE_ENFORCE_EQ( + PyNumber_Check(slice_item) || PySlice_Check(slice_item), true, + platform::errors::InvalidArgument( + "We allow indexing by Integers, Slices, and tuples of " + "these types, but received %s in %dth slice item", + std::string(Py_TYPE(slice_item)->tp_name), dim + 1)); + infer_flags->push_back(1); + int dim_len = shape[dim]; + if (PyNumber_Check(slice_item)) { + // integer + int start = static_cast(PyLong_AsLong(slice_item)); + start = start < 0 ? start + dim_len : start; + slice_axes->push_back(dim); + slice_starts->push_back(start); + slice_ends->push_back(start + 1); + slice_strides->push_back(1); + decrease_axis->push_back(dim); + } else { + // slice + Py_ssize_t start, end, step; +// The parameter type for the slice parameter was PySliceObject* before 3.2 +#if PY_VERSION_HEX >= 0x03020000 + PySlice_GetIndices(slice_item, dim_len, &start, &end, &step); +#else + PySlice_GetIndices(reinterpret_cast(slice_item), dim_len, + &start, &end, &step); +#endif + // :: or : or 0:dim_len:1 + if (start == 0 && end == dim_len && step == 1) continue; + slice_axes->push_back(dim); + slice_starts->push_back(start); + slice_ends->push_back(end); + slice_strides->push_back(step); + } + } + if (!PyTuple_Check(_index)) Py_DecRef(index); +} + // Bind Methods void BindImperative(py::module *m_ptr) { auto &m = *m_ptr; @@ -396,77 +458,22 @@ void BindImperative(py::module *m_ptr) { .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value")) .def("__init__", &InitVarBaseFromNumpyWithKwargs) .def("__getitem__", - [](imperative::VarBase &self, py::handle _index) { - // We allow indexing by Integers, Slices, and tuples of those - // types. - // Ellipsis and None are not supported yet. + [](std::shared_ptr &self, py::handle _index) { std::vector slice_axes, slice_starts, slice_ends, - slice_strides, decrease_axis; - // wrap to tuple - PyObject *index = !PyTuple_Check(_index.ptr()) - ? PyTuple_Pack(1, _index.ptr()) - : _index.ptr(); - const auto &tensor = self.Var().Get(); - PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true, - platform::errors::InvalidArgument( - "%s has not been initialized", self.Name())); - const auto &shape = tensor.dims(); - const int rank = shape.size(); - const int size = PyTuple_GET_SIZE(index); - PADDLE_ENFORCE_EQ( - size <= rank, true, - platform::errors::InvalidArgument( - "too many indices (%d) for tensor of dimension %d", size, - rank)); - for (int dim = 0; dim < size; ++dim) { - PyObject *slice_item = PyTuple_GetItem(index, dim); - PADDLE_ENFORCE_EQ( - PyNumber_Check(slice_item) || PySlice_Check(slice_item), - true, - platform::errors::InvalidArgument( - "We allow indexing by Integers, Slices, and tuples of " - "these types, but received %s in %dth slice item", - std::string(Py_TYPE(slice_item)->tp_name), dim + 1)); - int dim_len = shape[dim]; - if (PyNumber_Check(slice_item)) { - // integer - int start = static_cast(PyLong_AsLong(slice_item)); - start = start < 0 ? start + dim_len : start; - slice_axes.push_back(dim); - slice_starts.push_back(start); - slice_ends.push_back(start + 1); - slice_strides.push_back(1); - decrease_axis.push_back(dim); - } else { - // slice - Py_ssize_t start, end, step; -// The parameter type for the slice parameter was PySliceObject* before 3.2 -#if PY_VERSION_HEX >= 0x03020000 - PySlice_GetIndices(slice_item, dim_len, &start, &end, &step); -#else - PySlice_GetIndices( - reinterpret_cast(slice_item), dim_len, - &start, &end, &step); -#endif - // :: or : or 0:dim_len:1 - if (start == 0 && end == dim_len && step == 1) continue; - slice_axes.push_back(dim); - slice_starts.push_back(start); - slice_ends.push_back(end); - slice_strides.push_back(step); - } - } - if (!PyTuple_Check(_index.ptr())) Py_DecRef(index); + slice_strides, decrease_axis, infer_flags; + auto tensor = + self->MutableVar()->GetMutable(); + ParseIndexingSlice(tensor, _index.ptr(), &slice_axes, + &slice_starts, &slice_ends, &slice_strides, + &decrease_axis, &infer_flags); // release gil and do tracing py::gil_scoped_release release; const auto &tracer = imperative::GetCurrentTracer(); - auto _self = self.NewVarBase(tensor.place(), false); if (slice_axes.empty()) { - return _self; + return self; } else { - std::vector infer_flags(size, 1); - imperative::NameVarBaseMap ins = {{"Input", {_self}}}; + imperative::NameVarBaseMap ins = {{"Input", {self}}}; framework::AttributeMap attrs = { {"axes", slice_axes}, {"starts", slice_starts}, -- GitLab